diff --git a/new-api/.dockerignore b/new-api/.dockerignore
deleted file mode 100644
index c6dfb9274e99ce46d2ba5947ff4ec64c237317cd..0000000000000000000000000000000000000000
--- a/new-api/.dockerignore
+++ /dev/null
@@ -1,8 +0,0 @@
-.github
-.git
-*.md
-.vscode
-.gitignore
-Makefile
-docs
-.eslintcache
\ No newline at end of file
diff --git a/new-api/.env.example b/new-api/.env.example
deleted file mode 100644
index 4332abc07e287e2a71cff95118e0eb274c89f130..0000000000000000000000000000000000000000
--- a/new-api/.env.example
+++ /dev/null
@@ -1,73 +0,0 @@
-# 端口号
-# PORT=3000
-# 前端基础URL
-# FRONTEND_BASE_URL=https://your-frontend-url.com
-
-
-# 调试相关配置
-# 启用pprof
-# ENABLE_PPROF=true
-# 启用调试模式
-# DEBUG=true
-
-# 数据库相关配置
-# 数据库连接字符串
-# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
-# 日志数据库连接字符串
-# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
-# SQLite数据库路径
-# SQLITE_PATH=/path/to/sqlite.db
-# 数据库最大空闲连接数
-# SQL_MAX_IDLE_CONNS=100
-# 数据库最大打开连接数
-# SQL_MAX_OPEN_CONNS=1000
-# 数据库连接最大生命周期(秒)
-# SQL_MAX_LIFETIME=60
-
-
-# 缓存相关配置
-# Redis连接字符串
-# REDIS_CONN_STRING=redis://user:password@localhost:6379/0
-# 同步频率(单位:秒)
-# SYNC_FREQUENCY=60
-# 内存缓存启用
-# MEMORY_CACHE_ENABLED=true
-# 渠道更新频率(单位:秒)
-# CHANNEL_UPDATE_FREQUENCY=30
-# 批量更新启用
-# BATCH_UPDATE_ENABLED=true
-# 批量更新间隔(单位:秒)
-# BATCH_UPDATE_INTERVAL=5
-
-# 任务和功能配置
-# 更新任务启用
-# UPDATE_TASK=true
-
-# 对话超时设置
-# 所有请求超时时间,单位秒,默认为0,表示不限制
-# RELAY_TIMEOUT=0
-# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
-# STREAMING_TIMEOUT=300
-
-# Gemini 识别图片 最大图片数量
-# GEMINI_VISION_MAX_IMAGE_NUM=16
-
-# 会话密钥
-# SESSION_SECRET=random_string
-
-# 其他配置
-# 生成默认token
-# GENERATE_DEFAULT_TOKEN=false
-# Cohere 安全设置
-# COHERE_SAFETY_SETTING=NONE
-# 是否统计图片token
-# GET_MEDIA_TOKEN=true
-# 是否在非流(stream=false)情况下统计图片token
-# GET_MEDIA_TOKEN_NOT_STREAM=true
-# 设置 Dify 渠道是否输出工作流和节点信息到客户端
-# DIFY_DEBUG=true
-
-
-# 节点类型
-# 如果是主节点则为master
-# NODE_TYPE=master
diff --git a/new-api/.gitignore b/new-api/.gitignore
deleted file mode 100644
index 7fa0dcad62e2a51a1b897979f026a1c5612de4f2..0000000000000000000000000000000000000000
--- a/new-api/.gitignore
+++ /dev/null
@@ -1,14 +0,0 @@
-.idea
-.vscode
-upload
-*.exe
-*.db
-build
-*.db-journal
-logs
-web/dist
-.env
-one-api
-.DS_Store
-tiktoken_cache
-.eslintcache
\ No newline at end of file
diff --git a/new-api/Dockerfile b/new-api/Dockerfile
deleted file mode 100644
index c1d0ba7947a6b343b6222e33c98b11f2bd2d1123..0000000000000000000000000000000000000000
--- a/new-api/Dockerfile
+++ /dev/null
@@ -1,35 +0,0 @@
-FROM oven/bun:latest AS builder
-
-WORKDIR /build
-COPY web/package.json .
-COPY web/bun.lock .
-RUN bun install
-COPY ./web .
-COPY ./VERSION .
-RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
-
-FROM golang:alpine AS builder2
-
-ENV GO111MODULE=on \
- CGO_ENABLED=0 \
- GOOS=linux
-
-WORKDIR /build
-
-ADD go.mod go.sum ./
-RUN go mod download
-
-COPY . .
-COPY --from=builder /build/dist ./web/dist
-RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
-
-FROM alpine
-
-RUN apk upgrade --no-cache \
- && apk add --no-cache ca-certificates tzdata ffmpeg \
- && update-ca-certificates
-
-COPY --from=builder2 /build/one-api /
-EXPOSE 3000
-WORKDIR /data
-ENTRYPOINT ["/one-api"]
diff --git a/new-api/LICENSE b/new-api/LICENSE
deleted file mode 100644
index 43854ba4bfc206d66e5f3c61e698b94a1d7e4805..0000000000000000000000000000000000000000
--- a/new-api/LICENSE
+++ /dev/null
@@ -1,103 +0,0 @@
-# **New API 许可协议 (Licensing)**
-
-本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
-
-**核心原则:**
-
-- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
-- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
-
----
-
-## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
-
-- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
-- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
-- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
-- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
-
-## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
-
-在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
-
-- **场景一:移除品牌和版权信息**
- 您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
-
-- **场景二:规避 AGPLv3 开源义务**
- 您基于 New API 进行了修改,并希望:
- - 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
- - 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
-
-- **场景三:企业政策与集成需求**
- - 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
- - 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
-
-- **场景四:需要商业支持与保障**
- 您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
-
-**获取商业许可:**
-请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
-
-## **3. 贡献 (Contributions)**
-
-- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
-- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
-- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
-
-## **4. 其他条款 (Other Terms)**
-
-- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
-- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
-
----
-
-# **New API Licensing**
-
-This project uses a **Usage-Based Dual Licensing** model.
-
-**Core Principles:**
-
-- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
-- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
-
----
-
-## **1. Open Source License: AGPLv3 – For Basic Usage**
-
-- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
-- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
-- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
-- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
-
-## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
-
-You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
-
-- **Scenario 1: Removal of Branding and Copyright**
- You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
-
-- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
- You have modified New API and wish to:
- - Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
- - Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
-
-- **Scenario 3: Enterprise Policy & Integration Needs**
- - Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
- - You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
-
-- **Scenario 4: Commercial Support and Assurances**
- You require commercial assurances not provided by AGPLv3, such as official technical support.
-
-**Obtaining a Commercial License:**
-Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
-
-## **3. Contributions**
-
-- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
-- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
-- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
-
-## **4. Other Terms**
-
-- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
-- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
diff --git a/new-api/README.en.md b/new-api/README.en.md
deleted file mode 100644
index bc19966b688ecd1571a5a4f3b9b38b92efee51fa..0000000000000000000000000000000000000000
--- a/new-api/README.en.md
+++ /dev/null
@@ -1,216 +0,0 @@
-
- 中文 | English | Français
-
-
-
-
-
-# New API
-
-🍥 Next-Generation Large Model Gateway and AI Asset Management System
-
-

-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-## 📝 Project Description
-
-> [!NOTE]
-> This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
-
-> [!IMPORTANT]
-> - This project is for personal learning purposes only, with no guarantee of stability or technical support.
-> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
-> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
-
-🤝 Trusted Partners
-
-No particular order
-
-
-
-
-
-
-
-
-
-## 📚 Documentation
-
-For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
-
-You can also access the AI-generated DeepWiki:
-[](https://deepwiki.com/QuantumNous/new-api)
-
-## ✨ Key Features
-
-New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
-
-1. 🎨 Brand new UI interface
-2. 🌍 Multi-language support
-3. 💰 Online recharge functionality (YiPay)
-4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
-5. 🔄 Compatible with the original One API database
-6. 💵 Support for pay-per-use model pricing
-7. ⚖️ Support for weighted random channel selection
-8. 📈 Data dashboard (console)
-9. 🔒 Token grouping and model restrictions
-10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC)
-11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
-12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime)
-13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
-14. Support for entering chat interface via /chat2link route
-15. 🧠 Support for setting reasoning effort through model name suffixes:
- 1. OpenAI o-series models
- - Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`)
- - Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`)
- - Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`)
- 2. Claude thinking models
- - Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`)
-16. 🔄 Thinking-to-content functionality
-17. 🔄 Model rate limiting for users
-18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit:
- 1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings`
- 2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit
- 3. Supported channels:
- - [x] OpenAI
- - [x] Azure
- - [x] DeepSeek
- - [x] Claude
-
-## Model Support
-
-This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details:
-
-1. Third-party models **gpts** (gpt-4-gizmo-*)
-2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image)
-3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music)
-4. Custom channels, supporting full call address input
-5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
-6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
-7. Dify, currently only supports chatflow
-
-## Environment Variable Configuration
-
-For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
-
-- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
-- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
-- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
-- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
-- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
-- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true`
-- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true`
-- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE`
-- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16`
-- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20`
-- `CRYPTO_SECRET`: Encryption key used for encrypting database content
-- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
-- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
-- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
-- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
-
-## Deployment
-
-For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation):
-
-> [!TIP]
-> Latest Docker image: `calciumion/new-api:latest`
-
-### Multi-machine Deployment Considerations
-- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines
-- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines
-
-### Deployment Requirements
-- Local database (default): SQLite (Docker deployment must mount the `/data` directory)
-- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6
-
-### Deployment Methods
-
-#### Using BaoTa Panel Docker Feature
-Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it.
-[Tutorial with images](./docs/BT.md)
-
-#### Using Docker Compose (Recommended)
-```shell
-# Download the project
-git clone https://github.com/Calcium-Ion/new-api.git
-cd new-api
-# Edit docker-compose.yml as needed
-# Start
-docker-compose up -d
-```
-
-#### Using Docker Image Directly
-```shell
-# Using SQLite
-docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-
-# Using MySQL
-docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-```
-
-## Channel Retry and Cache
-Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**.
-
-### Cache Configuration Method
-1. `REDIS_CONN_STRING`: Set Redis as cache
-2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set)
-
-## API Documentation
-
-For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api):
-
-- [Chat API](https://docs.newapi.pro/api/openai-chat)
-- [Image API](https://docs.newapi.pro/api/openai-image)
-- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank)
-- [Realtime API](https://docs.newapi.pro/api/openai-realtime)
-- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat)
-
-## Related Projects
-- [One API](https://github.com/songquanpeng/one-api): Original project
-- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support
-- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution
-- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key
-
-Other projects based on New API:
-- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API
-- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API
-
-## Help and Support
-
-If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support):
-- [Community Interaction](https://docs.newapi.pro/support/community-interaction)
-- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
-- [FAQ](https://docs.newapi.pro/support/faq)
-
-## 🌟 Star History
-
-[](https://star-history.com/#Calcium-Ion/new-api&Date)
diff --git a/new-api/README.fr.md b/new-api/README.fr.md
deleted file mode 100644
index 9b800003c4a216c1700fc2acb9216039be881462..0000000000000000000000000000000000000000
--- a/new-api/README.fr.md
+++ /dev/null
@@ -1,216 +0,0 @@
-
- 中文 | English | Français
-
-
-
-
-
-# New API
-
-🍥 Passerelle de modèles étendus de nouvelle génération et système de gestion d'actifs d'IA
-
-

-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-## 📝 Description du projet
-
-> [!NOTE]
-> Il s'agit d'un projet open-source développé sur la base de [One API](https://github.com/songquanpeng/one-api)
-
-> [!IMPORTANT]
-> - Ce projet est uniquement destiné à des fins d'apprentissage personnel, sans garantie de stabilité ni de support technique.
-> - Les utilisateurs doivent se conformer aux [Conditions d'utilisation](https://openai.com/policies/terms-of-use) d'OpenAI et aux **lois et réglementations applicables**, et ne doivent pas l'utiliser à des fins illégales.
-> - Conformément aux [《Mesures provisoires pour la gestion des services d'intelligence artificielle générative》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), veuillez ne fournir aucun service d'IA générative non enregistré au public en Chine.
-
-🤝 Partenaires de confiance
-
-Sans ordre particulier
-
-
-
-
-
-
-
-
-
-## 📚 Documentation
-
-Pour une documentation détaillée, veuillez consulter notre Wiki officiel : [https://docs.newapi.pro/](https://docs.newapi.pro/)
-
-Vous pouvez également accéder au DeepWiki généré par l'IA :
-[](https://deepwiki.com/QuantumNous/new-api)
-
-## ✨ Fonctionnalités clés
-
-New API offre un large éventail de fonctionnalités, veuillez vous référer à [Présentation des fonctionnalités](https://docs.newapi.pro/wiki/features-introduction) pour plus de détails :
-
-1. 🎨 Nouvelle interface utilisateur
-2. 🌍 Prise en charge multilingue
-3. 💰 Fonctionnalité de recharge en ligne (YiPay)
-4. 🔍 Prise en charge de la recherche de quotas d'utilisation avec des clés (fonctionne avec [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
-5. 🔄 Compatible avec la base de données originale de One API
-6. 💵 Prise en charge de la tarification des modèles de paiement à l'utilisation
-7. ⚖️ Prise en charge de la sélection aléatoire pondérée des canaux
-8. 📈 Tableau de bord des données (console)
-9. 🔒 Regroupement de jetons et restrictions de modèles
-10. 🤖 Prise en charge de plus de méthodes de connexion par autorisation (LinuxDO, Telegram, OIDC)
-11. 🔄 Prise en charge des modèles Rerank (Cohere et Jina), [Documentation de l'API](https://docs.newapi.pro/api/jinaai-rerank)
-12. ⚡ Prise en charge de l'API OpenAI Realtime (y compris les canaux Azure), [Documentation de l'API](https://docs.newapi.pro/api/openai-realtime)
-13. ⚡ Prise en charge du format Claude Messages, [Documentation de l'API](https://docs.newapi.pro/api/anthropic-chat)
-14. Prise en charge de l'accès à l'interface de discussion via la route /chat2link
-15. 🧠 Prise en charge de la définition de l'effort de raisonnement via les suffixes de nom de modèle :
- 1. Modèles de la série o d'OpenAI
- - Ajouter le suffixe `-high` pour un effort de raisonnement élevé (par exemple : `o3-mini-high`)
- - Ajouter le suffixe `-medium` pour un effort de raisonnement moyen (par exemple : `o3-mini-medium`)
- - Ajouter le suffixe `-low` pour un effort de raisonnement faible (par exemple : `o3-mini-low`)
- 2. Modèles de pensée de Claude
- - Ajouter le suffixe `-thinking` pour activer le mode de pensée (par exemple : `claude-3-7-sonnet-20250219-thinking`)
-16. 🔄 Fonctionnalité de la pensée au contenu
-17. 🔄 Limitation du débit du modèle pour les utilisateurs
-18. 💰 Prise en charge de la facturation du cache, qui permet de facturer à un ratio défini lorsque le cache est atteint :
- 1. Définir l'option `Ratio de cache d'invite` dans `Paramètres système->Paramètres de fonctionnement`
- 2. Définir le `Ratio de cache d'invite` dans le canal, plage de 0 à 1, par exemple, le définir sur 0,5 signifie facturer à 50 % lorsque le cache est atteint
- 3. Canaux pris en charge :
- - [x] OpenAI
- - [x] Azure
- - [x] DeepSeek
- - [x] Claude
-
-## Prise en charge des modèles
-
-Cette version prend en charge plusieurs modèles, veuillez vous référer à [Documentation de l'API-Interface de relais](https://docs.newapi.pro/api) pour plus de détails :
-
-1. Modèles tiers **gpts** (gpt-4-gizmo-*)
-2. Canal tiers [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy), [Documentation de l'API](https://docs.newapi.pro/api/midjourney-proxy-image)
-3. Canal tiers [Suno API](https://github.com/Suno-API/Suno-API), [Documentation de l'API](https://docs.newapi.pro/api/suno-music)
-4. Canaux personnalisés, prenant en charge la saisie complète de l'adresse d'appel
-5. Modèles Rerank ([Cohere](https://cohere.ai/) et [Jina](https://jina.ai/)), [Documentation de l'API](https://docs.newapi.pro/api/jinaai-rerank)
-6. Format de messages Claude, [Documentation de l'API](https://docs.newapi.pro/api/anthropic-chat)
-7. Dify, ne prend actuellement en charge que chatflow
-
-## Configuration des variables d'environnement
-
-Pour des instructions de configuration détaillées, veuillez vous référer à [Guide d'installation-Configuration des variables d'environnement](https://docs.newapi.pro/installation/environment-variables) :
-
-- `GENERATE_DEFAULT_TOKEN` : S'il faut générer des jetons initiaux pour les utilisateurs nouvellement enregistrés, la valeur par défaut est `false`
-- `STREAMING_TIMEOUT` : Délai d'expiration de la réponse en streaming, la valeur par défaut est de 300 secondes
-- `DIFY_DEBUG` : S'il faut afficher les informations sur le flux de travail et les nœuds pour les canaux Dify, la valeur par défaut est `true`
-- `FORCE_STREAM_OPTION` : S'il faut remplacer le paramètre client stream_options, la valeur par défaut est `true`
-- `GET_MEDIA_TOKEN` : S'il faut compter les jetons d'image, la valeur par défaut est `true`
-- `GET_MEDIA_TOKEN_NOT_STREAM` : S'il faut compter les jetons d'image dans les cas sans streaming, la valeur par défaut est `true`
-- `UPDATE_TASK` : S'il faut mettre à jour les tâches asynchrones (Midjourney, Suno), la valeur par défaut est `true`
-- `COHERE_SAFETY_SETTING` : Paramètres de sécurité du modèle Cohere, les options sont `NONE`, `CONTEXTUAL`, `STRICT`, la valeur par défaut est `NONE`
-- `GEMINI_VISION_MAX_IMAGE_NUM` : Nombre maximum d'images pour les modèles Gemini, la valeur par défaut est `16`
-- `MAX_FILE_DOWNLOAD_MB` : Taille maximale de téléchargement de fichier en Mo, la valeur par défaut est `20`
-- `CRYPTO_SECRET` : Clé de chiffrement utilisée pour chiffrer le contenu de la base de données
-- `AZURE_DEFAULT_API_VERSION` : Version de l'API par défaut du canal Azure, la valeur par défaut est `2025-04-01-preview`
-- `NOTIFICATION_LIMIT_DURATION_MINUTE` : Durée de la limite de notification, la valeur par défaut est de `10` minutes
-- `NOTIFY_LIMIT_COUNT` : Nombre maximal de notifications utilisateur dans la durée spécifiée, la valeur par défaut est `2`
-- `ERROR_LOG_ENABLED=true` : S'il faut enregistrer et afficher les journaux d'erreurs, la valeur par défaut est `false`
-
-## Déploiement
-
-Pour des guides de déploiement détaillés, veuillez vous référer à [Guide d'installation-Méthodes de déploiement](https://docs.newapi.pro/installation) :
-
-> [!TIP]
-> Dernière image Docker : `calciumion/new-api:latest`
-
-### Considérations sur le déploiement multi-machines
-- La variable d'environnement `SESSION_SECRET` doit être définie, sinon l'état de connexion sera incohérent sur plusieurs machines
-- Si vous partagez Redis, `CRYPTO_SECRET` doit être défini, sinon le contenu de Redis ne pourra pas être consulté sur plusieurs machines
-
-### Exigences de déploiement
-- Base de données locale (par défaut) : SQLite (le déploiement Docker doit monter le répertoire `/data`)
-- Base de données distante : MySQL version >= 5.7.8, PgSQL version >= 9.6
-
-### Méthodes de déploiement
-
-#### Utilisation de la fonctionnalité Docker du panneau BaoTa
-Installez le panneau BaoTa (version **9.2.0** ou supérieure), recherchez **New-API** dans le magasin d'applications et installez-le.
-[Tutoriel avec des images](./docs/BT.md)
-
-#### Utilisation de Docker Compose (recommandé)
-```shell
-# Télécharger le projet
-git clone https://github.com/Calcium-Ion/new-api.git
-cd new-api
-# Modifier docker-compose.yml si nécessaire
-# Démarrer
-docker-compose up -d
-```
-
-#### Utilisation directe de l'image Docker
-```shell
-# Utilisation de SQLite
-docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-
-# Utilisation de MySQL
-docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-```
-
-## Nouvelle tentative de canal et cache
-La fonctionnalité de nouvelle tentative de canal a été implémentée, vous pouvez définir le nombre de tentatives dans `Paramètres->Paramètres de fonctionnement->Paramètres généraux`. Il est **recommandé d'activer la mise en cache**.
-
-### Méthode de configuration du cache
-1. `REDIS_CONN_STRING` : Définir Redis comme cache
-2. `MEMORY_CACHE_ENABLED` : Activer le cache mémoire (pas besoin de le définir manuellement si Redis est défini)
-
-## Documentation de l'API
-
-Pour une documentation détaillée de l'API, veuillez vous référer à [Documentation de l'API](https://docs.newapi.pro/api) :
-
-- [API de discussion](https://docs.newapi.pro/api/openai-chat)
-- [API d'image](https://docs.newapi.pro/api/openai-image)
-- [API de rerank](https://docs.newapi.pro/api/jinaai-rerank)
-- [API en temps réel](https://docs.newapi.pro/api/openai-realtime)
-- [API de discussion Claude (messages)](https://docs.newapi.pro/api/anthropic-chat)
-
-## Projets connexes
-- [One API](https://github.com/songquanpeng/one-api) : Projet original
-- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy) : Prise en charge de l'interface Midjourney
-- [chatnio](https://github.com/Deeptrain-Community/chatnio) : Solution B/C unique d'IA de nouvelle génération
-- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool) : Interroger le quota d'utilisation avec une clé
-
-Autres projets basés sur New API :
-- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon) : Version optimisée hautes performances de New API
-- [VoAPI](https://github.com/VoAPI/VoAPI) : Version embellie du frontend basée sur New API
-
-## Aide et support
-
-Si vous avez des questions, veuillez vous référer à [Aide et support](https://docs.newapi.pro/support) :
-- [Interaction avec la communauté](https://docs.newapi.pro/support/community-interaction)
-- [Commentaires sur les problèmes](https://docs.newapi.pro/support/feedback-issues)
-- [FAQ](https://docs.newapi.pro/support/faq)
-
-## 🌟 Historique des étoiles
-
-[](https://star-history.com/#Calcium-Ion/new-api&Date)
\ No newline at end of file
diff --git a/new-api/README.md b/new-api/README.md
deleted file mode 100644
index 8198d5da32ebd36a72fc41c8d771cab11bd6f32c..0000000000000000000000000000000000000000
--- a/new-api/README.md
+++ /dev/null
@@ -1,219 +0,0 @@
-
- 中文 | English | Français
-
-
-
-
-
-# New API
-
-🍥新一代大模型网关与AI资产管理系统
-
-

-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-## 📝 项目说明
-
-> [!NOTE]
-> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
-
-> [!IMPORTANT]
-> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
-> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
-> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
-
-🤝 我们信任的合作伙伴
-
-排名不分先后
-
-
-
-
-
-
-
-
-
-## 📚 文档
-
-详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
-
-也可访问AI生成的DeepWiki:
-[](https://deepwiki.com/QuantumNous/new-api)
-
-## ✨ 主要特性
-
-New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
-
-1. 🎨 全新的UI界面
-2. 🌍 多语言支持
-3. 💰 支持在线充值功能(易支付)
-4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
-5. 🔄 兼容原版One API的数据库
-6. 💵 支持模型按次数收费
-7. ⚖️ 支持渠道加权随机
-8. 📈 数据看板(控制台)
-9. 🔒 令牌分组、模型限制
-10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC)
-11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
-12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime)
-13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
-14. 支持使用路由/chat2link进入聊天界面
-15. 🧠 支持通过模型名称后缀设置 reasoning effort:
- 1. OpenAI o系列模型
- - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
- 2. Claude 思考模型
- - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
-16. 🔄 思考转内容功能
-17. 🔄 针对用户的模型限流功能
-18. 🔄 请求格式转换功能,支持以下三种格式转换:
- 1. OpenAI Chat Completions => Claude Messages
- 2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型)
- 3. OpenAI Chat Completions => Gemini Chat
-19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
- 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
- 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
- 3. 支持的渠道:
- - [x] OpenAI
- - [x] Azure
- - [x] DeepSeek
- - [x] Claude
-
-## 模型支持
-
-此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api):
-
-1. 第三方模型 **gpts** (gpt-4-gizmo-*)
-2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image)
-3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
-4. 自定义渠道,支持填入完整调用地址
-5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
-6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
-7. Dify,当前仅支持chatflow
-
-## 环境变量配置
-
-详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
-
-- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
-- `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
-- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
-- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
-- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
-- `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true`
-- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true`
-- `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
-- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16`
-- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20`
-- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
-- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
-- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
-- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
-- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
-
-## 部署
-
-详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation):
-
-> [!TIP]
-> 最新版Docker镜像:`calciumion/new-api:latest`
-
-### 多机部署注意事项
-- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
-- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取
-
-### 部署要求
-- 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录)
-- 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6
-
-### 部署方式
-
-#### 使用宝塔面板Docker功能部署
-安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
-[图文教程](./docs/BT.md)
-
-#### 使用Docker Compose部署(推荐)
-```shell
-# 下载项目
-git clone https://github.com/Calcium-Ion/new-api.git
-cd new-api
-# 按需编辑docker-compose.yml
-# 启动
-docker-compose up -d
-```
-
-#### 直接使用Docker镜像
-```shell
-# 使用SQLite
-docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-
-# 使用MySQL
-docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
-```
-
-## 渠道重试与缓存
-渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
-
-### 缓存设置方法
-1. `REDIS_CONN_STRING`:设置Redis作为缓存
-2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置)
-
-## 接口文档
-
-详细接口文档请参考[接口文档](https://docs.newapi.pro/api):
-
-- [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat)
-- [图像接口(Image)](https://docs.newapi.pro/api/openai-image)
-- [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank)
-- [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime)
-- [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat)
-
-## 相关项目
-- [One API](https://github.com/songquanpeng/one-api):原版项目
-- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
-- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案
-- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
-
-其他基于New API的项目:
-- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
-
-## 帮助支持
-
-如有问题,请参考[帮助支持](https://docs.newapi.pro/support):
-- [社区交流](https://docs.newapi.pro/support/community-interaction)
-- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
-- [常见问题](https://docs.newapi.pro/support/faq)
-
-## 🌟 Star History
-
-[](https://star-history.com/#Calcium-Ion/new-api&Date)
diff --git a/new-api/VERSION b/new-api/VERSION
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/new-api/bin/migration_v0.2-v0.3.sql b/new-api/bin/migration_v0.2-v0.3.sql
deleted file mode 100644
index 5b18e72b4b878733c6af53a37683c8f898719564..0000000000000000000000000000000000000000
--- a/new-api/bin/migration_v0.2-v0.3.sql
+++ /dev/null
@@ -1,6 +0,0 @@
-UPDATE users
-SET quota = quota + (
- SELECT SUM(remain_quota)
- FROM tokens
- WHERE tokens.user_id = users.id
-)
diff --git a/new-api/bin/migration_v0.3-v0.4.sql b/new-api/bin/migration_v0.3-v0.4.sql
deleted file mode 100644
index 3c9893410c9fffcd381373539e589dbfc32f1158..0000000000000000000000000000000000000000
--- a/new-api/bin/migration_v0.3-v0.4.sql
+++ /dev/null
@@ -1,17 +0,0 @@
-INSERT INTO abilities (`group`, model, channel_id, enabled)
-SELECT c.`group`, m.model, c.id, 1
-FROM channels c
-CROSS JOIN (
- SELECT 'gpt-3.5-turbo' AS model UNION ALL
- SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
- SELECT 'gpt-4' AS model UNION ALL
- SELECT 'gpt-4-0314' AS model
-) AS m
-WHERE c.status = 1
- AND NOT EXISTS (
- SELECT 1
- FROM abilities a
- WHERE a.`group` = c.`group`
- AND a.model = m.model
- AND a.channel_id = c.id
-);
diff --git a/new-api/bin/time_test.sh b/new-api/bin/time_test.sh
deleted file mode 100644
index af5d23bce8db372f00bae4c450c260fed9506ae8..0000000000000000000000000000000000000000
--- a/new-api/bin/time_test.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/bin/bash
-
-if [ $# -lt 3 ]; then
- echo "Usage: time_test.sh []"
- exit 1
-fi
-
-domain=$1
-key=$2
-count=$3
-model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
-
-total_time=0
-times=()
-
-for ((i=1; i<=count; i++)); do
- result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \
- https://"$domain"/v1/chat/completions \
- -H "Content-Type: application/json" \
- -H "Authorization: Bearer $key" \
- -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
- http_code=$(echo "$result" | awk '{print $1}')
- time=$(echo "$result" | awk '{print $2}')
- echo "HTTP status code: $http_code, Time taken: $time"
- total_time=$(bc <<< "$total_time + $time")
- times+=("$time")
-done
-
-average_time=$(echo "scale=4; $total_time / $count" | bc)
-
-sum_of_squares=0
-for time in "${times[@]}"; do
- difference=$(echo "scale=4; $time - $average_time" | bc)
- square=$(echo "scale=4; $difference * $difference" | bc)
- sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc)
-done
-
-standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc)
-
-echo "Average time: $average_time±$standard_deviation"
diff --git a/new-api/common/api_type.go b/new-api/common/api_type.go
deleted file mode 100644
index 89ea5d76c578eb80676f4e29069e0c5dea4d7466..0000000000000000000000000000000000000000
--- a/new-api/common/api_type.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package common
-
-import "one-api/constant"
-
-func ChannelType2APIType(channelType int) (int, bool) {
- apiType := -1
- switch channelType {
- case constant.ChannelTypeOpenAI:
- apiType = constant.APITypeOpenAI
- case constant.ChannelTypeAnthropic:
- apiType = constant.APITypeAnthropic
- case constant.ChannelTypeBaidu:
- apiType = constant.APITypeBaidu
- case constant.ChannelTypePaLM:
- apiType = constant.APITypePaLM
- case constant.ChannelTypeZhipu:
- apiType = constant.APITypeZhipu
- case constant.ChannelTypeAli:
- apiType = constant.APITypeAli
- case constant.ChannelTypeXunfei:
- apiType = constant.APITypeXunfei
- case constant.ChannelTypeAIProxyLibrary:
- apiType = constant.APITypeAIProxyLibrary
- case constant.ChannelTypeTencent:
- apiType = constant.APITypeTencent
- case constant.ChannelTypeGemini:
- apiType = constant.APITypeGemini
- case constant.ChannelTypeZhipu_v4:
- apiType = constant.APITypeZhipuV4
- case constant.ChannelTypeOllama:
- apiType = constant.APITypeOllama
- case constant.ChannelTypePerplexity:
- apiType = constant.APITypePerplexity
- case constant.ChannelTypeAws:
- apiType = constant.APITypeAws
- case constant.ChannelTypeCohere:
- apiType = constant.APITypeCohere
- case constant.ChannelTypeDify:
- apiType = constant.APITypeDify
- case constant.ChannelTypeJina:
- apiType = constant.APITypeJina
- case constant.ChannelCloudflare:
- apiType = constant.APITypeCloudflare
- case constant.ChannelTypeSiliconFlow:
- apiType = constant.APITypeSiliconFlow
- case constant.ChannelTypeVertexAi:
- apiType = constant.APITypeVertexAi
- case constant.ChannelTypeMistral:
- apiType = constant.APITypeMistral
- case constant.ChannelTypeDeepSeek:
- apiType = constant.APITypeDeepSeek
- case constant.ChannelTypeMokaAI:
- apiType = constant.APITypeMokaAI
- case constant.ChannelTypeVolcEngine:
- apiType = constant.APITypeVolcEngine
- case constant.ChannelTypeBaiduV2:
- apiType = constant.APITypeBaiduV2
- case constant.ChannelTypeOpenRouter:
- apiType = constant.APITypeOpenRouter
- case constant.ChannelTypeXinference:
- apiType = constant.APITypeXinference
- case constant.ChannelTypeXai:
- apiType = constant.APITypeXai
- case constant.ChannelTypeCoze:
- apiType = constant.APITypeCoze
- case constant.ChannelTypeJimeng:
- apiType = constant.APITypeJimeng
- case constant.ChannelTypeMoonshot:
- apiType = constant.APITypeMoonshot
- case constant.ChannelTypeSubmodel:
- apiType = constant.APITypeSubmodel
- }
- if apiType == -1 {
- return constant.APITypeOpenAI, false
- }
- return apiType, true
-}
diff --git a/new-api/common/constants.go b/new-api/common/constants.go
deleted file mode 100644
index 7d45cb891b02309809b0775f440ea3c130145de7..0000000000000000000000000000000000000000
--- a/new-api/common/constants.go
+++ /dev/null
@@ -1,202 +0,0 @@
-package common
-
-import (
- //"os"
- //"strconv"
- "sync"
- "time"
-
- "github.com/google/uuid"
-)
-
-var StartTime = time.Now().Unix() // unit: second
-var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
-var SystemName = "New API"
-var Footer = ""
-var Logo = ""
-var TopUpLink = ""
-
-// var ChatLink = ""
-// var ChatLink2 = ""
-var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
-var DisplayInCurrencyEnabled = true
-var DisplayTokenStatEnabled = true
-var DrawingEnabled = true
-var TaskEnabled = true
-var DataExportEnabled = true
-var DataExportInterval = 5 // unit: minute
-var DataExportDefaultTime = "hour" // unit: minute
-var DefaultCollapseSidebar = false // default value of collapse sidebar
-
-// Any options with "Secret", "Token" in its key won't be return by GetOptions
-
-var SessionSecret = uuid.New().String()
-var CryptoSecret = uuid.New().String()
-
-var OptionMap map[string]string
-var OptionMapRWMutex sync.RWMutex
-
-var ItemsPerPage = 10
-var MaxRecentItems = 100
-
-var PasswordLoginEnabled = true
-var PasswordRegisterEnabled = true
-var EmailVerificationEnabled = false
-var GitHubOAuthEnabled = false
-var LinuxDOOAuthEnabled = false
-var WeChatAuthEnabled = false
-var TelegramOAuthEnabled = false
-var TurnstileCheckEnabled = false
-var RegisterEnabled = true
-
-var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
-var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
-var EmailDomainWhitelist = []string{
- "gmail.com",
- "163.com",
- "126.com",
- "qq.com",
- "outlook.com",
- "hotmail.com",
- "icloud.com",
- "yahoo.com",
- "foxmail.com",
-}
-var EmailLoginAuthServerList = []string{
- "smtp.sendcloud.net",
- "smtp.azurecomm.net",
-}
-
-var DebugEnabled bool
-var MemoryCacheEnabled bool
-
-var LogConsumeEnabled = true
-
-var SMTPServer = ""
-var SMTPPort = 587
-var SMTPSSLEnabled = false
-var SMTPAccount = ""
-var SMTPFrom = ""
-var SMTPToken = ""
-
-var GitHubClientId = ""
-var GitHubClientSecret = ""
-var LinuxDOClientId = ""
-var LinuxDOClientSecret = ""
-var LinuxDOMinimumTrustLevel = 0
-
-var WeChatServerAddress = ""
-var WeChatServerToken = ""
-var WeChatAccountQRCodeImageURL = ""
-
-var TurnstileSiteKey = ""
-var TurnstileSecretKey = ""
-
-var TelegramBotToken = ""
-var TelegramBotName = ""
-
-var QuotaForNewUser = 0
-var QuotaForInviter = 0
-var QuotaForInvitee = 0
-var ChannelDisableThreshold = 5.0
-var AutomaticDisableChannelEnabled = false
-var AutomaticEnableChannelEnabled = false
-var QuotaRemindThreshold = 1000
-var PreConsumedQuota = 500
-
-var RetryTimes = 0
-
-//var RootUserEmail = ""
-
-var IsMasterNode bool
-
-var requestInterval int
-var RequestInterval time.Duration
-
-var SyncFrequency int // unit is second
-
-var BatchUpdateEnabled = false
-var BatchUpdateInterval int
-
-var RelayTimeout int // unit is second
-
-var GeminiSafetySetting string
-
-// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
-var CohereSafetySetting string
-
-const (
- RequestIdKey = "X-Oneapi-Request-Id"
-)
-
-const (
- RoleGuestUser = 0
- RoleCommonUser = 1
- RoleAdminUser = 10
- RoleRootUser = 100
-)
-
-func IsValidateRole(role int) bool {
- return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
-}
-
-var (
- FileUploadPermission = RoleGuestUser
- FileDownloadPermission = RoleGuestUser
- ImageUploadPermission = RoleGuestUser
- ImageDownloadPermission = RoleGuestUser
-)
-
-// All duration's unit is seconds
-// Shouldn't larger then RateLimitKeyExpirationDuration
-var (
- GlobalApiRateLimitEnable bool
- GlobalApiRateLimitNum int
- GlobalApiRateLimitDuration int64
-
- GlobalWebRateLimitEnable bool
- GlobalWebRateLimitNum int
- GlobalWebRateLimitDuration int64
-
- UploadRateLimitNum = 10
- UploadRateLimitDuration int64 = 60
-
- DownloadRateLimitNum = 10
- DownloadRateLimitDuration int64 = 60
-
- CriticalRateLimitNum = 20
- CriticalRateLimitDuration int64 = 20 * 60
-)
-
-var RateLimitKeyExpirationDuration = 20 * time.Minute
-
-const (
- UserStatusEnabled = 1 // don't use 0, 0 is the default value!
- UserStatusDisabled = 2 // also don't use 0
-)
-
-const (
- TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
- TokenStatusDisabled = 2 // also don't use 0
- TokenStatusExpired = 3
- TokenStatusExhausted = 4
-)
-
-const (
- RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
- RedemptionCodeStatusDisabled = 2 // also don't use 0
- RedemptionCodeStatusUsed = 3 // also don't use 0
-)
-
-const (
- ChannelStatusUnknown = 0
- ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
- ChannelStatusManuallyDisabled = 2 // also don't use 0
- ChannelStatusAutoDisabled = 3
-)
-
-const (
- TopUpStatusPending = "pending"
- TopUpStatusSuccess = "success"
- TopUpStatusExpired = "expired"
-)
diff --git a/new-api/common/copy.go b/new-api/common/copy.go
deleted file mode 100644
index a7bfa760dfcd4cd37942059bbc577524c02f0d10..0000000000000000000000000000000000000000
--- a/new-api/common/copy.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package common
-
-import (
- "fmt"
-
- "github.com/jinzhu/copier"
-)
-
-func DeepCopy[T any](src *T) (*T, error) {
- if src == nil {
- return nil, fmt.Errorf("copy source cannot be nil")
- }
- var dst T
- err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
- if err != nil {
- return nil, err
- }
- return &dst, nil
-}
diff --git a/new-api/common/crypto.go b/new-api/common/crypto.go
deleted file mode 100644
index a4f3028a598e7e135f3afc685a58d6a370b0210b..0000000000000000000000000000000000000000
--- a/new-api/common/crypto.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package common
-
-import (
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "golang.org/x/crypto/bcrypt"
-)
-
-func GenerateHMACWithKey(key []byte, data string) string {
- h := hmac.New(sha256.New, key)
- h.Write([]byte(data))
- return hex.EncodeToString(h.Sum(nil))
-}
-
-func GenerateHMAC(data string) string {
- h := hmac.New(sha256.New, []byte(CryptoSecret))
- h.Write([]byte(data))
- return hex.EncodeToString(h.Sum(nil))
-}
-
-func Password2Hash(password string) (string, error) {
- passwordBytes := []byte(password)
- hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
- return string(hashedPassword), err
-}
-
-func ValidatePasswordAndHash(password string, hash string) bool {
- err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
- return err == nil
-}
diff --git a/new-api/common/custom-event.go b/new-api/common/custom-event.go
deleted file mode 100644
index 976e29969f3c735ea4988f4338145f467cf9b109..0000000000000000000000000000000000000000
--- a/new-api/common/custom-event.go
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
-// Use of this source code is governed by a MIT style
-// license that can be found in the LICENSE file.
-
-package common
-
-import (
- "fmt"
- "io"
- "net/http"
- "strings"
- "sync"
-)
-
-type stringWriter interface {
- io.Writer
- writeString(string) (int, error)
-}
-
-type stringWrapper struct {
- io.Writer
-}
-
-func (w stringWrapper) writeString(str string) (int, error) {
- return w.Writer.Write([]byte(str))
-}
-
-func checkWriter(writer io.Writer) stringWriter {
- if w, ok := writer.(stringWriter); ok {
- return w
- } else {
- return stringWrapper{writer}
- }
-}
-
-// Server-Sent Events
-// W3C Working Draft 29 October 2009
-// http://www.w3.org/TR/2009/WD-eventsource-20091029/
-
-var contentType = []string{"text/event-stream"}
-var noCache = []string{"no-cache"}
-
-var fieldReplacer = strings.NewReplacer(
- "\n", "\\n",
- "\r", "\\r")
-
-var dataReplacer = strings.NewReplacer(
- "\n", "\n",
- "\r", "\\r")
-
-type CustomEvent struct {
- Event string
- Id string
- Retry uint
- Data interface{}
-
- Mutex sync.Mutex
-}
-
-func encode(writer io.Writer, event CustomEvent) error {
- w := checkWriter(writer)
- return writeData(w, event.Data)
-}
-
-func writeData(w stringWriter, data interface{}) error {
- dataReplacer.WriteString(w, fmt.Sprint(data))
- if strings.HasPrefix(data.(string), "data") {
- w.writeString("\n\n")
- }
- return nil
-}
-
-func (r CustomEvent) Render(w http.ResponseWriter) error {
- r.WriteContentType(w)
- return encode(w, r)
-}
-
-func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
- r.Mutex.Lock()
- defer r.Mutex.Unlock()
- header := w.Header()
- header["Content-Type"] = contentType
-
- if _, exist := header["Cache-Control"]; !exist {
- header["Cache-Control"] = noCache
- }
-}
diff --git a/new-api/common/database.go b/new-api/common/database.go
deleted file mode 100644
index 2fad5a2ca03bc4bb69eb05a98983ad6bfdde50fd..0000000000000000000000000000000000000000
--- a/new-api/common/database.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package common
-
-const (
- DatabaseTypeMySQL = "mysql"
- DatabaseTypeSQLite = "sqlite"
- DatabaseTypePostgreSQL = "postgres"
-)
-
-var UsingSQLite = false
-var UsingPostgreSQL = false
-var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
-var UsingMySQL = false
-var UsingClickHouse = false
-
-var SQLitePath = "one-api.db?_busy_timeout=30000"
\ No newline at end of file
diff --git a/new-api/common/email-outlook-auth.go b/new-api/common/email-outlook-auth.go
deleted file mode 100644
index 070da21105a791f7ef470c50dc689039908218fa..0000000000000000000000000000000000000000
--- a/new-api/common/email-outlook-auth.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package common
-
-import (
- "errors"
- "net/smtp"
- "strings"
-)
-
-type outlookAuth struct {
- username, password string
-}
-
-func LoginAuth(username, password string) smtp.Auth {
- return &outlookAuth{username, password}
-}
-
-func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
- return "LOGIN", []byte{}, nil
-}
-
-func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
- if more {
- switch string(fromServer) {
- case "Username:":
- return []byte(a.username), nil
- case "Password:":
- return []byte(a.password), nil
- default:
- return nil, errors.New("unknown fromServer")
- }
- }
- return nil, nil
-}
-
-func isOutlookServer(server string) bool {
- // 兼容多地区的outlook邮箱和ofb邮箱
- // 其实应该加一个Option来区分是否用LOGIN的方式登录
- // 先临时兼容一下
- return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
-}
diff --git a/new-api/common/email.go b/new-api/common/email.go
deleted file mode 100644
index 1aab000c422bdc16e85641e48e306a5e4319279c..0000000000000000000000000000000000000000
--- a/new-api/common/email.go
+++ /dev/null
@@ -1,90 +0,0 @@
-package common
-
-import (
- "crypto/tls"
- "encoding/base64"
- "fmt"
- "net/smtp"
- "slices"
- "strings"
- "time"
-)
-
-func generateMessageID() (string, error) {
- split := strings.Split(SMTPFrom, "@")
- if len(split) < 2 {
- return "", fmt.Errorf("invalid SMTP account")
- }
- domain := strings.Split(SMTPFrom, "@")[1]
- return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil
-}
-
-func SendEmail(subject string, receiver string, content string) error {
- if SMTPFrom == "" { // for compatibility
- SMTPFrom = SMTPAccount
- }
- id, err2 := generateMessageID()
- if err2 != nil {
- return err2
- }
- if SMTPServer == "" && SMTPAccount == "" {
- return fmt.Errorf("SMTP 服务器未配置")
- }
- encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
- mail := []byte(fmt.Sprintf("To: %s\r\n"+
- "From: %s<%s>\r\n"+
- "Subject: %s\r\n"+
- "Date: %s\r\n"+
- "Message-ID: %s\r\n"+ // 添加 Message-ID 头
- "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
- receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content))
- auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
- addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
- to := strings.Split(receiver, ";")
- var err error
- if SMTPPort == 465 || SMTPSSLEnabled {
- tlsConfig := &tls.Config{
- InsecureSkipVerify: true,
- ServerName: SMTPServer,
- }
- conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
- if err != nil {
- return err
- }
- client, err := smtp.NewClient(conn, SMTPServer)
- if err != nil {
- return err
- }
- defer client.Close()
- if err = client.Auth(auth); err != nil {
- return err
- }
- if err = client.Mail(SMTPFrom); err != nil {
- return err
- }
- receiverEmails := strings.Split(receiver, ";")
- for _, receiver := range receiverEmails {
- if err = client.Rcpt(receiver); err != nil {
- return err
- }
- }
- w, err := client.Data()
- if err != nil {
- return err
- }
- _, err = w.Write(mail)
- if err != nil {
- return err
- }
- err = w.Close()
- if err != nil {
- return err
- }
- } else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
- auth = LoginAuth(SMTPAccount, SMTPToken)
- err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
- } else {
- err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
- }
- return err
-}
diff --git a/new-api/common/embed-file-system.go b/new-api/common/embed-file-system.go
deleted file mode 100644
index bf2247e1dcc6b28176659c0aea0d3a88c74c6039..0000000000000000000000000000000000000000
--- a/new-api/common/embed-file-system.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package common
-
-import (
- "embed"
- "github.com/gin-contrib/static"
- "io/fs"
- "net/http"
-)
-
-// Credit: https://github.com/gin-contrib/static/issues/19
-
-type embedFileSystem struct {
- http.FileSystem
-}
-
-func (e embedFileSystem) Exists(prefix string, path string) bool {
- _, err := e.Open(path)
- if err != nil {
- return false
- }
- return true
-}
-
-func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
- efs, err := fs.Sub(fsEmbed, targetPath)
- if err != nil {
- panic(err)
- }
- return embedFileSystem{
- FileSystem: http.FS(efs),
- }
-}
diff --git a/new-api/common/endpoint_defaults.go b/new-api/common/endpoint_defaults.go
deleted file mode 100644
index 8915d23a129f53bdf0295eb1cbbd5fcd63f1b54d..0000000000000000000000000000000000000000
--- a/new-api/common/endpoint_defaults.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package common
-
-import "one-api/constant"
-
-// EndpointInfo 描述单个端点的默认请求信息
-// path: 上游路径
-// method: HTTP 请求方式,例如 POST/GET
-// 目前均为 POST,后续可扩展
-//
-// json 标签用于直接序列化到 API 输出
-// 例如:{"path":"/v1/chat/completions","method":"POST"}
-
-type EndpointInfo struct {
- Path string `json:"path"`
- Method string `json:"method"`
-}
-
-// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
-var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
- constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
- constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
- constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
- constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
- constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
- constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
- constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"},
-}
-
-// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
-func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
- info, ok := defaultEndpointInfoMap[et]
- return info, ok
-}
diff --git a/new-api/common/endpoint_type.go b/new-api/common/endpoint_type.go
deleted file mode 100644
index d473ac7b569e22f1e48c3cd8306778b7aa3063d8..0000000000000000000000000000000000000000
--- a/new-api/common/endpoint_type.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package common
-
-import "one-api/constant"
-
-// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
-func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
- var endpointTypes []constant.EndpointType
- switch channelType {
- case constant.ChannelTypeJina:
- endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
- //case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
- // endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
- //case constant.ChannelTypeSunoAPI:
- // endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
- //case constant.ChannelTypeKling:
- // endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
- //case constant.ChannelTypeJimeng:
- // endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
- case constant.ChannelTypeAws:
- fallthrough
- case constant.ChannelTypeAnthropic:
- endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
- case constant.ChannelTypeVertexAi:
- fallthrough
- case constant.ChannelTypeGemini:
- endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
- case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
- endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
- default:
- if IsOpenAIResponseOnlyModel(modelName) {
- endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
- } else {
- endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
- }
- }
- if IsImageGenerationModel(modelName) {
- // add to first
- endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
- }
- return endpointTypes
-}
diff --git a/new-api/common/env.go b/new-api/common/env.go
deleted file mode 100644
index a22344568f2368668d88affc39c287117c5fa264..0000000000000000000000000000000000000000
--- a/new-api/common/env.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package common
-
-import (
- "fmt"
- "os"
- "strconv"
-)
-
-func GetEnvOrDefault(env string, defaultValue int) int {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- num, err := strconv.Atoi(os.Getenv(env))
- if err != nil {
- SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
- return defaultValue
- }
- return num
-}
-
-func GetEnvOrDefaultString(env string, defaultValue string) string {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- return os.Getenv(env)
-}
-
-func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
- if env == "" || os.Getenv(env) == "" {
- return defaultValue
- }
- b, err := strconv.ParseBool(os.Getenv(env))
- if err != nil {
- SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
- return defaultValue
- }
- return b
-}
diff --git a/new-api/common/gin.go b/new-api/common/gin.go
deleted file mode 100644
index 8b74c51f3153649c856e26df5796955a791b134d..0000000000000000000000000000000000000000
--- a/new-api/common/gin.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package common
-
-import (
- "bytes"
- "io"
- "net/http"
- "one-api/constant"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-const KeyRequestBody = "key_request_body"
-
-func GetRequestBody(c *gin.Context) ([]byte, error) {
- requestBody, _ := c.Get(KeyRequestBody)
- if requestBody != nil {
- return requestBody.([]byte), nil
- }
- requestBody, err := io.ReadAll(c.Request.Body)
- if err != nil {
- return nil, err
- }
- _ = c.Request.Body.Close()
- c.Set(KeyRequestBody, requestBody)
- return requestBody.([]byte), nil
-}
-
-func UnmarshalBodyReusable(c *gin.Context, v any) error {
- requestBody, err := GetRequestBody(c)
- if err != nil {
- return err
- }
- //if DebugEnabled {
- // println("UnmarshalBodyReusable request body:", string(requestBody))
- //}
- contentType := c.Request.Header.Get("Content-Type")
- if strings.HasPrefix(contentType, "application/json") {
- err = Unmarshal(requestBody, &v)
- } else {
- // skip for now
- // TODO: someday non json request have variant model, we will need to implementation this
- }
- if err != nil {
- return err
- }
- // Reset request body
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return nil
-}
-
-func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
- c.Set(string(key), value)
-}
-
-func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
- return c.Get(string(key))
-}
-
-func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
- return c.GetString(string(key))
-}
-
-func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
- return c.GetInt(string(key))
-}
-
-func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
- return c.GetBool(string(key))
-}
-
-func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
- return c.GetStringSlice(string(key))
-}
-
-func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
- return c.GetStringMap(string(key))
-}
-
-func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
- return c.GetTime(string(key))
-}
-
-func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
- if value, ok := c.Get(string(key)); ok {
- if v, ok := value.(T); ok {
- return v, true
- }
- }
- var t T
- return t, false
-}
-
-func ApiError(c *gin.Context, err error) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
-}
-
-func ApiErrorMsg(c *gin.Context, msg string) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": msg,
- })
-}
-
-func ApiSuccess(c *gin.Context, data any) {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": data,
- })
-}
diff --git a/new-api/common/go-channel.go b/new-api/common/go-channel.go
deleted file mode 100644
index 65b5537cf5d5e3edfdb925b87c41396e8ce37fcf..0000000000000000000000000000000000000000
--- a/new-api/common/go-channel.go
+++ /dev/null
@@ -1,53 +0,0 @@
-package common
-
-import (
- "time"
-)
-
-func SafeSendBool(ch chan bool, value bool) (closed bool) {
- defer func() {
- // Recover from panic if one occured. A panic would mean the channel was closed.
- if recover() != nil {
- closed = true
- }
- }()
-
- // This will panic if the channel is closed.
- ch <- value
-
- // If the code reaches here, then the channel was not closed.
- return false
-}
-
-func SafeSendString(ch chan string, value string) (closed bool) {
- defer func() {
- // Recover from panic if one occured. A panic would mean the channel was closed.
- if recover() != nil {
- closed = true
- }
- }()
-
- // This will panic if the channel is closed.
- ch <- value
-
- // If the code reaches here, then the channel was not closed.
- return false
-}
-
-// SafeSendStringTimeout send, return true, else return false
-func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
- defer func() {
- // Recover from panic if one occured. A panic would mean the channel was closed.
- if recover() != nil {
- closed = false
- }
- }()
-
- // This will panic if the channel is closed.
- select {
- case ch <- value:
- return true
- case <-time.After(time.Duration(timeout) * time.Second):
- return false
- }
-}
diff --git a/new-api/common/gopool.go b/new-api/common/gopool.go
deleted file mode 100644
index 9eac80ba5e06a7daecdcfa0f9be9c3f2100cddbb..0000000000000000000000000000000000000000
--- a/new-api/common/gopool.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package common
-
-import (
- "context"
- "fmt"
- "github.com/bytedance/gopkg/util/gopool"
- "math"
-)
-
-var relayGoPool gopool.Pool
-
-func init() {
- relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
- relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
- if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
- SafeSendBool(stopChan, true)
- }
- SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
- })
-}
-
-func RelayCtxGo(ctx context.Context, f func()) {
- relayGoPool.CtxGo(ctx, f)
-}
diff --git a/new-api/common/hash.go b/new-api/common/hash.go
deleted file mode 100644
index f7c82748384da8b1a5eb8ae6289f3fbeda3491ed..0000000000000000000000000000000000000000
--- a/new-api/common/hash.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package common
-
-import (
- "crypto/hmac"
- "crypto/sha1"
- "crypto/sha256"
- "encoding/hex"
-)
-
-func Sha256Raw(data []byte) []byte {
- h := sha256.New()
- h.Write(data)
- return h.Sum(nil)
-}
-
-func Sha1Raw(data []byte) []byte {
- h := sha1.New()
- h.Write(data)
- return h.Sum(nil)
-}
-
-func Sha1(data []byte) string {
- return hex.EncodeToString(Sha1Raw(data))
-}
-
-func HmacSha256Raw(message, key []byte) []byte {
- h := hmac.New(sha256.New, key)
- h.Write(message)
- return h.Sum(nil)
-}
-
-func HmacSha256(message, key string) string {
- return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
-}
diff --git a/new-api/common/init.go b/new-api/common/init.go
deleted file mode 100644
index 53782d56b11c582d9095ea6aa1c663c1788ed461..0000000000000000000000000000000000000000
--- a/new-api/common/init.go
+++ /dev/null
@@ -1,120 +0,0 @@
-package common
-
-import (
- "flag"
- "fmt"
- "log"
- "one-api/constant"
- "os"
- "path/filepath"
- "strconv"
- "time"
-)
-
-var (
- Port = flag.Int("port", 3000, "the listening port")
- PrintVersion = flag.Bool("version", false, "print version and exit")
- PrintHelp = flag.Bool("help", false, "print help and exit")
- LogDir = flag.String("log-dir", "./logs", "specify the log directory")
-)
-
-func printHelp() {
- fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
- fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
- fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
- fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]")
-}
-
-func InitEnv() {
- flag.Parse()
-
- if *PrintVersion {
- fmt.Println(Version)
- os.Exit(0)
- }
-
- if *PrintHelp {
- printHelp()
- os.Exit(0)
- }
-
- if os.Getenv("SESSION_SECRET") != "" {
- ss := os.Getenv("SESSION_SECRET")
- if ss == "random_string" {
- log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
- log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
- log.Fatal("Please set SESSION_SECRET to a random string.")
- } else {
- SessionSecret = ss
- }
- }
- if os.Getenv("CRYPTO_SECRET") != "" {
- CryptoSecret = os.Getenv("CRYPTO_SECRET")
- } else {
- CryptoSecret = SessionSecret
- }
- if os.Getenv("SQLITE_PATH") != "" {
- SQLitePath = os.Getenv("SQLITE_PATH")
- }
- if *LogDir != "" {
- var err error
- *LogDir, err = filepath.Abs(*LogDir)
- if err != nil {
- log.Fatal(err)
- }
- if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
- err = os.Mkdir(*LogDir, 0777)
- if err != nil {
- log.Fatal(err)
- }
- }
- }
-
- // Initialize variables from constants.go that were using environment variables
- DebugEnabled = os.Getenv("DEBUG") == "true"
- MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
- IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
-
- // Parse requestInterval and set RequestInterval
- requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
- RequestInterval = time.Duration(requestInterval) * time.Second
-
- // Initialize variables with GetEnvOrDefault
- SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
- BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
- RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
-
- // Initialize string variables with GetEnvOrDefaultString
- GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
- CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
-
- // Initialize rate limit variables
- GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
- GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
- GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
-
- GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
- GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
- GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
-
- initConstantEnv()
-}
-
-func initConstantEnv() {
- constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
- constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
- constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
- // ForceStreamOption 覆盖请求参数,强制返回usage信息
- constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
- constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
- constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
- constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
- constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
- constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
- constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
- constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
- // GenerateDefaultToken 是否生成初始令牌,默认关闭。
- constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
- // 是否启用错误日志
- constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
-}
diff --git a/new-api/common/ip.go b/new-api/common/ip.go
deleted file mode 100644
index 7b61a012e5bc2489b382c477251e237ddd5759ca..0000000000000000000000000000000000000000
--- a/new-api/common/ip.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package common
-
-import "net"
-
-func IsPrivateIP(ip net.IP) bool {
- if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
- return true
- }
-
- private := []net.IPNet{
- {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
- {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
- {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
- }
-
- for _, privateNet := range private {
- if privateNet.Contains(ip) {
- return true
- }
- }
- return false
-}
diff --git a/new-api/common/json.go b/new-api/common/json.go
deleted file mode 100644
index 531e737d65e2d0d9496c508ec2c2e8a8e0999bf2..0000000000000000000000000000000000000000
--- a/new-api/common/json.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package common
-
-import (
- "bytes"
- "encoding/json"
-)
-
-func Unmarshal(data []byte, v any) error {
- return json.Unmarshal(data, v)
-}
-
-func UnmarshalJsonStr(data string, v any) error {
- return json.Unmarshal(StringToByteSlice(data), v)
-}
-
-func DecodeJson(reader *bytes.Reader, v any) error {
- return json.NewDecoder(reader).Decode(v)
-}
-
-func Marshal(v any) ([]byte, error) {
- return json.Marshal(v)
-}
-
-func GetJsonType(data json.RawMessage) string {
- data = bytes.TrimSpace(data)
- if len(data) == 0 {
- return "unknown"
- }
- firstChar := bytes.TrimSpace(data)[0]
- switch firstChar {
- case '{':
- return "object"
- case '[':
- return "array"
- case '"':
- return "string"
- case 't', 'f':
- return "boolean"
- case 'n':
- return "null"
- default:
- return "number"
- }
-}
diff --git a/new-api/common/limiter/limiter.go b/new-api/common/limiter/limiter.go
deleted file mode 100644
index e966cc898b59f885f82498f3fb39e3387f3c2145..0000000000000000000000000000000000000000
--- a/new-api/common/limiter/limiter.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package limiter
-
-import (
- "context"
- _ "embed"
- "fmt"
- "github.com/go-redis/redis/v8"
- "one-api/common"
- "sync"
-)
-
-//go:embed lua/rate_limit.lua
-var rateLimitScript string
-
-type RedisLimiter struct {
- client *redis.Client
- limitScriptSHA string
-}
-
-var (
- instance *RedisLimiter
- once sync.Once
-)
-
-func New(ctx context.Context, r *redis.Client) *RedisLimiter {
- once.Do(func() {
- // 预加载脚本
- limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
- if err != nil {
- common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
- }
- instance = &RedisLimiter{
- client: r,
- limitScriptSHA: limitSHA,
- }
- })
-
- return instance
-}
-
-func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
- // 默认配置
- config := &Config{
- Capacity: 10,
- Rate: 1,
- Requested: 1,
- }
-
- // 应用选项模式
- for _, opt := range opts {
- opt(config)
- }
-
- // 执行限流
- result, err := rl.client.EvalSha(
- ctx,
- rl.limitScriptSHA,
- []string{key},
- config.Requested,
- config.Rate,
- config.Capacity,
- ).Int()
-
- if err != nil {
- return false, fmt.Errorf("rate limit failed: %w", err)
- }
- return result == 1, nil
-}
-
-// Config 配置选项模式
-type Config struct {
- Capacity int64
- Rate int64
- Requested int64
-}
-
-type Option func(*Config)
-
-func WithCapacity(c int64) Option {
- return func(cfg *Config) { cfg.Capacity = c }
-}
-
-func WithRate(r int64) Option {
- return func(cfg *Config) { cfg.Rate = r }
-}
-
-func WithRequested(n int64) Option {
- return func(cfg *Config) { cfg.Requested = n }
-}
diff --git a/new-api/common/limiter/lua/rate_limit.lua b/new-api/common/limiter/lua/rate_limit.lua
deleted file mode 100644
index c67fac240f040838b3119c3a4f43b82fa11ae724..0000000000000000000000000000000000000000
--- a/new-api/common/limiter/lua/rate_limit.lua
+++ /dev/null
@@ -1,44 +0,0 @@
--- 令牌桶限流器
--- KEYS[1]: 限流器唯一标识
--- ARGV[1]: 请求令牌数 (通常为1)
--- ARGV[2]: 令牌生成速率 (每秒)
--- ARGV[3]: 桶容量
-
-local key = KEYS[1]
-local requested = tonumber(ARGV[1])
-local rate = tonumber(ARGV[2])
-local capacity = tonumber(ARGV[3])
-
--- 获取当前时间(Redis服务器时间)
-local now = redis.call('TIME')
-local nowInSeconds = tonumber(now[1])
-
--- 获取桶状态
-local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
-local tokens = tonumber(bucket[1])
-local last_time = tonumber(bucket[2])
-
--- 初始化桶(首次请求或过期)
-if not tokens or not last_time then
- tokens = capacity
- last_time = nowInSeconds
-else
- -- 计算新增令牌
- local elapsed = nowInSeconds - last_time
- local add_tokens = elapsed * rate
- tokens = math.min(capacity, tokens + add_tokens)
- last_time = nowInSeconds
-end
-
--- 判断是否允许请求
-local allowed = false
-if tokens >= requested then
- tokens = tokens - requested
- allowed = true
-end
-
----- 更新桶状态并设置过期时间
-redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
---redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
-
-return allowed and 1 or 0
\ No newline at end of file
diff --git a/new-api/common/model.go b/new-api/common/model.go
deleted file mode 100644
index 181e2c3a0ca3a61fa564996220e629f20f423be5..0000000000000000000000000000000000000000
--- a/new-api/common/model.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package common
-
-import "strings"
-
-var (
- // OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
- OpenAIResponseOnlyModels = []string{
- "o3-pro",
- "o3-deep-research",
- "o4-mini-deep-research",
- }
- ImageGenerationModels = []string{
- "dall-e-3",
- "dall-e-2",
- "gpt-image-1",
- "prefix:imagen-",
- "flux-",
- "flux.1-",
- }
-)
-
-func IsOpenAIResponseOnlyModel(modelName string) bool {
- for _, m := range OpenAIResponseOnlyModels {
- if strings.Contains(modelName, m) {
- return true
- }
- }
- return false
-}
-
-func IsImageGenerationModel(modelName string) bool {
- modelName = strings.ToLower(modelName)
- for _, m := range ImageGenerationModels {
- if strings.Contains(modelName, m) {
- return true
- }
- if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
- return true
- }
- }
- return false
-}
diff --git a/new-api/common/page_info.go b/new-api/common/page_info.go
deleted file mode 100644
index 58bf2ab12d0aa1afa8fdbc48c55b9aca2fef6875..0000000000000000000000000000000000000000
--- a/new-api/common/page_info.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package common
-
-import (
- "strconv"
-
- "github.com/gin-gonic/gin"
-)
-
-type PageInfo struct {
- Page int `json:"page"` // page num 页码
- PageSize int `json:"page_size"` // page size 页大小
-
- Total int `json:"total"` // 总条数,后设置
- Items any `json:"items"` // 数据,后设置
-}
-
-func (p *PageInfo) GetStartIdx() int {
- return (p.Page - 1) * p.PageSize
-}
-
-func (p *PageInfo) GetEndIdx() int {
- return p.Page * p.PageSize
-}
-
-func (p *PageInfo) GetPageSize() int {
- return p.PageSize
-}
-
-func (p *PageInfo) GetPage() int {
- return p.Page
-}
-
-func (p *PageInfo) SetTotal(total int) {
- p.Total = total
-}
-
-func (p *PageInfo) SetItems(items any) {
- p.Items = items
-}
-
-func GetPageQuery(c *gin.Context) *PageInfo {
- pageInfo := &PageInfo{}
- // 手动获取并处理每个参数
- if page, err := strconv.Atoi(c.Query("p")); err == nil {
- pageInfo.Page = page
- }
- if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
- pageInfo.PageSize = pageSize
- }
- if pageInfo.Page < 1 {
- // 兼容
- page, _ := strconv.Atoi(c.Query("p"))
- if page != 0 {
- pageInfo.Page = page
- } else {
- pageInfo.Page = 1
- }
- }
-
- if pageInfo.PageSize == 0 {
- // 兼容
- pageSize, _ := strconv.Atoi(c.Query("ps"))
- if pageSize != 0 {
- pageInfo.PageSize = pageSize
- }
- if pageInfo.PageSize == 0 {
- pageSize, _ = strconv.Atoi(c.Query("size")) // token page
- if pageSize != 0 {
- pageInfo.PageSize = pageSize
- }
- }
- if pageInfo.PageSize == 0 {
- pageInfo.PageSize = ItemsPerPage
- }
- }
-
- if pageInfo.PageSize > 100 {
- pageInfo.PageSize = 100
- }
-
- return pageInfo
-}
diff --git a/new-api/common/pprof.go b/new-api/common/pprof.go
deleted file mode 100644
index b18fd024473526dfd960c2b134e862ac84a152eb..0000000000000000000000000000000000000000
--- a/new-api/common/pprof.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package common
-
-import (
- "fmt"
- "github.com/shirou/gopsutil/cpu"
- "os"
- "runtime/pprof"
- "time"
-)
-
-// Monitor 定时监控cpu使用率,超过阈值输出pprof文件
-func Monitor() {
- for {
- percent, err := cpu.Percent(time.Second, false)
- if err != nil {
- panic(err)
- }
- if percent[0] > 80 {
- fmt.Println("cpu usage too high")
- // write pprof file
- if _, err := os.Stat("./pprof"); os.IsNotExist(err) {
- err := os.Mkdir("./pprof", os.ModePerm)
- if err != nil {
- SysLog("创建pprof文件夹失败 " + err.Error())
- continue
- }
- }
- f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405")))
- if err != nil {
- SysLog("创建pprof文件失败 " + err.Error())
- continue
- }
- err = pprof.StartCPUProfile(f)
- if err != nil {
- SysLog("启动pprof失败 " + err.Error())
- continue
- }
- time.Sleep(10 * time.Second) // profile for 30 seconds
- pprof.StopCPUProfile()
- f.Close()
- }
- time.Sleep(30 * time.Second)
- }
-}
diff --git a/new-api/common/quota.go b/new-api/common/quota.go
deleted file mode 100644
index 5961d3c4939db7db05b388086d74c753b23ed568..0000000000000000000000000000000000000000
--- a/new-api/common/quota.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package common
-
-func GetTrustQuota() int {
- return int(10 * QuotaPerUnit)
-}
diff --git a/new-api/common/rate-limit.go b/new-api/common/rate-limit.go
deleted file mode 100644
index be08b6fbfeb7e90ce6daa6ced7412e26e0d72a38..0000000000000000000000000000000000000000
--- a/new-api/common/rate-limit.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package common
-
-import (
- "sync"
- "time"
-)
-
-type InMemoryRateLimiter struct {
- store map[string]*[]int64
- mutex sync.Mutex
- expirationDuration time.Duration
-}
-
-func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
- if l.store == nil {
- l.mutex.Lock()
- if l.store == nil {
- l.store = make(map[string]*[]int64)
- l.expirationDuration = expirationDuration
- if expirationDuration > 0 {
- go l.clearExpiredItems()
- }
- }
- l.mutex.Unlock()
- }
-}
-
-func (l *InMemoryRateLimiter) clearExpiredItems() {
- for {
- time.Sleep(l.expirationDuration)
- l.mutex.Lock()
- now := time.Now().Unix()
- for key := range l.store {
- queue := l.store[key]
- size := len(*queue)
- if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
- delete(l.store, key)
- }
- }
- l.mutex.Unlock()
- }
-}
-
-// Request parameter duration's unit is seconds
-func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
- l.mutex.Lock()
- defer l.mutex.Unlock()
- // [old <-- new]
- queue, ok := l.store[key]
- now := time.Now().Unix()
- if ok {
- if len(*queue) < maxRequestNum {
- *queue = append(*queue, now)
- return true
- } else {
- if now-(*queue)[0] >= duration {
- *queue = (*queue)[1:]
- *queue = append(*queue, now)
- return true
- } else {
- return false
- }
- }
- } else {
- s := make([]int64, 0, maxRequestNum)
- l.store[key] = &s
- *(l.store[key]) = append(*(l.store[key]), now)
- }
- return true
-}
diff --git a/new-api/common/redis.go b/new-api/common/redis.go
deleted file mode 100644
index 90a9ce3c4a253f3c11e006b59da1e50644052556..0000000000000000000000000000000000000000
--- a/new-api/common/redis.go
+++ /dev/null
@@ -1,327 +0,0 @@
-package common
-
-import (
- "context"
- "errors"
- "fmt"
- "os"
- "reflect"
- "strconv"
- "time"
-
- "github.com/go-redis/redis/v8"
- "gorm.io/gorm"
-)
-
-var RDB *redis.Client
-var RedisEnabled = true
-
-func RedisKeyCacheSeconds() int {
- return SyncFrequency
-}
-
-// InitRedisClient This function is called after init()
-func InitRedisClient() (err error) {
- if os.Getenv("REDIS_CONN_STRING") == "" {
- RedisEnabled = false
- SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
- return nil
- }
- if os.Getenv("SYNC_FREQUENCY") == "" {
- SysLog("SYNC_FREQUENCY not set, use default value 60")
- SyncFrequency = 60
- }
- SysLog("Redis is enabled")
- opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
- if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
- }
- opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
- RDB = redis.NewClient(opt)
-
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- _, err = RDB.Ping(ctx).Result()
- if err != nil {
- FatalLog("Redis ping test failed: " + err.Error())
- }
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
- SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
- }
- return err
-}
-
-func ParseRedisOption() *redis.Options {
- opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
- if err != nil {
- FatalLog("failed to parse Redis connection string: " + err.Error())
- }
- return opt
-}
-
-func RedisSet(key string, value string, expiration time.Duration) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
- }
- ctx := context.Background()
- return RDB.Set(ctx, key, value, expiration).Err()
-}
-
-func RedisGet(key string) (string, error) {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis GET: key=%s", key))
- }
- ctx := context.Background()
- val, err := RDB.Get(ctx, key).Result()
- return val, err
-}
-
-//func RedisExpire(key string, expiration time.Duration) error {
-// ctx := context.Background()
-// return RDB.Expire(ctx, key, expiration).Err()
-//}
-//
-//func RedisGetEx(key string, expiration time.Duration) (string, error) {
-// ctx := context.Background()
-// return RDB.GetSet(ctx, key, expiration).Result()
-//}
-
-func RedisDel(key string) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
- }
- ctx := context.Background()
- return RDB.Del(ctx, key).Err()
-}
-
-func RedisDelKey(key string) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
- }
- ctx := context.Background()
- return RDB.Del(ctx, key).Err()
-}
-
-func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
- }
- ctx := context.Background()
-
- data := make(map[string]interface{})
-
- // 使用反射遍历结构体字段
- v := reflect.ValueOf(obj).Elem()
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- field := t.Field(i)
- value := v.Field(i)
-
- // Skip DeletedAt field
- if field.Type.String() == "gorm.DeletedAt" {
- continue
- }
-
- // 处理指针类型
- if value.Kind() == reflect.Ptr {
- if value.IsNil() {
- data[field.Name] = ""
- continue
- }
- value = value.Elem()
- }
-
- // 处理布尔类型
- if value.Kind() == reflect.Bool {
- data[field.Name] = strconv.FormatBool(value.Bool())
- continue
- }
-
- // 其他类型直接转换为字符串
- data[field.Name] = fmt.Sprintf("%v", value.Interface())
- }
-
- txn := RDB.TxPipeline()
- txn.HSet(ctx, key, data)
-
- // 只有在 expiration 大于 0 时才设置过期时间
- if expiration > 0 {
- txn.Expire(ctx, key, expiration)
- }
-
- _, err := txn.Exec(ctx)
- if err != nil {
- return fmt.Errorf("failed to execute transaction: %w", err)
- }
- return nil
-}
-
-func RedisHGetObj(key string, obj interface{}) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
- }
- ctx := context.Background()
-
- result, err := RDB.HGetAll(ctx, key).Result()
- if err != nil {
- return fmt.Errorf("failed to load hash from Redis: %w", err)
- }
-
- if len(result) == 0 {
- return fmt.Errorf("key %s not found in Redis", key)
- }
-
- // Handle both pointer and non-pointer values
- val := reflect.ValueOf(obj)
- if val.Kind() != reflect.Ptr {
- return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
- }
-
- v := val.Elem()
- if v.Kind() != reflect.Struct {
- return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
- }
-
- t := v.Type()
- for i := 0; i < v.NumField(); i++ {
- field := t.Field(i)
- fieldName := field.Name
- if value, ok := result[fieldName]; ok {
- fieldValue := v.Field(i)
-
- // Handle pointer types
- if fieldValue.Kind() == reflect.Ptr {
- if value == "" {
- continue
- }
- if fieldValue.IsNil() {
- fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
- }
- fieldValue = fieldValue.Elem()
- }
-
- // Enhanced type handling for Token struct
- switch fieldValue.Kind() {
- case reflect.String:
- fieldValue.SetString(value)
- case reflect.Int, reflect.Int64:
- intValue, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
- }
- fieldValue.SetInt(intValue)
- case reflect.Bool:
- boolValue, err := strconv.ParseBool(value)
- if err != nil {
- return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
- }
- fieldValue.SetBool(boolValue)
- case reflect.Struct:
- // Special handling for gorm.DeletedAt
- if fieldValue.Type().String() == "gorm.DeletedAt" {
- if value != "" {
- timeValue, err := time.Parse(time.RFC3339, value)
- if err != nil {
- return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
- }
- fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
- }
- }
- default:
- return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
- }
- }
- }
-
- return nil
-}
-
-// RedisIncr Add this function to handle atomic increments
-func RedisIncr(key string, delta int64) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
- }
- // 检查键的剩余生存时间
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
-
- // 只有在 key 存在且有 TTL 时才需要特殊处理
- if ttl > 0 {
- ctx := context.Background()
- // 开始一个Redis事务
- txn := RDB.TxPipeline()
-
- // 减少余额
- decrCmd := txn.IncrBy(ctx, key, delta)
- if err := decrCmd.Err(); err != nil {
- return err // 如果减少失败,则直接返回错误
- }
-
- // 重新设置过期时间,使用原来的过期时间
- txn.Expire(ctx, key, ttl)
-
- // 执行事务
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
-}
-
-func RedisHIncrBy(key, field string, delta int64) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
- }
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
-
- if ttl > 0 {
- ctx := context.Background()
- txn := RDB.TxPipeline()
-
- incrCmd := txn.HIncrBy(ctx, key, field, delta)
- if err := incrCmd.Err(); err != nil {
- return err
- }
-
- txn.Expire(ctx, key, ttl)
-
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
-}
-
-func RedisHSetField(key, field string, value interface{}) error {
- if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
- }
- ttlCmd := RDB.TTL(context.Background(), key)
- ttl, err := ttlCmd.Result()
- if err != nil && !errors.Is(err, redis.Nil) {
- return fmt.Errorf("failed to get TTL: %w", err)
- }
-
- if ttl > 0 {
- ctx := context.Background()
- txn := RDB.TxPipeline()
-
- hsetCmd := txn.HSet(ctx, key, field, value)
- if err := hsetCmd.Err(); err != nil {
- return err
- }
-
- txn.Expire(ctx, key, ttl)
-
- _, err = txn.Exec(ctx)
- return err
- }
- return nil
-}
diff --git a/new-api/common/ssrf_protection.go b/new-api/common/ssrf_protection.go
deleted file mode 100644
index 9eaadc2d77a81b2cd5d4b2da49555fe487c0eace..0000000000000000000000000000000000000000
--- a/new-api/common/ssrf_protection.go
+++ /dev/null
@@ -1,327 +0,0 @@
-package common
-
-import (
- "fmt"
- "net"
- "net/url"
- "strconv"
- "strings"
-)
-
-// SSRFProtection SSRF防护配置
-type SSRFProtection struct {
- AllowPrivateIp bool
- DomainFilterMode bool // true: 白名单, false: 黑名单
- DomainList []string // domain format, e.g. example.com, *.example.com
- IpFilterMode bool // true: 白名单, false: 黑名单
- IpList []string // CIDR or single IP
- AllowedPorts []int // 允许的端口范围
- ApplyIPFilterForDomain bool // 对域名启用IP过滤
-}
-
-// DefaultSSRFProtection 默认SSRF防护配置
-var DefaultSSRFProtection = &SSRFProtection{
- AllowPrivateIp: false,
- DomainFilterMode: true,
- DomainList: []string{},
- IpFilterMode: true,
- IpList: []string{},
- AllowedPorts: []int{},
-}
-
-// isPrivateIP 检查IP是否为私有地址
-func isPrivateIP(ip net.IP) bool {
- if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
- return true
- }
-
- // 检查私有网段
- private := []net.IPNet{
- {IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
- {IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
- {IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
- {IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
- {IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
- {IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
- {IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
- }
-
- for _, privateNet := range private {
- if privateNet.Contains(ip) {
- return true
- }
- }
-
- // 检查IPv6私有地址
- if ip.To4() == nil {
- // IPv6 loopback
- if ip.Equal(net.IPv6loopback) {
- return true
- }
- // IPv6 link-local
- if strings.HasPrefix(ip.String(), "fe80:") {
- return true
- }
- // IPv6 unique local
- if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
- return true
- }
- }
-
- return false
-}
-
-// parsePortRanges 解析端口范围配置
-// 支持格式: "80", "443", "8000-9000"
-func parsePortRanges(portConfigs []string) ([]int, error) {
- var ports []int
-
- for _, config := range portConfigs {
- config = strings.TrimSpace(config)
- if config == "" {
- continue
- }
-
- if strings.Contains(config, "-") {
- // 处理端口范围 "8000-9000"
- parts := strings.Split(config, "-")
- if len(parts) != 2 {
- return nil, fmt.Errorf("invalid port range format: %s", config)
- }
-
- startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
- if err != nil {
- return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
- }
-
- endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
- if err != nil {
- return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
- }
-
- if startPort > endPort {
- return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
- }
-
- if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
- return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
- }
-
- // 添加范围内的所有端口
- for port := startPort; port <= endPort; port++ {
- ports = append(ports, port)
- }
- } else {
- // 处理单个端口 "80"
- port, err := strconv.Atoi(config)
- if err != nil {
- return nil, fmt.Errorf("invalid port number: %s", config)
- }
-
- if port < 1 || port > 65535 {
- return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
- }
-
- ports = append(ports, port)
- }
- }
-
- return ports, nil
-}
-
-// isAllowedPort 检查端口是否被允许
-func (p *SSRFProtection) isAllowedPort(port int) bool {
- if len(p.AllowedPorts) == 0 {
- return true // 如果没有配置端口限制,则允许所有端口
- }
-
- for _, allowedPort := range p.AllowedPorts {
- if port == allowedPort {
- return true
- }
- }
- return false
-}
-
-// isDomainWhitelisted 检查域名是否在白名单中
-func isDomainListed(domain string, list []string) bool {
- if len(list) == 0 {
- return false
- }
-
- domain = strings.ToLower(domain)
- for _, item := range list {
- item = strings.ToLower(strings.TrimSpace(item))
- if item == "" {
- continue
- }
- // 精确匹配
- if domain == item {
- return true
- }
- // 通配符匹配 (*.example.com)
- if strings.HasPrefix(item, "*.") {
- suffix := strings.TrimPrefix(item, "*.")
- if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
- return true
- }
- }
- }
- return false
-}
-
-func (p *SSRFProtection) isDomainAllowed(domain string) bool {
- listed := isDomainListed(domain, p.DomainList)
- if p.DomainFilterMode { // 白名单
- return listed
- }
- // 黑名单
- return !listed
-}
-
-// isIPWhitelisted 检查IP是否在白名单中
-
-func isIPListed(ip net.IP, list []string) bool {
- if len(list) == 0 {
- return false
- }
-
- for _, whitelistCIDR := range list {
- _, network, err := net.ParseCIDR(whitelistCIDR)
- if err != nil {
- // 尝试作为单个IP处理
- if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
- if ip.Equal(whitelistIP) {
- return true
- }
- }
- continue
- }
-
- if network.Contains(ip) {
- return true
- }
- }
- return false
-}
-
-// IsIPAccessAllowed 检查IP是否允许访问
-func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
- // 私有IP限制
- if isPrivateIP(ip) && !p.AllowPrivateIp {
- return false
- }
-
- listed := isIPListed(ip, p.IpList)
- if p.IpFilterMode { // 白名单
- return listed
- }
- // 黑名单
- return !listed
-}
-
-// ValidateURL 验证URL是否安全
-func (p *SSRFProtection) ValidateURL(urlStr string) error {
- // 解析URL
- u, err := url.Parse(urlStr)
- if err != nil {
- return fmt.Errorf("invalid URL format: %v", err)
- }
-
- // 只允许HTTP/HTTPS协议
- if u.Scheme != "http" && u.Scheme != "https" {
- return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
- }
-
- // 解析主机和端口
- host, portStr, err := net.SplitHostPort(u.Host)
- if err != nil {
- // 没有端口,使用默认端口
- host = u.Hostname()
- if u.Scheme == "https" {
- portStr = "443"
- } else {
- portStr = "80"
- }
- }
-
- // 验证端口
- port, err := strconv.Atoi(portStr)
- if err != nil {
- return fmt.Errorf("invalid port: %s", portStr)
- }
-
- if !p.isAllowedPort(port) {
- return fmt.Errorf("port %d is not allowed", port)
- }
-
- // 如果 host 是 IP,则跳过域名检查
- if ip := net.ParseIP(host); ip != nil {
- if !p.IsIPAccessAllowed(ip) {
- if isPrivateIP(ip) {
- return fmt.Errorf("private IP address not allowed: %s", ip.String())
- }
- if p.IpFilterMode {
- return fmt.Errorf("ip not in whitelist: %s", ip.String())
- }
- return fmt.Errorf("ip in blacklist: %s", ip.String())
- }
- return nil
- }
-
- // 先进行域名过滤
- if !p.isDomainAllowed(host) {
- if p.DomainFilterMode {
- return fmt.Errorf("domain not in whitelist: %s", host)
- }
- return fmt.Errorf("domain in blacklist: %s", host)
- }
-
- // 若未启用对域名应用IP过滤,则到此通过
- if !p.ApplyIPFilterForDomain {
- return nil
- }
-
- // 解析域名对应IP并检查
- ips, err := net.LookupIP(host)
- if err != nil {
- return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
- }
- for _, ip := range ips {
- if !p.IsIPAccessAllowed(ip) {
- if isPrivateIP(ip) && !p.AllowPrivateIp {
- return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
- }
- if p.IpFilterMode {
- return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
- }
- return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
- }
- }
- return nil
-}
-
-// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
-func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
- // 如果SSRF防护被禁用,直接返回成功
- if !enableSSRFProtection {
- return nil
- }
-
- // 解析端口范围配置
- allowedPortInts, err := parsePortRanges(allowedPorts)
- if err != nil {
- return fmt.Errorf("request reject - invalid port configuration: %v", err)
- }
-
- protection := &SSRFProtection{
- AllowPrivateIp: allowPrivateIp,
- DomainFilterMode: domainFilterMode,
- DomainList: domainList,
- IpFilterMode: ipFilterMode,
- IpList: ipList,
- AllowedPorts: allowedPortInts,
- ApplyIPFilterForDomain: applyIPFilterForDomain,
- }
- return protection.ValidateURL(urlStr)
-}
diff --git a/new-api/common/str.go b/new-api/common/str.go
deleted file mode 100644
index 1a016a51208f90551bc398accf4598fa4f19b9dd..0000000000000000000000000000000000000000
--- a/new-api/common/str.go
+++ /dev/null
@@ -1,237 +0,0 @@
-package common
-
-import (
- "encoding/base64"
- "encoding/json"
- "math/rand"
- "net/url"
- "regexp"
- "strconv"
- "strings"
- "unsafe"
-)
-
-func GetStringIfEmpty(str string, defaultValue string) string {
- if str == "" {
- return defaultValue
- }
- return str
-}
-
-func GetRandomString(length int) string {
- //rand.Seed(time.Now().UnixNano())
- key := make([]byte, length)
- for i := 0; i < length; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- return string(key)
-}
-
-func MapToJsonStr(m map[string]interface{}) string {
- bytes, err := json.Marshal(m)
- if err != nil {
- return ""
- }
- return string(bytes)
-}
-
-func StrToMap(str string) (map[string]interface{}, error) {
- m := make(map[string]interface{})
- err := Unmarshal([]byte(str), &m)
- if err != nil {
- return nil, err
- }
- return m, nil
-}
-
-func StrToJsonArray(str string) ([]interface{}, error) {
- var js []interface{}
- err := json.Unmarshal([]byte(str), &js)
- if err != nil {
- return nil, err
- }
- return js, nil
-}
-
-func IsJsonArray(str string) bool {
- var js []interface{}
- return json.Unmarshal([]byte(str), &js) == nil
-}
-
-func IsJsonObject(str string) bool {
- var js map[string]interface{}
- return json.Unmarshal([]byte(str), &js) == nil
-}
-
-func String2Int(str string) int {
- num, err := strconv.Atoi(str)
- if err != nil {
- return 0
- }
- return num
-}
-
-func StringsContains(strs []string, str string) bool {
- for _, s := range strs {
- if s == str {
- return true
- }
- }
- return false
-}
-
-// StringToByteSlice []byte only read, panic on append
-func StringToByteSlice(s string) []byte {
- tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
- tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
- return *(*[]byte)(unsafe.Pointer(&tmp2))
-}
-
-func EncodeBase64(str string) string {
- return base64.StdEncoding.EncodeToString([]byte(str))
-}
-
-func GetJsonString(data any) string {
- if data == nil {
- return ""
- }
- b, _ := json.Marshal(data)
- return string(b)
-}
-
-// MaskEmail masks a user email to prevent PII leakage in logs
-// Returns "***masked***" if email is empty, otherwise shows only the domain part
-func MaskEmail(email string) string {
- if email == "" {
- return "***masked***"
- }
-
- // Find the @ symbol
- atIndex := strings.Index(email, "@")
- if atIndex == -1 {
- // No @ symbol found, return masked
- return "***masked***"
- }
-
- // Return only the domain part with @ symbol
- return "***@" + email[atIndex+1:]
-}
-
-// maskHostTail returns the tail parts of a domain/host that should be preserved.
-// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
-func maskHostTail(parts []string) []string {
- if len(parts) < 2 {
- return parts
- }
- lastPart := parts[len(parts)-1]
- secondLastPart := parts[len(parts)-2]
- if len(lastPart) == 2 && len(secondLastPart) <= 3 {
- // Likely country code TLD like co.uk, com.cn
- return []string{secondLastPart, lastPart}
- }
- return []string{lastPart}
-}
-
-// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
-// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
-func maskHostForURL(host string) string {
- parts := strings.Split(host, ".")
- if len(parts) < 2 {
- return "***"
- }
- tail := maskHostTail(parts)
- return "***." + strings.Join(tail, ".")
-}
-
-// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
-// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
-func maskHostForPlainDomain(domain string) string {
- parts := strings.Split(domain, ".")
- if len(parts) < 2 {
- return domain
- }
- tail := maskHostTail(parts)
- numStars := len(parts) - len(tail)
- if numStars < 1 {
- numStars = 1
- }
- stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
- return stars + "." + strings.Join(tail, ".")
-}
-
-// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
-// Example:
-// http://example.com -> http://***.com
-// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
-// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
-// 192.168.1.1 -> ***.***.***.***
-// openai.com -> ***.com
-// www.openai.com -> ***.***.com
-// api.openai.com -> ***.***.com
-func MaskSensitiveInfo(str string) string {
- // Mask URLs
- urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
- str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
- u, err := url.Parse(urlStr)
- if err != nil {
- return urlStr
- }
-
- host := u.Host
- if host == "" {
- return urlStr
- }
-
- // Mask host with unified logic
- maskedHost := maskHostForURL(host)
-
- result := u.Scheme + "://" + maskedHost
-
- // Mask path
- if u.Path != "" && u.Path != "/" {
- pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
- maskedPathParts := make([]string, len(pathParts))
- for i := range pathParts {
- if pathParts[i] != "" {
- maskedPathParts[i] = "***"
- }
- }
- if len(maskedPathParts) > 0 {
- result += "/" + strings.Join(maskedPathParts, "/")
- }
- } else if u.Path == "/" {
- result += "/"
- }
-
- // Mask query parameters
- if u.RawQuery != "" {
- values, err := url.ParseQuery(u.RawQuery)
- if err != nil {
- // If can't parse query, just mask the whole query string
- result += "?***"
- } else {
- maskedParams := make([]string, 0, len(values))
- for key := range values {
- maskedParams = append(maskedParams, key+"=***")
- }
- if len(maskedParams) > 0 {
- result += "?" + strings.Join(maskedParams, "&")
- }
- }
- }
-
- return result
- })
-
- // Mask domain names without protocol (like openai.com, www.openai.com)
- domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
- str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
- return maskHostForPlainDomain(domain)
- })
-
- // Mask IP addresses
- ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
- str = ipPattern.ReplaceAllString(str, "***.***.***.***")
-
- return str
-}
diff --git a/new-api/common/sys_log.go b/new-api/common/sys_log.go
deleted file mode 100644
index 95b40ea60987ed805739487b810f228f0ca920d7..0000000000000000000000000000000000000000
--- a/new-api/common/sys_log.go
+++ /dev/null
@@ -1,55 +0,0 @@
-package common
-
-import (
- "fmt"
- "os"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func SysLog(s string) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
-}
-
-func SysError(s string) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
-}
-
-func FatalLog(v ...any) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
- os.Exit(1)
-}
-
-func LogStartupSuccess(startTime time.Time, port string) {
-
- duration := time.Since(startTime)
- durationMs := duration.Milliseconds()
-
- // Get network IPs
- networkIps := GetNetworkIps()
-
- // Print blank line for spacing
- fmt.Fprintf(gin.DefaultWriter, "\n")
-
- // Print the main success message
- fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
- fmt.Fprintf(gin.DefaultWriter, "\n")
-
- // Skip fancy startup message in container environments
- if !IsRunningInContainer() {
- // Print local URL
- fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
- }
-
- // Print network URLs
- for _, ip := range networkIps {
- fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
- }
-
- // Print blank line for spacing
- fmt.Fprintf(gin.DefaultWriter, "\n")
-}
diff --git a/new-api/common/topup-ratio.go b/new-api/common/topup-ratio.go
deleted file mode 100644
index 9c0e9c85a9e116c10fe7fbbda0eb2c4db1df7c3c..0000000000000000000000000000000000000000
--- a/new-api/common/topup-ratio.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package common
-
-import (
- "encoding/json"
-)
-
-var TopupGroupRatio = map[string]float64{
- "default": 1,
- "vip": 1,
- "svip": 1,
-}
-
-func TopupGroupRatio2JSONString() string {
- jsonBytes, err := json.Marshal(TopupGroupRatio)
- if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
- }
- return string(jsonBytes)
-}
-
-func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
- TopupGroupRatio = make(map[string]float64)
- return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
-}
-
-func GetTopupGroupRatio(name string) float64 {
- ratio, ok := TopupGroupRatio[name]
- if !ok {
- SysError("topup group ratio not found: " + name)
- return 1
- }
- return ratio
-}
diff --git a/new-api/common/totp.go b/new-api/common/totp.go
deleted file mode 100644
index 7502626041f03c969c0a0ed5b5802045868e7e34..0000000000000000000000000000000000000000
--- a/new-api/common/totp.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package common
-
-import (
- "crypto/rand"
- "fmt"
- "os"
- "strconv"
- "strings"
-
- "github.com/pquerna/otp"
- "github.com/pquerna/otp/totp"
-)
-
-const (
- // 备用码配置
- BackupCodeLength = 8 // 备用码长度
- BackupCodeCount = 4 // 生成备用码数量
-
- // 限制配置
- MaxFailAttempts = 5 // 最大失败尝试次数
- LockoutDuration = 300 // 锁定时间(秒)
-)
-
-// GenerateTOTPSecret 生成TOTP密钥和配置
-func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
- issuer := Get2FAIssuer()
- return totp.Generate(totp.GenerateOpts{
- Issuer: issuer,
- AccountName: accountName,
- Period: 30,
- Digits: otp.DigitsSix,
- Algorithm: otp.AlgorithmSHA1,
- })
-}
-
-// ValidateTOTPCode 验证TOTP验证码
-func ValidateTOTPCode(secret, code string) bool {
- // 清理验证码格式
- cleanCode := strings.ReplaceAll(code, " ", "")
- if len(cleanCode) != 6 {
- return false
- }
-
- // 验证验证码
- return totp.Validate(cleanCode, secret)
-}
-
-// GenerateBackupCodes 生成备用恢复码
-func GenerateBackupCodes() ([]string, error) {
- codes := make([]string, BackupCodeCount)
-
- for i := 0; i < BackupCodeCount; i++ {
- code, err := generateRandomBackupCode()
- if err != nil {
- return nil, err
- }
- codes[i] = code
- }
-
- return codes, nil
-}
-
-// generateRandomBackupCode 生成单个备用码
-func generateRandomBackupCode() (string, error) {
- const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
- code := make([]byte, BackupCodeLength)
-
- for i := range code {
- randomBytes := make([]byte, 1)
- _, err := rand.Read(randomBytes)
- if err != nil {
- return "", err
- }
- code[i] = charset[int(randomBytes[0])%len(charset)]
- }
-
- // 格式化为 XXXX-XXXX 格式
- return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
-}
-
-// ValidateBackupCode 验证备用码格式
-func ValidateBackupCode(code string) bool {
- // 移除所有分隔符并转为大写
- cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
- if len(cleanCode) != BackupCodeLength {
- return false
- }
-
- // 检查字符是否合法
- for _, char := range cleanCode {
- if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
- return false
- }
- }
-
- return true
-}
-
-// NormalizeBackupCode 标准化备用码格式
-func NormalizeBackupCode(code string) string {
- cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
- if len(cleanCode) == BackupCodeLength {
- return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
- }
- return code
-}
-
-// HashBackupCode 对备用码进行哈希
-func HashBackupCode(code string) (string, error) {
- normalizedCode := NormalizeBackupCode(code)
- return Password2Hash(normalizedCode)
-}
-
-// Get2FAIssuer 获取2FA发行者名称
-func Get2FAIssuer() string {
- return SystemName
-}
-
-// getEnvOrDefault 获取环境变量或默认值
-func getEnvOrDefault(key, defaultValue string) string {
- if value, exists := os.LookupEnv(key); exists {
- return value
- }
- return defaultValue
-}
-
-// ValidateNumericCode 验证数字验证码格式
-func ValidateNumericCode(code string) (string, error) {
- // 移除空格
- code = strings.ReplaceAll(code, " ", "")
-
- if len(code) != 6 {
- return "", fmt.Errorf("验证码必须是6位数字")
- }
-
- // 检查是否为纯数字
- if _, err := strconv.Atoi(code); err != nil {
- return "", fmt.Errorf("验证码只能包含数字")
- }
-
- return code, nil
-}
-
-// GenerateQRCodeData 生成二维码数据
-func GenerateQRCodeData(secret, username string) string {
- issuer := Get2FAIssuer()
- accountName := fmt.Sprintf("%s (%s)", username, issuer)
- return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
- issuer, accountName, secret, issuer)
-}
diff --git a/new-api/common/utils.go b/new-api/common/utils.go
deleted file mode 100644
index 08a90bd4bea129ad63079d098bd992968d01c23e..0000000000000000000000000000000000000000
--- a/new-api/common/utils.go
+++ /dev/null
@@ -1,384 +0,0 @@
-package common
-
-import (
- "bytes"
- "context"
- crand "crypto/rand"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "html/template"
- "io"
- "log"
- "math/big"
- "math/rand"
- "net"
- "net/url"
- "os"
- "os/exec"
- "runtime"
- "strconv"
- "strings"
- "time"
-
- "github.com/google/uuid"
- "github.com/pkg/errors"
-)
-
-func OpenBrowser(url string) {
- var err error
-
- switch runtime.GOOS {
- case "linux":
- err = exec.Command("xdg-open", url).Start()
- case "windows":
- err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
- case "darwin":
- err = exec.Command("open", url).Start()
- }
- if err != nil {
- log.Println(err)
- }
-}
-
-func GetIp() (ip string) {
- ips, err := net.InterfaceAddrs()
- if err != nil {
- log.Println(err)
- return ip
- }
-
- for _, a := range ips {
- if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
- if ipNet.IP.To4() != nil {
- ip = ipNet.IP.String()
- if strings.HasPrefix(ip, "10") {
- return
- }
- if strings.HasPrefix(ip, "172") {
- return
- }
- if strings.HasPrefix(ip, "192.168") {
- return
- }
- ip = ""
- }
- }
- }
- return
-}
-
-func GetNetworkIps() []string {
- var networkIps []string
- ips, err := net.InterfaceAddrs()
- if err != nil {
- log.Println(err)
- return networkIps
- }
-
- for _, a := range ips {
- if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
- if ipNet.IP.To4() != nil {
- ip := ipNet.IP.String()
- // Include common private network ranges
- if strings.HasPrefix(ip, "10.") ||
- strings.HasPrefix(ip, "172.") ||
- strings.HasPrefix(ip, "192.168.") {
- networkIps = append(networkIps, ip)
- }
- }
- }
- }
- return networkIps
-}
-
-// IsRunningInContainer detects if the application is running inside a container
-func IsRunningInContainer() bool {
- // Method 1: Check for .dockerenv file (Docker containers)
- if _, err := os.Stat("/.dockerenv"); err == nil {
- return true
- }
-
- // Method 2: Check cgroup for container indicators
- if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
- content := string(data)
- if strings.Contains(content, "docker") ||
- strings.Contains(content, "containerd") ||
- strings.Contains(content, "kubepods") ||
- strings.Contains(content, "/lxc/") {
- return true
- }
- }
-
- // Method 3: Check environment variables commonly set by container runtimes
- containerEnvVars := []string{
- "KUBERNETES_SERVICE_HOST",
- "DOCKER_CONTAINER",
- "container",
- }
-
- for _, envVar := range containerEnvVars {
- if os.Getenv(envVar) != "" {
- return true
- }
- }
-
- // Method 4: Check if init process is not the traditional init
- if data, err := os.ReadFile("/proc/1/comm"); err == nil {
- comm := strings.TrimSpace(string(data))
- // In containers, process 1 is often not "init" or "systemd"
- if comm != "init" && comm != "systemd" {
- // Additional check: if it's a common container entrypoint
- if strings.Contains(comm, "docker") ||
- strings.Contains(comm, "containerd") ||
- strings.Contains(comm, "runc") {
- return true
- }
- }
- }
-
- return false
-}
-
-var sizeKB = 1024
-var sizeMB = sizeKB * 1024
-var sizeGB = sizeMB * 1024
-
-func Bytes2Size(num int64) string {
- numStr := ""
- unit := "B"
- if num/int64(sizeGB) > 1 {
- numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
- unit = "GB"
- } else if num/int64(sizeMB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
- unit = "MB"
- } else if num/int64(sizeKB) > 1 {
- numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
- unit = "KB"
- } else {
- numStr = fmt.Sprintf("%d", num)
- }
- return numStr + " " + unit
-}
-
-func Seconds2Time(num int) (time string) {
- if num/31104000 > 0 {
- time += strconv.Itoa(num/31104000) + " 年 "
- num %= 31104000
- }
- if num/2592000 > 0 {
- time += strconv.Itoa(num/2592000) + " 个月 "
- num %= 2592000
- }
- if num/86400 > 0 {
- time += strconv.Itoa(num/86400) + " 天 "
- num %= 86400
- }
- if num/3600 > 0 {
- time += strconv.Itoa(num/3600) + " 小时 "
- num %= 3600
- }
- if num/60 > 0 {
- time += strconv.Itoa(num/60) + " 分钟 "
- num %= 60
- }
- time += strconv.Itoa(num) + " 秒"
- return
-}
-
-func Interface2String(inter interface{}) string {
- switch inter.(type) {
- case string:
- return inter.(string)
- case int:
- return fmt.Sprintf("%d", inter.(int))
- case float64:
- return fmt.Sprintf("%f", inter.(float64))
- case bool:
- if inter.(bool) {
- return "true"
- } else {
- return "false"
- }
- case nil:
- return ""
- }
- return fmt.Sprintf("%v", inter)
-}
-
-func UnescapeHTML(x string) interface{} {
- return template.HTML(x)
-}
-
-func IntMax(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
- }
-}
-
-func IsIP(s string) bool {
- ip := net.ParseIP(s)
- return ip != nil
-}
-
-func GetUUID() string {
- code := uuid.New().String()
- code = strings.Replace(code, "-", "", -1)
- return code
-}
-
-const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-
-func init() {
- rand.New(rand.NewSource(time.Now().UnixNano()))
-}
-
-func GenerateRandomCharsKey(length int) (string, error) {
- b := make([]byte, length)
- maxI := big.NewInt(int64(len(keyChars)))
-
- for i := range b {
- n, err := crand.Int(crand.Reader, maxI)
- if err != nil {
- return "", err
- }
- b[i] = keyChars[n.Int64()]
- }
-
- return string(b), nil
-}
-
-func GenerateRandomKey(length int) (string, error) {
- bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
- if _, err := crand.Read(bytes); err != nil {
- return "", err
- }
- return base64.StdEncoding.EncodeToString(bytes), nil
-}
-
-func GenerateKey() (string, error) {
- //rand.Seed(time.Now().UnixNano())
- return GenerateRandomCharsKey(48)
-}
-
-func GetRandomInt(max int) int {
- //rand.Seed(time.Now().UnixNano())
- return rand.Intn(max)
-}
-
-func GetTimestamp() int64 {
- return time.Now().Unix()
-}
-
-func GetTimeString() string {
- now := time.Now()
- return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
-}
-
-func Max(a int, b int) int {
- if a >= b {
- return a
- } else {
- return b
- }
-}
-
-func MessageWithRequestId(message string, id string) string {
- return fmt.Sprintf("%s (request id: %s)", message, id)
-}
-
-func RandomSleep() {
- // Sleep for 0-3000 ms
- time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
-}
-
-func GetPointer[T any](v T) *T {
- return &v
-}
-
-func Any2Type[T any](data any) (T, error) {
- var zero T
- bytes, err := json.Marshal(data)
- if err != nil {
- return zero, err
- }
- var res T
- err = json.Unmarshal(bytes, &res)
- if err != nil {
- return zero, err
- }
- return res, nil
-}
-
-// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
-func SaveTmpFile(filename string, data io.Reader) (string, error) {
- f, err := os.CreateTemp(os.TempDir(), filename)
- if err != nil {
- return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
- }
- defer f.Close()
-
- _, err = io.Copy(f, data)
- if err != nil {
- return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
- }
-
- return f.Name(), nil
-}
-
-// GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
- // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
- c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
- output, err := c.Output()
- if err != nil {
- return 0, errors.Wrap(err, "failed to get audio duration")
- }
- durationStr := string(bytes.TrimSpace(output))
- if durationStr == "N/A" {
- // Create a temporary output file name
- tmpFp, err := os.CreateTemp("", "audio-*"+ext)
- if err != nil {
- return 0, errors.Wrap(err, "failed to create temporary file")
- }
- tmpName := tmpFp.Name()
- // Close immediately so ffmpeg can open the file on Windows.
- _ = tmpFp.Close()
- defer os.Remove(tmpName)
-
- // ffmpeg -y -i filename -vcodec copy -acodec copy
- ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
- if err := ffmpegCmd.Run(); err != nil {
- return 0, errors.Wrap(err, "failed to run ffmpeg")
- }
-
- // Recalculate the duration of the new file
- c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
- output, err := c.Output()
- if err != nil {
- return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
- }
- durationStr = string(bytes.TrimSpace(output))
- }
- return strconv.ParseFloat(durationStr, 64)
-}
-
-// BuildURL concatenates base and endpoint, returns the complete url string
-func BuildURL(base string, endpoint string) string {
- u, err := url.Parse(base)
- if err != nil {
- return base + endpoint
- }
- end := endpoint
- if end == "" {
- end = "/"
- }
- ref, err := url.Parse(end)
- if err != nil {
- return base + endpoint
- }
- return u.ResolveReference(ref).String()
-}
diff --git a/new-api/common/validate.go b/new-api/common/validate.go
deleted file mode 100644
index 4e1888508bb8d63556bac2fff48e58203012d361..0000000000000000000000000000000000000000
--- a/new-api/common/validate.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package common
-
-import "github.com/go-playground/validator/v10"
-
-var Validate *validator.Validate
-
-func init() {
- Validate = validator.New()
-}
diff --git a/new-api/common/verification.go b/new-api/common/verification.go
deleted file mode 100644
index 2c1059aa16ca6cda95e386b79a60a3edbc90c48f..0000000000000000000000000000000000000000
--- a/new-api/common/verification.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package common
-
-import (
- "github.com/google/uuid"
- "strings"
- "sync"
- "time"
-)
-
-type verificationValue struct {
- code string
- time time.Time
-}
-
-const (
- EmailVerificationPurpose = "v"
- PasswordResetPurpose = "r"
-)
-
-var verificationMutex sync.Mutex
-var verificationMap map[string]verificationValue
-var verificationMapMaxSize = 10
-var VerificationValidMinutes = 10
-
-func GenerateVerificationCode(length int) string {
- code := uuid.New().String()
- code = strings.Replace(code, "-", "", -1)
- if length == 0 {
- return code
- }
- return code[:length]
-}
-
-func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
- verificationMutex.Lock()
- defer verificationMutex.Unlock()
- verificationMap[purpose+key] = verificationValue{
- code: code,
- time: time.Now(),
- }
- if len(verificationMap) > verificationMapMaxSize {
- removeExpiredPairs()
- }
-}
-
-func VerifyCodeWithKey(key string, code string, purpose string) bool {
- verificationMutex.Lock()
- defer verificationMutex.Unlock()
- value, okay := verificationMap[purpose+key]
- now := time.Now()
- if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
- return false
- }
- return code == value.code
-}
-
-func DeleteKey(key string, purpose string) {
- verificationMutex.Lock()
- defer verificationMutex.Unlock()
- delete(verificationMap, purpose+key)
-}
-
-// no lock inside, so the caller must lock the verificationMap before calling!
-func removeExpiredPairs() {
- now := time.Now()
- for key := range verificationMap {
- if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
- delete(verificationMap, key)
- }
- }
-}
-
-func init() {
- verificationMutex.Lock()
- defer verificationMutex.Unlock()
- verificationMap = make(map[string]verificationValue)
-}
diff --git a/new-api/constant/README.md b/new-api/constant/README.md
deleted file mode 100644
index 963cb439ca007a4637f9210245c8fbbcc1f447c4..0000000000000000000000000000000000000000
--- a/new-api/constant/README.md
+++ /dev/null
@@ -1,26 +0,0 @@
-# constant 包 (`/constant`)
-
-该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
-
-## 当前文件
-
-| 文件 | 说明 |
-|----------------------|---------------------------------------------------------------------|
-| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
-| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
-| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
-| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
-| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
-| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
-| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
-| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
-| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
-| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
-
-## 使用约定
-
-1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
-2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
-3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
-
-> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
\ No newline at end of file
diff --git a/new-api/constant/api_type.go b/new-api/constant/api_type.go
deleted file mode 100644
index 4f517fcf351e90eb300833fa8ae750fbd03bf502..0000000000000000000000000000000000000000
--- a/new-api/constant/api_type.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package constant
-
-const (
- APITypeOpenAI = iota
- APITypeAnthropic
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
- APITypeZhipuV4
- APITypeOllama
- APITypePerplexity
- APITypeAws
- APITypeCohere
- APITypeDify
- APITypeJina
- APITypeCloudflare
- APITypeSiliconFlow
- APITypeVertexAi
- APITypeMistral
- APITypeDeepSeek
- APITypeMokaAI
- APITypeVolcEngine
- APITypeBaiduV2
- APITypeOpenRouter
- APITypeXinference
- APITypeXai
- APITypeCoze
- APITypeJimeng
- APITypeMoonshot
- APITypeSubmodel
- APITypeDummy // this one is only for count, do not add any channel after this
-)
diff --git a/new-api/constant/azure.go b/new-api/constant/azure.go
deleted file mode 100644
index 20dc7083d27b124ea3d26d6d0ceae2e06119a883..0000000000000000000000000000000000000000
--- a/new-api/constant/azure.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package constant
-
-import "time"
-
-var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()
diff --git a/new-api/constant/cache_key.go b/new-api/constant/cache_key.go
deleted file mode 100644
index 5f274406ed37f3e31c4cf613c187e6fefd308d25..0000000000000000000000000000000000000000
--- a/new-api/constant/cache_key.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package constant
-
-// Cache keys
-const (
- UserGroupKeyFmt = "user_group:%d"
- UserQuotaKeyFmt = "user_quota:%d"
- UserEnabledKeyFmt = "user_enabled:%d"
- UserUsernameKeyFmt = "user_name:%d"
-)
-
-const (
- TokenFiledRemainQuota = "RemainQuota"
- TokenFieldGroup = "Group"
-)
diff --git a/new-api/constant/channel.go b/new-api/constant/channel.go
deleted file mode 100644
index 68ad8960d4677ce0b2ccf5253af2ac80f47fd42e..0000000000000000000000000000000000000000
--- a/new-api/constant/channel.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package constant
-
-const (
- ChannelTypeUnknown = 0
- ChannelTypeOpenAI = 1
- ChannelTypeMidjourney = 2
- ChannelTypeAzure = 3
- ChannelTypeOllama = 4
- ChannelTypeMidjourneyPlus = 5
- ChannelTypeOpenAIMax = 6
- ChannelTypeOhMyGPT = 7
- ChannelTypeCustom = 8
- ChannelTypeAILS = 9
- ChannelTypeAIProxy = 10
- ChannelTypePaLM = 11
- ChannelTypeAPI2GPT = 12
- ChannelTypeAIGC2D = 13
- ChannelTypeAnthropic = 14
- ChannelTypeBaidu = 15
- ChannelTypeZhipu = 16
- ChannelTypeAli = 17
- ChannelTypeXunfei = 18
- ChannelType360 = 19
- ChannelTypeOpenRouter = 20
- ChannelTypeAIProxyLibrary = 21
- ChannelTypeFastGPT = 22
- ChannelTypeTencent = 23
- ChannelTypeGemini = 24
- ChannelTypeMoonshot = 25
- ChannelTypeZhipu_v4 = 26
- ChannelTypePerplexity = 27
- ChannelTypeLingYiWanWu = 31
- ChannelTypeAws = 33
- ChannelTypeCohere = 34
- ChannelTypeMiniMax = 35
- ChannelTypeSunoAPI = 36
- ChannelTypeDify = 37
- ChannelTypeJina = 38
- ChannelCloudflare = 39
- ChannelTypeSiliconFlow = 40
- ChannelTypeVertexAi = 41
- ChannelTypeMistral = 42
- ChannelTypeDeepSeek = 43
- ChannelTypeMokaAI = 44
- ChannelTypeVolcEngine = 45
- ChannelTypeBaiduV2 = 46
- ChannelTypeXinference = 47
- ChannelTypeXai = 48
- ChannelTypeCoze = 49
- ChannelTypeKling = 50
- ChannelTypeJimeng = 51
- ChannelTypeVidu = 52
- ChannelTypeSubmodel = 53
- ChannelTypeDummy // this one is only for count, do not add any channel after this
-
-
-)
-
-var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "http://localhost:11434", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://api.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.tencentcloudapi.com", //23
- "https://generativelanguage.googleapis.com", //24
- "https://api.moonshot.cn", //25
- "https://open.bigmodel.cn", //26
- "https://api.perplexity.ai", //27
- "", //28
- "", //29
- "", //30
- "https://api.lingyiwanwu.com", //31
- "", //32
- "", //33
- "https://api.cohere.ai", //34
- "https://api.minimax.chat", //35
- "", //36
- "https://api.dify.ai", //37
- "https://api.jina.ai", //38
- "https://api.cloudflare.com", //39
- "https://api.siliconflow.cn", //40
- "", //41
- "https://api.mistral.ai", //42
- "https://api.deepseek.com", //43
- "https://api.moka.ai", //44
- "https://ark.cn-beijing.volces.com", //45
- "https://qianfan.baidubce.com", //46
- "", //47
- "https://api.x.ai", //48
- "https://api.coze.cn", //49
- "https://api.klingai.com", //50
- "https://visual.volcengineapi.com", //51
- "https://api.vidu.cn", //52
- "https://llm.submodel.ai", //53
-}
diff --git a/new-api/constant/context_key.go b/new-api/constant/context_key.go
deleted file mode 100644
index 7d766c850a9b4e7bf665310708aeab653e6fec5a..0000000000000000000000000000000000000000
--- a/new-api/constant/context_key.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package constant
-
-type ContextKey string
-
-const (
- ContextKeyTokenCountMeta ContextKey = "token_count_meta"
- ContextKeyPromptTokens ContextKey = "prompt_tokens"
-
- ContextKeyOriginalModel ContextKey = "original_model"
- ContextKeyRequestStartTime ContextKey = "request_start_time"
-
- /* token related keys */
- ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
- ContextKeyTokenKey ContextKey = "token_key"
- ContextKeyTokenId ContextKey = "token_id"
- ContextKeyTokenGroup ContextKey = "token_group"
- ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
- ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
- ContextKeyTokenModelLimit ContextKey = "token_model_limit"
-
- /* channel related keys */
- ContextKeyChannelId ContextKey = "channel_id"
- ContextKeyChannelName ContextKey = "channel_name"
- ContextKeyChannelCreateTime ContextKey = "channel_create_time"
- ContextKeyChannelBaseUrl ContextKey = "base_url"
- ContextKeyChannelType ContextKey = "channel_type"
- ContextKeyChannelSetting ContextKey = "channel_setting"
- ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
- ContextKeyChannelParamOverride ContextKey = "param_override"
- ContextKeyChannelHeaderOverride ContextKey = "header_override"
- ContextKeyChannelOrganization ContextKey = "channel_organization"
- ContextKeyChannelAutoBan ContextKey = "auto_ban"
- ContextKeyChannelModelMapping ContextKey = "model_mapping"
- ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
- ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
- ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
- ContextKeyChannelKey ContextKey = "channel_key"
-
- /* user related keys */
- ContextKeyUserId ContextKey = "id"
- ContextKeyUserSetting ContextKey = "user_setting"
- ContextKeyUserQuota ContextKey = "user_quota"
- ContextKeyUserStatus ContextKey = "user_status"
- ContextKeyUserEmail ContextKey = "user_email"
- ContextKeyUserGroup ContextKey = "user_group"
- ContextKeyUsingGroup ContextKey = "group"
- ContextKeyUserName ContextKey = "username"
-
- ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
-)
diff --git a/new-api/constant/endpoint_type.go b/new-api/constant/endpoint_type.go
deleted file mode 100644
index 67c0ebb4e115887ea3dac25f151e675806cf2c6b..0000000000000000000000000000000000000000
--- a/new-api/constant/endpoint_type.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package constant
-
-type EndpointType string
-
-const (
- EndpointTypeOpenAI EndpointType = "openai"
- EndpointTypeOpenAIResponse EndpointType = "openai-response"
- EndpointTypeAnthropic EndpointType = "anthropic"
- EndpointTypeGemini EndpointType = "gemini"
- EndpointTypeJinaRerank EndpointType = "jina-rerank"
- EndpointTypeImageGeneration EndpointType = "image-generation"
- EndpointTypeEmbeddings EndpointType = "embeddings"
- //EndpointTypeMidjourney EndpointType = "midjourney-proxy"
- //EndpointTypeSuno EndpointType = "suno-proxy"
- //EndpointTypeKling EndpointType = "kling"
- //EndpointTypeJimeng EndpointType = "jimeng"
-)
diff --git a/new-api/constant/env.go b/new-api/constant/env.go
deleted file mode 100644
index 20c44b2168b683cb13e291c3c6cf7a3d03d25509..0000000000000000000000000000000000000000
--- a/new-api/constant/env.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package constant
-
-var StreamingTimeout int
-var DifyDebug bool
-var MaxFileDownloadMB int
-var ForceStreamOption bool
-var GetMediaToken bool
-var GetMediaTokenNotStream bool
-var UpdateTask bool
-var AzureDefaultAPIVersion string
-var GeminiVisionMaxImageNum int
-var NotifyLimitCount int
-var NotificationLimitDurationMinute int
-var GenerateDefaultToken bool
-var ErrorLogEnabled bool
diff --git a/new-api/constant/finish_reason.go b/new-api/constant/finish_reason.go
deleted file mode 100644
index 72047538ec6e72e5097555463d2fa587f6f18728..0000000000000000000000000000000000000000
--- a/new-api/constant/finish_reason.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package constant
-
-var (
- FinishReasonStop = "stop"
- FinishReasonToolCalls = "tool_calls"
- FinishReasonLength = "length"
- FinishReasonFunctionCall = "function_call"
- FinishReasonContentFilter = "content_filter"
-)
diff --git a/new-api/constant/midjourney.go b/new-api/constant/midjourney.go
deleted file mode 100644
index 891499ae356d2e123dd8a3420357d5d71e0e698b..0000000000000000000000000000000000000000
--- a/new-api/constant/midjourney.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package constant
-
-const (
- MjErrorUnknown = 5
- MjRequestError = 4
-)
-
-const (
- MjActionImagine = "IMAGINE"
- MjActionDescribe = "DESCRIBE"
- MjActionBlend = "BLEND"
- MjActionUpscale = "UPSCALE"
- MjActionVariation = "VARIATION"
- MjActionReRoll = "REROLL"
- MjActionInPaint = "INPAINT"
- MjActionModal = "MODAL"
- MjActionZoom = "ZOOM"
- MjActionCustomZoom = "CUSTOM_ZOOM"
- MjActionShorten = "SHORTEN"
- MjActionHighVariation = "HIGH_VARIATION"
- MjActionLowVariation = "LOW_VARIATION"
- MjActionPan = "PAN"
- MjActionSwapFace = "SWAP_FACE"
- MjActionUpload = "UPLOAD"
- MjActionVideo = "VIDEO"
- MjActionEdits = "EDITS"
-)
-
-var MidjourneyModel2Action = map[string]string{
- "mj_imagine": MjActionImagine,
- "mj_describe": MjActionDescribe,
- "mj_blend": MjActionBlend,
- "mj_upscale": MjActionUpscale,
- "mj_variation": MjActionVariation,
- "mj_reroll": MjActionReRoll,
- "mj_modal": MjActionModal,
- "mj_inpaint": MjActionInPaint,
- "mj_zoom": MjActionZoom,
- "mj_custom_zoom": MjActionCustomZoom,
- "mj_shorten": MjActionShorten,
- "mj_high_variation": MjActionHighVariation,
- "mj_low_variation": MjActionLowVariation,
- "mj_pan": MjActionPan,
- "swap_face": MjActionSwapFace,
- "mj_upload": MjActionUpload,
- "mj_video": MjActionVideo,
- "mj_edits": MjActionEdits,
-}
diff --git a/new-api/constant/multi_key_mode.go b/new-api/constant/multi_key_mode.go
deleted file mode 100644
index 8419698c16a6ebc3d37b6db02d18c2ae555d70ff..0000000000000000000000000000000000000000
--- a/new-api/constant/multi_key_mode.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package constant
-
-type MultiKeyMode string
-
-const (
- MultiKeyModeRandom MultiKeyMode = "random" // 随机
- MultiKeyModePolling MultiKeyMode = "polling" // 轮询
-)
diff --git a/new-api/constant/setup.go b/new-api/constant/setup.go
deleted file mode 100644
index 5de6e789805bf1c807a91edc0cd4ed2169516a87..0000000000000000000000000000000000000000
--- a/new-api/constant/setup.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package constant
-
-var Setup = false
diff --git a/new-api/constant/task.go b/new-api/constant/task.go
deleted file mode 100644
index c0fc45376847ed56be4a60af0deff1d0e1394f1b..0000000000000000000000000000000000000000
--- a/new-api/constant/task.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package constant
-
-type TaskPlatform string
-
-const (
- TaskPlatformSuno TaskPlatform = "suno"
- TaskPlatformMidjourney = "mj"
-)
-
-const (
- SunoActionMusic = "MUSIC"
- SunoActionLyrics = "LYRICS"
-
- TaskActionGenerate = "generate"
- TaskActionTextGenerate = "textGenerate"
- TaskActionFirstTailGenerate = "firstTailGenerate"
- TaskActionReferenceGenerate = "referenceGenerate"
-)
-
-var SunoModel2Action = map[string]string{
- "suno_music": SunoActionMusic,
- "suno_lyrics": SunoActionLyrics,
-}
diff --git a/new-api/controller/billing.go b/new-api/controller/billing.go
deleted file mode 100644
index 24c667cdb3b4dd0c4e4cb0a658177a20610fe4ea..0000000000000000000000000000000000000000
--- a/new-api/controller/billing.go
+++ /dev/null
@@ -1,92 +0,0 @@
-package controller
-
-import (
- "github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/dto"
- "one-api/model"
-)
-
-func GetSubscription(c *gin.Context) {
- var remainQuota int
- var usedQuota int
- var err error
- var token *model.Token
- var expiredTime int64
- if common.DisplayTokenStatEnabled {
- tokenId := c.GetInt("token_id")
- token, err = model.GetTokenById(tokenId)
- expiredTime = token.ExpiredTime
- remainQuota = token.RemainQuota
- usedQuota = token.UsedQuota
- } else {
- userId := c.GetInt("id")
- remainQuota, err = model.GetUserQuota(userId, false)
- usedQuota, err = model.GetUserUsedQuota(userId)
- }
- if expiredTime <= 0 {
- expiredTime = 0
- }
- if err != nil {
- openAIError := dto.OpenAIError{
- Message: err.Error(),
- Type: "upstream_error",
- }
- c.JSON(200, gin.H{
- "error": openAIError,
- })
- return
- }
- quota := remainQuota + usedQuota
- amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
- }
- if token != nil && token.UnlimitedQuota {
- amount = 100000000
- }
- subscription := OpenAISubscriptionResponse{
- Object: "billing_subscription",
- HasPaymentMethod: true,
- SoftLimitUSD: amount,
- HardLimitUSD: amount,
- SystemHardLimitUSD: amount,
- AccessUntil: expiredTime,
- }
- c.JSON(200, subscription)
- return
-}
-
-func GetUsage(c *gin.Context) {
- var quota int
- var err error
- var token *model.Token
- if common.DisplayTokenStatEnabled {
- tokenId := c.GetInt("token_id")
- token, err = model.GetTokenById(tokenId)
- quota = token.UsedQuota
- } else {
- userId := c.GetInt("id")
- quota, err = model.GetUserUsedQuota(userId)
- }
- if err != nil {
- openAIError := dto.OpenAIError{
- Message: err.Error(),
- Type: "new_api_error",
- }
- c.JSON(200, gin.H{
- "error": openAIError,
- })
- return
- }
- amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
- }
- usage := OpenAIUsageResponse{
- Object: "list",
- TotalUsage: amount * 100,
- }
- c.JSON(200, usage)
- return
-}
diff --git a/new-api/controller/channel-billing.go b/new-api/controller/channel-billing.go
deleted file mode 100644
index c4a6ef1253c219e9d3cba34a7ceac161f00a7813..0000000000000000000000000000000000000000
--- a/new-api/controller/channel-billing.go
+++ /dev/null
@@ -1,496 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/model"
- "one-api/service"
- "one-api/setting/operation_setting"
- "one-api/types"
- "strconv"
- "time"
-
- "github.com/shopspring/decimal"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://github.com/songquanpeng/one-api/issues/79
-
-type OpenAISubscriptionResponse struct {
- Object string `json:"object"`
- HasPaymentMethod bool `json:"has_payment_method"`
- SoftLimitUSD float64 `json:"soft_limit_usd"`
- HardLimitUSD float64 `json:"hard_limit_usd"`
- SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
- AccessUntil int64 `json:"access_until"`
-}
-
-type OpenAIUsageDailyCost struct {
- Timestamp float64 `json:"timestamp"`
- LineItems []struct {
- Name string `json:"name"`
- Cost float64 `json:"cost"`
- }
-}
-
-type OpenAICreditGrants struct {
- Object string `json:"object"`
- TotalGranted float64 `json:"total_granted"`
- TotalUsed float64 `json:"total_used"`
- TotalAvailable float64 `json:"total_available"`
-}
-
-type OpenAIUsageResponse struct {
- Object string `json:"object"`
- //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
- TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
-}
-
-type OpenAISBUsageResponse struct {
- Msg string `json:"msg"`
- Data *struct {
- Credit string `json:"credit"`
- } `json:"data"`
-}
-
-type AIProxyUserOverviewResponse struct {
- Success bool `json:"success"`
- Message string `json:"message"`
- ErrorCode int `json:"error_code"`
- Data struct {
- TotalPoints float64 `json:"totalPoints"`
- } `json:"data"`
-}
-
-type API2GPTUsageResponse struct {
- Object string `json:"object"`
- TotalGranted float64 `json:"total_granted"`
- TotalUsed float64 `json:"total_used"`
- TotalRemaining float64 `json:"total_remaining"`
-}
-
-type APGC2DGPTUsageResponse struct {
- //Grants interface{} `json:"grants"`
- Object string `json:"object"`
- TotalAvailable float64 `json:"total_available"`
- TotalGranted float64 `json:"total_granted"`
- TotalUsed float64 `json:"total_used"`
-}
-
-type SiliconFlowUsageResponse struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status bool `json:"status"`
- Data struct {
- ID string `json:"id"`
- Name string `json:"name"`
- Image string `json:"image"`
- Email string `json:"email"`
- IsAdmin bool `json:"isAdmin"`
- Balance string `json:"balance"`
- Status string `json:"status"`
- Introduction string `json:"introduction"`
- Role string `json:"role"`
- ChargeBalance string `json:"chargeBalance"`
- TotalBalance string `json:"totalBalance"`
- Category string `json:"category"`
- } `json:"data"`
-}
-
-type DeepSeekUsageResponse struct {
- IsAvailable bool `json:"is_available"`
- BalanceInfos []struct {
- Currency string `json:"currency"`
- TotalBalance string `json:"total_balance"`
- GrantedBalance string `json:"granted_balance"`
- ToppedUpBalance string `json:"topped_up_balance"`
- } `json:"balance_infos"`
-}
-
-type OpenRouterCreditResponse struct {
- Data struct {
- TotalCredits float64 `json:"total_credits"`
- TotalUsage float64 `json:"total_usage"`
- } `json:"data"`
-}
-
-// GetAuthHeader get auth header
-func GetAuthHeader(token string) http.Header {
- h := http.Header{}
- h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
- return h
-}
-
-func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
- req, err := http.NewRequest(method, url, nil)
- if err != nil {
- return nil, err
- }
- for k := range headers {
- req.Header.Add(k, headers.Get(k))
- }
- client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
- if err != nil {
- return nil, err
- }
- res, err := client.Do(req)
- if err != nil {
- return nil, err
- }
- if res.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("status code: %d", res.StatusCode)
- }
- body, err := io.ReadAll(res.Body)
- if err != nil {
- return nil, err
- }
- err = res.Body.Close()
- if err != nil {
- return nil, err
- }
- return body, nil
-}
-
-func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
- url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
-
- if err != nil {
- return 0, err
- }
- response := OpenAICreditGrants{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(response.TotalAvailable)
- return response.TotalAvailable, nil
-}
-
-func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
- url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- response := OpenAISBUsageResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- if response.Data == nil {
- return 0, errors.New(response.Msg)
- }
- balance, err := strconv.ParseFloat(response.Data.Credit, 64)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(balance)
- return balance, nil
-}
-
-func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
- url := "https://aiproxy.io/api/report/getUserOverview"
- headers := http.Header{}
- headers.Add("Api-Key", channel.Key)
- body, err := GetResponseBody("GET", url, channel, headers)
- if err != nil {
- return 0, err
- }
- response := AIProxyUserOverviewResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- if !response.Success {
- return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
- }
- channel.UpdateBalance(response.Data.TotalPoints)
- return response.Data.TotalPoints, nil
-}
-
-func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
- url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
-
- if err != nil {
- return 0, err
- }
- response := API2GPTUsageResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(response.TotalRemaining)
- return response.TotalRemaining, nil
-}
-
-func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
- url := "https://api.siliconflow.cn/v1/user/info"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- response := SiliconFlowUsageResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- if response.Code != 20000 {
- return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
- }
- balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(balance)
- return balance, nil
-}
-
-func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
- url := "https://api.deepseek.com/user/balance"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- response := DeepSeekUsageResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- index := -1
- for i, balanceInfo := range response.BalanceInfos {
- if balanceInfo.Currency == "CNY" {
- index = i
- break
- }
- }
- if index == -1 {
- return 0, errors.New("currency CNY not found")
- }
- balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(balance)
- return balance, nil
-}
-
-func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
- url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- response := APGC2DGPTUsageResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- channel.UpdateBalance(response.TotalAvailable)
- return response.TotalAvailable, nil
-}
-
-func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
- url := "https://openrouter.ai/api/v1/credits"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- response := OpenRouterCreditResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- balance := response.Data.TotalCredits - response.Data.TotalUsage
- channel.UpdateBalance(balance)
- return balance, nil
-}
-
-func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
- url := "https://api.moonshot.cn/v1/users/me/balance"
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
-
- type MoonshotBalanceData struct {
- AvailableBalance float64 `json:"available_balance"`
- VoucherBalance float64 `json:"voucher_balance"`
- CashBalance float64 `json:"cash_balance"`
- }
-
- type MoonshotBalanceResponse struct {
- Code int `json:"code"`
- Data MoonshotBalanceData `json:"data"`
- Scode string `json:"scode"`
- Status bool `json:"status"`
- }
-
- response := MoonshotBalanceResponse{}
- err = json.Unmarshal(body, &response)
- if err != nil {
- return 0, err
- }
- if !response.Status || response.Code != 0 {
- return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
- }
- availableBalanceCny := response.Data.AvailableBalance
- availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
- channel.UpdateBalance(availableBalanceUsd)
- return availableBalanceUsd, nil
-}
-
-func updateChannelBalance(channel *model.Channel) (float64, error) {
- baseURL := constant.ChannelBaseURLs[channel.Type]
- if channel.GetBaseURL() == "" {
- channel.BaseURL = &baseURL
- }
- switch channel.Type {
- case constant.ChannelTypeOpenAI:
- if channel.GetBaseURL() != "" {
- baseURL = channel.GetBaseURL()
- }
- case constant.ChannelTypeAzure:
- return 0, errors.New("尚未实现")
- case constant.ChannelTypeCustom:
- baseURL = channel.GetBaseURL()
- //case common.ChannelTypeOpenAISB:
- // return updateChannelOpenAISBBalance(channel)
- case constant.ChannelTypeAIProxy:
- return updateChannelAIProxyBalance(channel)
- case constant.ChannelTypeAPI2GPT:
- return updateChannelAPI2GPTBalance(channel)
- case constant.ChannelTypeAIGC2D:
- return updateChannelAIGC2DBalance(channel)
- case constant.ChannelTypeSiliconFlow:
- return updateChannelSiliconFlowBalance(channel)
- case constant.ChannelTypeDeepSeek:
- return updateChannelDeepSeekBalance(channel)
- case constant.ChannelTypeOpenRouter:
- return updateChannelOpenRouterBalance(channel)
- case constant.ChannelTypeMoonshot:
- return updateChannelMoonshotBalance(channel)
- default:
- return 0, errors.New("尚未实现")
- }
- url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
-
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- subscription := OpenAISubscriptionResponse{}
- err = json.Unmarshal(body, &subscription)
- if err != nil {
- return 0, err
- }
- now := time.Now()
- startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
- endDate := now.Format("2006-01-02")
- if !subscription.HasPaymentMethod {
- startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
- }
- url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
- body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
- if err != nil {
- return 0, err
- }
- usage := OpenAIUsageResponse{}
- err = json.Unmarshal(body, &usage)
- if err != nil {
- return 0, err
- }
- balance := subscription.HardLimitUSD - usage.TotalUsage/100
- channel.UpdateBalance(balance)
- return balance, nil
-}
-
-func UpdateChannelBalance(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- channel, err := model.CacheGetChannel(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if channel.ChannelInfo.IsMultiKey {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "多密钥渠道不支持余额查询",
- })
- return
- }
- balance, err := updateChannelBalance(channel)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "balance": balance,
- })
-}
-
-func updateAllChannelsBalance() error {
- channels, err := model.GetAllChannels(0, 0, true, false)
- if err != nil {
- return err
- }
- for _, channel := range channels {
- if channel.Status != common.ChannelStatusEnabled {
- continue
- }
- if channel.ChannelInfo.IsMultiKey {
- continue // skip multi-key channels
- }
- // TODO: support Azure
- //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
- // continue
- //}
- balance, err := updateChannelBalance(channel)
- if err != nil {
- continue
- } else {
- // err is nil & balance <= 0 means quota is used up
- if balance <= 0 {
- service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
- }
- }
- time.Sleep(common.RequestInterval)
- }
- return nil
-}
-
-func UpdateAllChannelsBalance(c *gin.Context) {
- // TODO: make it async
- err := updateAllChannelsBalance()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func AutomaticallyUpdateChannels(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("updating all channels")
- _ = updateAllChannelsBalance()
- common.SysLog("channels update done")
- }
-}
diff --git a/new-api/controller/channel-test.go b/new-api/controller/channel-test.go
deleted file mode 100644
index 65261c9d85a87cebdae2d69eb82aafb41a3453ee..0000000000000000000000000000000000000000
--- a/new-api/controller/channel-test.go
+++ /dev/null
@@ -1,655 +0,0 @@
-package controller
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "math"
- "net/http"
- "net/http/httptest"
- "net/url"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/middleware"
- "one-api/model"
- "one-api/relay"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/operation_setting"
- "one-api/types"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
-
- "github.com/gin-gonic/gin"
-)
-
-type testResult struct {
- context *gin.Context
- localErr error
- newAPIError *types.NewAPIError
-}
-
-func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
- tik := time.Now()
- if channel.Type == constant.ChannelTypeMidjourney {
- return testResult{
- localErr: errors.New("midjourney channel test is not supported"),
- newAPIError: nil,
- }
- }
- if channel.Type == constant.ChannelTypeMidjourneyPlus {
- return testResult{
- localErr: errors.New("midjourney plus channel test is not supported"),
- newAPIError: nil,
- }
- }
- if channel.Type == constant.ChannelTypeSunoAPI {
- return testResult{
- localErr: errors.New("suno channel test is not supported"),
- newAPIError: nil,
- }
- }
- if channel.Type == constant.ChannelTypeKling {
- return testResult{
- localErr: errors.New("kling channel test is not supported"),
- newAPIError: nil,
- }
- }
- if channel.Type == constant.ChannelTypeJimeng {
- return testResult{
- localErr: errors.New("jimeng channel test is not supported"),
- newAPIError: nil,
- }
- }
- if channel.Type == constant.ChannelTypeVidu {
- return testResult{
- localErr: errors.New("vidu channel test is not supported"),
- newAPIError: nil,
- }
- }
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
-
- requestPath := "/v1/chat/completions"
-
- // 如果指定了端点类型,使用指定的端点类型
- if endpointType != "" {
- if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
- requestPath = endpointInfo.Path
- }
- } else {
- // 如果没有指定端点类型,使用原有的自动检测逻辑
- // 先判断是否为 Embedding 模型
- if strings.Contains(strings.ToLower(testModel), "embedding") ||
- strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
- strings.Contains(testModel, "bge-") || // bge 系列模型
- strings.Contains(testModel, "embed") ||
- channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
- requestPath = "/v1/embeddings" // 修改请求路径
- }
-
- // VolcEngine 图像生成模型
- if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
- requestPath = "/v1/images/generations"
- }
- }
-
- c.Request = &http.Request{
- Method: "POST",
- URL: &url.URL{Path: requestPath}, // 使用动态路径
- Body: nil,
- Header: make(http.Header),
- }
-
- if testModel == "" {
- if channel.TestModel != nil && *channel.TestModel != "" {
- testModel = *channel.TestModel
- } else {
- if len(channel.GetModels()) > 0 {
- testModel = channel.GetModels()[0]
- } else {
- testModel = "gpt-4o-mini"
- }
- }
- }
-
- cache, err := model.GetUserCache(1)
- if err != nil {
- return testResult{
- localErr: err,
- newAPIError: nil,
- }
- }
- cache.WriteContext(c)
-
- //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
- c.Request.Header.Set("Content-Type", "application/json")
- c.Set("channel", channel.Type)
- c.Set("base_url", channel.GetBaseURL())
- group, _ := model.GetUserGroup(1, false)
- c.Set("group", group)
-
- newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
- if newAPIError != nil {
- return testResult{
- context: c,
- localErr: newAPIError,
- newAPIError: newAPIError,
- }
- }
-
- // Determine relay format based on endpoint type or request path
- var relayFormat types.RelayFormat
- if endpointType != "" {
- // 根据指定的端点类型设置 relayFormat
- switch constant.EndpointType(endpointType) {
- case constant.EndpointTypeOpenAI:
- relayFormat = types.RelayFormatOpenAI
- case constant.EndpointTypeOpenAIResponse:
- relayFormat = types.RelayFormatOpenAIResponses
- case constant.EndpointTypeAnthropic:
- relayFormat = types.RelayFormatClaude
- case constant.EndpointTypeGemini:
- relayFormat = types.RelayFormatGemini
- case constant.EndpointTypeJinaRerank:
- relayFormat = types.RelayFormatRerank
- case constant.EndpointTypeImageGeneration:
- relayFormat = types.RelayFormatOpenAIImage
- case constant.EndpointTypeEmbeddings:
- relayFormat = types.RelayFormatEmbedding
- default:
- relayFormat = types.RelayFormatOpenAI
- }
- } else {
- // 根据请求路径自动检测
- relayFormat = types.RelayFormatOpenAI
- if c.Request.URL.Path == "/v1/embeddings" {
- relayFormat = types.RelayFormatEmbedding
- }
- if c.Request.URL.Path == "/v1/images/generations" {
- relayFormat = types.RelayFormatOpenAIImage
- }
- if c.Request.URL.Path == "/v1/messages" {
- relayFormat = types.RelayFormatClaude
- }
- if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
- relayFormat = types.RelayFormatGemini
- }
- if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
- relayFormat = types.RelayFormatRerank
- }
- if c.Request.URL.Path == "/v1/responses" {
- relayFormat = types.RelayFormatOpenAIResponses
- }
- }
-
- request := buildTestRequest(testModel, endpointType)
-
- info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
-
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
- }
- }
-
- info.InitChannelMeta(c)
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
- }
- }
-
- testModel = info.UpstreamModelName
- // 更新请求中的模型名称
- request.SetModelName(testModel)
-
- apiType, _ := common.ChannelType2APIType(channel.Type)
- adaptor := relay.GetAdaptor(apiType)
- if adaptor == nil {
- return testResult{
- context: c,
- localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
- newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
- }
- }
-
- //// 创建一个用于日志的 info 副本,移除 ApiKey
- //logInfo := info
- //logInfo.ApiKey = ""
- common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
-
- priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
- }
- }
-
- adaptor.Init(info)
-
- var convertedRequest any
- // 根据 RelayMode 选择正确的转换函数
- switch info.RelayMode {
- case relayconstant.RelayModeEmbeddings:
- // Embedding 请求 - request 已经是正确的类型
- if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
- convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
- } else {
- return testResult{
- context: c,
- localErr: errors.New("invalid embedding request type"),
- newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
- }
- }
- case relayconstant.RelayModeImagesGenerations:
- // 图像生成请求 - request 已经是正确的类型
- if imageReq, ok := request.(*dto.ImageRequest); ok {
- convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
- } else {
- return testResult{
- context: c,
- localErr: errors.New("invalid image request type"),
- newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
- }
- }
- case relayconstant.RelayModeRerank:
- // Rerank 请求 - request 已经是正确的类型
- if rerankReq, ok := request.(*dto.RerankRequest); ok {
- convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
- } else {
- return testResult{
- context: c,
- localErr: errors.New("invalid rerank request type"),
- newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
- }
- }
- case relayconstant.RelayModeResponses:
- // Response 请求 - request 已经是正确的类型
- if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
- convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
- } else {
- return testResult{
- context: c,
- localErr: errors.New("invalid response request type"),
- newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
- }
- }
- default:
- // Chat/Completion 等其他请求类型
- if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
- convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
- } else {
- return testResult{
- context: c,
- localErr: errors.New("invalid general request type"),
- newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
- }
- }
- }
-
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
- }
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
- }
- }
- requestBody := bytes.NewBuffer(jsonData)
- c.Request.Body = io.NopCloser(requestBody)
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
- }
- }
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
- }
- }
- }
- usageA, respErr := adaptor.DoResponse(c, httpResp, info)
- if respErr != nil {
- return testResult{
- context: c,
- localErr: respErr,
- newAPIError: respErr,
- }
- }
- if usageA == nil {
- return testResult{
- context: c,
- localErr: errors.New("usage is nil"),
- newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
- }
- }
- usage := usageA.(*dto.Usage)
- result := w.Result()
- respBody, err := io.ReadAll(result.Body)
- if err != nil {
- return testResult{
- context: c,
- localErr: err,
- newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
- }
- }
- info.PromptTokens = usage.PromptTokens
-
- quota := 0
- if !priceData.UsePrice {
- quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
- quota = int(math.Round(float64(quota) * priceData.ModelRatio))
- if priceData.ModelRatio != 0 && quota <= 0 {
- quota = 1
- }
- } else {
- quota = int(priceData.ModelPrice * common.QuotaPerUnit)
- }
- tok := time.Now()
- milliseconds := tok.Sub(tik).Milliseconds()
- consumedTime := float64(milliseconds) / 1000.0
- other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
- usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
- model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
- ChannelId: channel.Id,
- PromptTokens: usage.PromptTokens,
- CompletionTokens: usage.CompletionTokens,
- ModelName: info.OriginModelName,
- TokenName: "模型测试",
- Quota: quota,
- Content: "模型测试",
- UseTimeSeconds: int(consumedTime),
- IsStream: info.IsStream,
- Group: info.UsingGroup,
- Other: other,
- })
- common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
- return testResult{
- context: c,
- localErr: nil,
- newAPIError: nil,
- }
-}
-
-func buildTestRequest(model string, endpointType string) dto.Request {
- // 根据端点类型构建不同的测试请求
- if endpointType != "" {
- switch constant.EndpointType(endpointType) {
- case constant.EndpointTypeEmbeddings:
- // 返回 EmbeddingRequest
- return &dto.EmbeddingRequest{
- Model: model,
- Input: []any{"hello world"},
- }
- case constant.EndpointTypeImageGeneration:
- // 返回 ImageRequest
- return &dto.ImageRequest{
- Model: model,
- Prompt: "a cute cat",
- N: 1,
- Size: "1024x1024",
- }
- case constant.EndpointTypeJinaRerank:
- // 返回 RerankRequest
- return &dto.RerankRequest{
- Model: model,
- Query: "What is Deep Learning?",
- Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
- TopN: 2,
- }
- case constant.EndpointTypeOpenAIResponse:
- // 返回 OpenAIResponsesRequest
- return &dto.OpenAIResponsesRequest{
- Model: model,
- Input: json.RawMessage("\"hi\""),
- }
- case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
- // 返回 GeneralOpenAIRequest
- maxTokens := uint(10)
- if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
- maxTokens = 3000
- }
- return &dto.GeneralOpenAIRequest{
- Model: model,
- Stream: false,
- Messages: []dto.Message{
- {
- Role: "user",
- Content: "hi",
- },
- },
- MaxTokens: maxTokens,
- }
- }
- }
-
- // 自动检测逻辑(保持原有行为)
- // 先判断是否为 Embedding 模型
- if strings.Contains(strings.ToLower(model), "embedding") ||
- strings.HasPrefix(model, "m3e") ||
- strings.Contains(model, "bge-") {
- // 返回 EmbeddingRequest
- return &dto.EmbeddingRequest{
- Model: model,
- Input: []any{"hello world"},
- }
- }
-
- // Chat/Completion 请求 - 返回 GeneralOpenAIRequest
- testRequest := &dto.GeneralOpenAIRequest{
- Model: model,
- Stream: false,
- Messages: []dto.Message{
- {
- Role: "user",
- Content: "hi",
- },
- },
- }
-
- if strings.HasPrefix(model, "o") {
- testRequest.MaxCompletionTokens = 10
- } else if strings.Contains(model, "thinking") {
- if !strings.Contains(model, "claude") {
- testRequest.MaxTokens = 50
- }
- } else if strings.Contains(model, "gemini") {
- testRequest.MaxTokens = 3000
- } else {
- testRequest.MaxTokens = 10
- }
-
- return testRequest
-}
-
-func TestChannel(c *gin.Context) {
- channelId, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- channel, err := model.CacheGetChannel(channelId)
- if err != nil {
- channel, err = model.GetChannelById(channelId, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- }
- //defer func() {
- // if channel.ChannelInfo.IsMultiKey {
- // go func() { _ = channel.SaveChannelInfo() }()
- // }
- //}()
- testModel := c.Query("model")
- endpointType := c.Query("endpoint_type")
- tik := time.Now()
- result := testChannel(channel, testModel, endpointType)
- if result.localErr != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": result.localErr.Error(),
- "time": 0.0,
- })
- return
- }
- tok := time.Now()
- milliseconds := tok.Sub(tik).Milliseconds()
- go channel.UpdateResponseTime(milliseconds)
- consumedTime := float64(milliseconds) / 1000.0
- if result.newAPIError != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": result.newAPIError.Error(),
- "time": consumedTime,
- })
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "time": consumedTime,
- })
-}
-
-var testAllChannelsLock sync.Mutex
-var testAllChannelsRunning bool = false
-
-func testAllChannels(notify bool) error {
-
- testAllChannelsLock.Lock()
- if testAllChannelsRunning {
- testAllChannelsLock.Unlock()
- return errors.New("测试已在运行中")
- }
- testAllChannelsRunning = true
- testAllChannelsLock.Unlock()
- channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
- if getChannelErr != nil {
- return getChannelErr
- }
- var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
- if disableThreshold == 0 {
- disableThreshold = 10000000 // a impossible value
- }
- gopool.Go(func() {
- // 使用 defer 确保无论如何都会重置运行状态,防止死锁
- defer func() {
- testAllChannelsLock.Lock()
- testAllChannelsRunning = false
- testAllChannelsLock.Unlock()
- }()
-
- for _, channel := range channels {
- isChannelEnabled := channel.Status == common.ChannelStatusEnabled
- tik := time.Now()
- result := testChannel(channel, "", "")
- tok := time.Now()
- milliseconds := tok.Sub(tik).Milliseconds()
-
- shouldBanChannel := false
- newAPIError := result.newAPIError
- // request error disables the channel
- if newAPIError != nil {
- shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
- }
-
- // 当错误检查通过,才检查响应时间
- if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
- if milliseconds > disableThreshold {
- err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
- newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
- shouldBanChannel = true
- }
- }
-
- // disable channel
- if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
- processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
- }
-
- // enable channel
- if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
- service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
- }
-
- channel.UpdateResponseTime(milliseconds)
- time.Sleep(common.RequestInterval)
- }
-
- if notify {
- service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
- }
- })
- return nil
-}
-
-func TestAllChannels(c *gin.Context) {
- err := testAllChannels(true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
-}
-
-var autoTestChannelsOnce sync.Once
-
-func AutomaticallyTestChannels() {
- autoTestChannelsOnce.Do(func() {
- for {
- if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
- time.Sleep(10 * time.Minute)
- continue
- }
- frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
- common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
- for {
- time.Sleep(time.Duration(frequency) * time.Minute)
- common.SysLog("automatically testing all channels")
- _ = testAllChannels(false)
- common.SysLog("automatically channel test finished")
- if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
- break
- }
- }
- }
- })
-}
diff --git a/new-api/controller/channel.go b/new-api/controller/channel.go
deleted file mode 100644
index d0a13c4874d415e02a4216417b7012a7d31e9b39..0000000000000000000000000000000000000000
--- a/new-api/controller/channel.go
+++ /dev/null
@@ -1,1567 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- "one-api/service"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type OpenAIModel struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- OwnedBy string `json:"owned_by"`
- Permission []struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- AllowCreateEngine bool `json:"allow_create_engine"`
- AllowSampling bool `json:"allow_sampling"`
- AllowLogprobs bool `json:"allow_logprobs"`
- AllowSearchIndices bool `json:"allow_search_indices"`
- AllowView bool `json:"allow_view"`
- AllowFineTuning bool `json:"allow_fine_tuning"`
- Organization string `json:"organization"`
- Group string `json:"group"`
- IsBlocking bool `json:"is_blocking"`
- } `json:"permission"`
- Root string `json:"root"`
- Parent string `json:"parent"`
-}
-
-type OpenAIModelsResponse struct {
- Data []OpenAIModel `json:"data"`
- Success bool `json:"success"`
-}
-
-func parseStatusFilter(statusParam string) int {
- switch strings.ToLower(statusParam) {
- case "enabled", "1":
- return common.ChannelStatusEnabled
- case "disabled", "0":
- return 0
- default:
- return -1
- }
-}
-
-func clearChannelInfo(channel *model.Channel) {
- if channel.ChannelInfo.IsMultiKey {
- channel.ChannelInfo.MultiKeyDisabledReason = nil
- channel.ChannelInfo.MultiKeyDisabledTime = nil
- }
-}
-
-func GetAllChannels(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- channelData := make([]*model.Channel, 0)
- idSort, _ := strconv.ParseBool(c.Query("id_sort"))
- enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
- statusParam := c.Query("status")
- // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
- statusFilter := parseStatusFilter(statusParam)
- // type filter
- typeStr := c.Query("type")
- typeFilter := -1
- if typeStr != "" {
- if t, err := strconv.Atoi(typeStr); err == nil {
- typeFilter = t
- }
- }
-
- var total int64
-
- if enableTagMode {
- tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- for _, tag := range tags {
- if tag == nil || *tag == "" {
- continue
- }
- tagChannels, err := model.GetChannelsByTag(*tag, idSort)
- if err != nil {
- continue
- }
- filtered := make([]*model.Channel, 0)
- for _, ch := range tagChannels {
- if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
- continue
- }
- if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
- continue
- }
- if typeFilter >= 0 && ch.Type != typeFilter {
- continue
- }
- filtered = append(filtered, ch)
- }
- channelData = append(channelData, filtered...)
- }
- total, _ = model.CountAllTags()
- } else {
- baseQuery := model.DB.Model(&model.Channel{})
- if typeFilter >= 0 {
- baseQuery = baseQuery.Where("type = ?", typeFilter)
- }
- if statusFilter == common.ChannelStatusEnabled {
- baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
- } else if statusFilter == 0 {
- baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
- }
-
- baseQuery.Count(&total)
-
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
-
- err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- }
-
- for _, datum := range channelData {
- clearChannelInfo(datum)
- }
-
- countQuery := model.DB.Model(&model.Channel{})
- if statusFilter == common.ChannelStatusEnabled {
- countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
- } else if statusFilter == 0 {
- countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
- }
- var results []struct {
- Type int64
- Count int64
- }
- _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
- typeCounts := make(map[int64]int64)
- for _, r := range results {
- typeCounts[r.Type] = r.Count
- }
- common.ApiSuccess(c, gin.H{
- "items": channelData,
- "total": total,
- "page": pageInfo.GetPage(),
- "page_size": pageInfo.GetPageSize(),
- "type_counts": typeCounts,
- })
- return
-}
-
-func FetchUpstreamModels(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- channel, err := model.GetChannelById(id, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- baseURL := constant.ChannelBaseURLs[channel.Type]
- if channel.GetBaseURL() != "" {
- baseURL = channel.GetBaseURL()
- }
-
- var url string
- switch channel.Type {
- case constant.ChannelTypeGemini:
- // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
- url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
- case constant.ChannelTypeAli:
- url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
- case constant.ChannelTypeZhipu_v4:
- url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
- default:
- url = fmt.Sprintf("%s/v1/models", baseURL)
- }
-
- // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
- var body []byte
- key := strings.Split(channel.Key, "\n")[0]
- if channel.Type == constant.ChannelTypeGemini {
- body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
- } else {
- body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
- }
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- var result OpenAIModelsResponse
- if err = json.Unmarshal(body, &result); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
- })
- return
- }
-
- var ids []string
- for _, model := range result.Data {
- id := model.ID
- if channel.Type == constant.ChannelTypeGemini {
- id = strings.TrimPrefix(id, "models/")
- }
- ids = append(ids, id)
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": ids,
- })
-}
-
-func FixChannelsAbilities(c *gin.Context) {
- success, fails, err := model.FixAbility()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "success": success,
- "fails": fails,
- },
- })
-}
-
-func SearchChannels(c *gin.Context) {
- keyword := c.Query("keyword")
- group := c.Query("group")
- modelKeyword := c.Query("model")
- statusParam := c.Query("status")
- statusFilter := parseStatusFilter(statusParam)
- idSort, _ := strconv.ParseBool(c.Query("id_sort"))
- enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
- channelData := make([]*model.Channel, 0)
- if enableTagMode {
- tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- for _, tag := range tags {
- if tag != nil && *tag != "" {
- tagChannel, err := model.GetChannelsByTag(*tag, idSort)
- if err == nil {
- channelData = append(channelData, tagChannel...)
- }
- }
- }
- } else {
- channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- channelData = channels
- }
-
- if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
- filtered := make([]*model.Channel, 0, len(channelData))
- for _, ch := range channelData {
- if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
- continue
- }
- if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
- continue
- }
- filtered = append(filtered, ch)
- }
- channelData = filtered
- }
-
- // calculate type counts for search results
- typeCounts := make(map[int64]int64)
- for _, channel := range channelData {
- typeCounts[int64(channel.Type)]++
- }
-
- typeParam := c.Query("type")
- typeFilter := -1
- if typeParam != "" {
- if tp, err := strconv.Atoi(typeParam); err == nil {
- typeFilter = tp
- }
- }
-
- if typeFilter >= 0 {
- filtered := make([]*model.Channel, 0, len(channelData))
- for _, ch := range channelData {
- if ch.Type == typeFilter {
- filtered = append(filtered, ch)
- }
- }
- channelData = filtered
- }
-
- page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
- if page < 1 {
- page = 1
- }
- if pageSize <= 0 {
- pageSize = 20
- }
-
- total := len(channelData)
- startIdx := (page - 1) * pageSize
- if startIdx > total {
- startIdx = total
- }
- endIdx := startIdx + pageSize
- if endIdx > total {
- endIdx = total
- }
-
- pagedData := channelData[startIdx:endIdx]
-
- for _, datum := range pagedData {
- clearChannelInfo(datum)
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "items": pagedData,
- "total": total,
- "type_counts": typeCounts,
- },
- })
- return
-}
-
-func GetChannel(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- channel, err := model.GetChannelById(id, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if channel != nil {
- clearChannelInfo(channel)
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": channel,
- })
- return
-}
-
-// GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
-// 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
-func GetChannelKey(c *gin.Context) {
- userId := c.GetInt("id")
- channelId, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
- return
- }
-
- // 获取渠道信息(包含密钥)
- channel, err := model.GetChannelById(channelId, true)
- if err != nil {
- common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
- return
- }
-
- if channel == nil {
- common.ApiError(c, fmt.Errorf("渠道不存在"))
- return
- }
-
- // 记录操作日志
- model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
-
- // 返回渠道密钥
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "获取成功",
- "data": map[string]interface{}{
- "key": channel.Key,
- },
- })
-}
-
-// validateTwoFactorAuth 统一的2FA验证函数
-func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
- // 尝试验证TOTP
- if cleanCode, err := common.ValidateNumericCode(code); err == nil {
- if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
- return true
- }
- }
-
- // 尝试验证备用码
- if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
- return true
- }
-
- return false
-}
-
-// validateChannel 通用的渠道校验函数
-func validateChannel(channel *model.Channel, isAdd bool) error {
- // 校验 channel settings
- if err := channel.ValidateSettings(); err != nil {
- return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
- }
-
- // 如果是添加操作,检查 channel 和 key 是否为空
- if isAdd {
- if channel == nil || channel.Key == "" {
- return fmt.Errorf("channel cannot be empty")
- }
-
- // 检查模型名称长度是否超过 255
- for _, m := range channel.GetModels() {
- if len(m) > 255 {
- return fmt.Errorf("模型名称过长: %s", m)
- }
- }
- }
-
- // VertexAI 特殊校验
- if channel.Type == constant.ChannelTypeVertexAi {
- if channel.Other == "" {
- return fmt.Errorf("部署地区不能为空")
- }
-
- regionMap, err := common.StrToMap(channel.Other)
- if err != nil {
- return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
- }
-
- if regionMap["default"] == nil {
- return fmt.Errorf("部署地区必须包含default字段")
- }
- }
-
- return nil
-}
-
-type AddChannelRequest struct {
- Mode string `json:"mode"`
- MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
- BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
- Channel *model.Channel `json:"channel"`
-}
-
-func getVertexArrayKeys(keys string) ([]string, error) {
- if keys == "" {
- return nil, nil
- }
- var keyArray []interface{}
- err := common.Unmarshal([]byte(keys), &keyArray)
- if err != nil {
- return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
- }
- cleanKeys := make([]string, 0, len(keyArray))
- for _, key := range keyArray {
- var keyStr string
- switch v := key.(type) {
- case string:
- keyStr = strings.TrimSpace(v)
- default:
- bytes, err := json.Marshal(v)
- if err != nil {
- return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
- }
- keyStr = string(bytes)
- }
- if keyStr != "" {
- cleanKeys = append(cleanKeys, keyStr)
- }
- }
- if len(cleanKeys) == 0 {
- return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
- }
- return cleanKeys, nil
-}
-
-func AddChannel(c *gin.Context) {
- addChannelRequest := AddChannelRequest{}
- err := c.ShouldBindJSON(&addChannelRequest)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 使用统一的校验函数
- if err := validateChannel(addChannelRequest.Channel, true); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
- keys := make([]string, 0)
- switch addChannelRequest.Mode {
- case "multi_to_single":
- addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
- addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
- if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
- array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
- addChannelRequest.Channel.Key = strings.Join(array, "\n")
- } else {
- cleanKeys := make([]string, 0)
- for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
- if key == "" {
- continue
- }
- key = strings.TrimSpace(key)
- cleanKeys = append(cleanKeys, key)
- }
- addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
- addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
- }
- keys = []string{addChannelRequest.Channel.Key}
- case "batch":
- if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
- // multi json
- keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- keys = strings.Split(addChannelRequest.Channel.Key, "\n")
- }
- case "single":
- keys = []string{addChannelRequest.Channel.Key}
- default:
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "不支持的添加模式",
- })
- return
- }
-
- channels := make([]model.Channel, 0, len(keys))
- for _, key := range keys {
- if key == "" {
- continue
- }
- localChannel := addChannelRequest.Channel
- localChannel.Key = key
- if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
- keyPrefix := localChannel.Key
- if len(localChannel.Key) > 8 {
- keyPrefix = localChannel.Key[:8]
- }
- localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
- }
- channels = append(channels, *localChannel)
- }
- err = model.BatchInsertChannels(channels)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- service.ResetProxyClientCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func DeleteChannel(c *gin.Context) {
- id, _ := strconv.Atoi(c.Param("id"))
- channel := model.Channel{Id: id}
- err := channel.Delete()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func DeleteDisabledChannel(c *gin.Context) {
- rows, err := model.DeleteDisabledChannel()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": rows,
- })
- return
-}
-
-type ChannelTag struct {
- Tag string `json:"tag"`
- NewTag *string `json:"new_tag"`
- Priority *int64 `json:"priority"`
- Weight *uint `json:"weight"`
- ModelMapping *string `json:"model_mapping"`
- Models *string `json:"models"`
- Groups *string `json:"groups"`
-}
-
-func DisableTagChannels(c *gin.Context) {
- channelTag := ChannelTag{}
- err := c.ShouldBindJSON(&channelTag)
- if err != nil || channelTag.Tag == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- err = model.DisableChannelByTag(channelTag.Tag)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func EnableTagChannels(c *gin.Context) {
- channelTag := ChannelTag{}
- err := c.ShouldBindJSON(&channelTag)
- if err != nil || channelTag.Tag == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- err = model.EnableChannelByTag(channelTag.Tag)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func EditTagChannels(c *gin.Context) {
- channelTag := ChannelTag{}
- err := c.ShouldBindJSON(&channelTag)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- if channelTag.Tag == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "tag不能为空",
- })
- return
- }
- err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-type ChannelBatch struct {
- Ids []int `json:"ids"`
- Tag *string `json:"tag"`
-}
-
-func DeleteChannelBatch(c *gin.Context) {
- channelBatch := ChannelBatch{}
- err := c.ShouldBindJSON(&channelBatch)
- if err != nil || len(channelBatch.Ids) == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- err = model.BatchDeleteChannels(channelBatch.Ids)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": len(channelBatch.Ids),
- })
- return
-}
-
-type PatchChannel struct {
- model.Channel
- MultiKeyMode *string `json:"multi_key_mode"`
- KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
-}
-
-func UpdateChannel(c *gin.Context) {
- channel := PatchChannel{}
- err := c.ShouldBindJSON(&channel)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 使用统一的校验函数
- if err := validateChannel(&channel.Channel, false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
- originChannel, err := model.GetChannelById(channel.Id, true)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
- channel.ChannelInfo = originChannel.ChannelInfo
-
- // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
- if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
- channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
- }
-
- // 处理多key模式下的密钥追加/覆盖逻辑
- if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
- switch *channel.KeyMode {
- case "append":
- // 追加模式:将新密钥添加到现有密钥列表
- if originChannel.Key != "" {
- var newKeys []string
- var existingKeys []string
-
- // 解析现有密钥
- if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
- // JSON数组格式
- var arr []json.RawMessage
- if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
- existingKeys = make([]string, len(arr))
- for i, v := range arr {
- existingKeys[i] = string(v)
- }
- }
- } else {
- // 换行分隔格式
- existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
- }
-
- // 处理 Vertex AI 的特殊情况
- if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
- // 尝试解析新密钥为JSON数组
- if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
- array, err := getVertexArrayKeys(channel.Key)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "追加密钥解析失败: " + err.Error(),
- })
- return
- }
- newKeys = array
- } else {
- // 单个JSON密钥
- newKeys = []string{channel.Key}
- }
- // 合并密钥
- allKeys := append(existingKeys, newKeys...)
- channel.Key = strings.Join(allKeys, "\n")
- } else {
- // 普通渠道的处理
- inputKeys := strings.Split(channel.Key, "\n")
- for _, key := range inputKeys {
- key = strings.TrimSpace(key)
- if key != "" {
- newKeys = append(newKeys, key)
- }
- }
- // 合并密钥
- allKeys := append(existingKeys, newKeys...)
- channel.Key = strings.Join(allKeys, "\n")
- }
- }
- case "replace":
- // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
- }
- }
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- service.ResetProxyClientCache()
- channel.Key = ""
- clearChannelInfo(&channel.Channel)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": channel,
- })
- return
-}
-
-func FetchModels(c *gin.Context) {
- var req struct {
- BaseURL string `json:"base_url"`
- Type int `json:"type"`
- Key string `json:"key"`
- }
-
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": "Invalid request",
- })
- return
- }
-
- baseURL := req.BaseURL
- if baseURL == "" {
- baseURL = constant.ChannelBaseURLs[req.Type]
- }
-
- client := &http.Client{}
- url := fmt.Sprintf("%s/v1/models", baseURL)
-
- request, err := http.NewRequest("GET", url, nil)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- // remove line breaks and extra spaces.
- key := strings.TrimSpace(req.Key)
- // If the key contains a line break, only take the first part.
- key = strings.Split(key, "\n")[0]
- request.Header.Set("Authorization", "Bearer "+key)
-
- response, err := client.Do(request)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- //check status code
- if response.StatusCode != http.StatusOK {
- c.JSON(http.StatusInternalServerError, gin.H{
- "success": false,
- "message": "Failed to fetch models",
- })
- return
- }
- defer response.Body.Close()
-
- var result struct {
- Data []struct {
- ID string `json:"id"`
- } `json:"data"`
- }
-
- if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- var models []string
- for _, model := range result.Data {
- models = append(models, model.ID)
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": models,
- })
-}
-
-func BatchSetChannelTag(c *gin.Context) {
- channelBatch := ChannelBatch{}
- err := c.ShouldBindJSON(&channelBatch)
- if err != nil || len(channelBatch.Ids) == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": len(channelBatch.Ids),
- })
- return
-}
-
-func GetTagModels(c *gin.Context) {
- tag := c.Query("tag")
- if tag == "" {
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": "tag不能为空",
- })
- return
- }
-
- channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- var longestModels string
- maxLength := 0
-
- // Find the longest models string among all channels with the given tag
- for _, channel := range channels {
- if channel.Models != "" {
- currentModels := strings.Split(channel.Models, ",")
- if len(currentModels) > maxLength {
- maxLength = len(currentModels)
- longestModels = channel.Models
- }
- }
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": longestModels,
- })
- return
-}
-
-// CopyChannel handles cloning an existing channel with its key.
-// POST /api/channel/copy/:id
-// Optional query params:
-//
-// suffix - string appended to the original name (default "_复制")
-// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
-func CopyChannel(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
- return
- }
-
- suffix := c.DefaultQuery("suffix", "_复制")
- resetBalance := true
- if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
- if v, err := strconv.ParseBool(rbStr); err == nil {
- resetBalance = v
- }
- }
-
- // fetch original channel with key
- origin, err := model.GetChannelById(id, true)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
-
- // clone channel
- clone := *origin // shallow copy is sufficient as we will overwrite primitives
- clone.Id = 0 // let DB auto-generate
- clone.CreatedTime = common.GetTimestamp()
- clone.Name = origin.Name + suffix
- clone.TestTime = 0
- clone.ResponseTime = 0
- if resetBalance {
- clone.Balance = 0
- clone.UsedQuota = 0
- }
-
- // insert
- if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- model.InitChannelCache()
- // success
- c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
-}
-
-// MultiKeyManageRequest represents the request for multi-key management operations
-type MultiKeyManageRequest struct {
- ChannelId int `json:"channel_id"`
- Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
- KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
- Page int `json:"page,omitempty"` // for get_key_status pagination
- PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
- Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
-}
-
-// MultiKeyStatusResponse represents the response for key status query
-type MultiKeyStatusResponse struct {
- Keys []KeyStatus `json:"keys"`
- Total int `json:"total"`
- Page int `json:"page"`
- PageSize int `json:"page_size"`
- TotalPages int `json:"total_pages"`
- // Statistics
- EnabledCount int `json:"enabled_count"`
- ManualDisabledCount int `json:"manual_disabled_count"`
- AutoDisabledCount int `json:"auto_disabled_count"`
-}
-
-type KeyStatus struct {
- Index int `json:"index"`
- Status int `json:"status"` // 1: enabled, 2: disabled
- DisabledTime int64 `json:"disabled_time,omitempty"`
- Reason string `json:"reason,omitempty"`
- KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
-}
-
-// ManageMultiKeys handles multi-key management operations
-func ManageMultiKeys(c *gin.Context) {
- request := MultiKeyManageRequest{}
- err := c.ShouldBindJSON(&request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- channel, err := model.GetChannelById(request.ChannelId, true)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "渠道不存在",
- })
- return
- }
-
- if !channel.ChannelInfo.IsMultiKey {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该渠道不是多密钥模式",
- })
- return
- }
-
- lock := model.GetChannelPollingLock(channel.Id)
- lock.Lock()
- defer lock.Unlock()
-
- switch request.Action {
- case "get_key_status":
- keys := channel.GetKeys()
-
- // Default pagination parameters
- page := request.Page
- pageSize := request.PageSize
- if page <= 0 {
- page = 1
- }
- if pageSize <= 0 {
- pageSize = 50 // Default page size
- }
-
- // Statistics for all keys (unchanged by filtering)
- var enabledCount, manualDisabledCount, autoDisabledCount int
-
- // Build all key status data first
- var allKeyStatusList []KeyStatus
- for i, key := range keys {
- status := 1 // default enabled
- var disabledTime int64
- var reason string
-
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
- status = s
- }
- }
-
- // Count for statistics (all keys)
- switch status {
- case 1:
- enabledCount++
- case 2:
- manualDisabledCount++
- case 3:
- autoDisabledCount++
- }
-
- if status != 1 {
- if channel.ChannelInfo.MultiKeyDisabledTime != nil {
- disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
- }
- if channel.ChannelInfo.MultiKeyDisabledReason != nil {
- reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
- }
- }
-
- // Create key preview (first 10 chars)
- keyPreview := key
- if len(key) > 10 {
- keyPreview = key[:10] + "..."
- }
-
- allKeyStatusList = append(allKeyStatusList, KeyStatus{
- Index: i,
- Status: status,
- DisabledTime: disabledTime,
- Reason: reason,
- KeyPreview: keyPreview,
- })
- }
-
- // Apply status filter if specified
- var filteredKeyStatusList []KeyStatus
- if request.Status != nil {
- for _, keyStatus := range allKeyStatusList {
- if keyStatus.Status == *request.Status {
- filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
- }
- }
- } else {
- filteredKeyStatusList = allKeyStatusList
- }
-
- // Calculate pagination based on filtered results
- filteredTotal := len(filteredKeyStatusList)
- totalPages := (filteredTotal + pageSize - 1) / pageSize
- if totalPages == 0 {
- totalPages = 1
- }
- if page > totalPages {
- page = totalPages
- }
-
- // Calculate range for current page
- start := (page - 1) * pageSize
- end := start + pageSize
- if end > filteredTotal {
- end = filteredTotal
- }
-
- // Get the page data
- var pageKeyStatusList []KeyStatus
- if start < filteredTotal {
- pageKeyStatusList = filteredKeyStatusList[start:end]
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": MultiKeyStatusResponse{
- Keys: pageKeyStatusList,
- Total: filteredTotal, // Total of filtered results
- Page: page,
- PageSize: pageSize,
- TotalPages: totalPages,
- EnabledCount: enabledCount, // Overall statistics
- ManualDisabledCount: manualDisabledCount, // Overall statistics
- AutoDisabledCount: autoDisabledCount, // Overall statistics
- },
- })
- return
-
- case "disable_key":
- if request.KeyIndex == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "未指定要禁用的密钥索引",
- })
- return
- }
-
- keyIndex := *request.KeyIndex
- if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "密钥索引超出范围",
- })
- return
- }
-
- if channel.ChannelInfo.MultiKeyStatusList == nil {
- channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
- }
- if channel.ChannelInfo.MultiKeyDisabledTime == nil {
- channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
- }
- if channel.ChannelInfo.MultiKeyDisabledReason == nil {
- channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
- }
-
- channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "密钥已禁用",
- })
- return
-
- case "enable_key":
- if request.KeyIndex == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "未指定要启用的密钥索引",
- })
- return
- }
-
- keyIndex := *request.KeyIndex
- if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "密钥索引超出范围",
- })
- return
- }
-
- // 从状态列表中删除该密钥的记录,使其回到默认启用状态
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
- }
- if channel.ChannelInfo.MultiKeyDisabledTime != nil {
- delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
- }
- if channel.ChannelInfo.MultiKeyDisabledReason != nil {
- delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
- }
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "密钥已启用",
- })
- return
-
- case "enable_all_keys":
- // 清空所有禁用状态,使所有密钥回到默认启用状态
- var enabledCount int
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
- }
-
- channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
- channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
- channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
- })
- return
-
- case "disable_all_keys":
- // 禁用所有启用的密钥
- if channel.ChannelInfo.MultiKeyStatusList == nil {
- channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
- }
- if channel.ChannelInfo.MultiKeyDisabledTime == nil {
- channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
- }
- if channel.ChannelInfo.MultiKeyDisabledReason == nil {
- channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
- }
-
- var disabledCount int
- for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
- status := 1 // default enabled
- if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
- status = s
- }
-
- // 只禁用当前启用的密钥
- if status == 1 {
- channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
- disabledCount++
- }
- }
-
- if disabledCount == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "没有可禁用的密钥",
- })
- return
- }
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
- })
- return
-
- case "delete_key":
- if request.KeyIndex == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "未指定要删除的密钥索引",
- })
- return
- }
-
- keyIndex := *request.KeyIndex
- if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "密钥索引超出范围",
- })
- return
- }
-
- keys := channel.GetKeys()
- var remainingKeys []string
- var newStatusList = make(map[int]int)
- var newDisabledTime = make(map[int]int64)
- var newDisabledReason = make(map[int]string)
-
- newIndex := 0
- for i, key := range keys {
- // 跳过要删除的密钥
- if i == keyIndex {
- continue
- }
-
- remainingKeys = append(remainingKeys, key)
-
- // 保留其他密钥的状态信息,重新索引
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
- newStatusList[newIndex] = status
- }
- }
- if channel.ChannelInfo.MultiKeyDisabledTime != nil {
- if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
- newDisabledTime[newIndex] = t
- }
- }
- if channel.ChannelInfo.MultiKeyDisabledReason != nil {
- if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
- newDisabledReason[newIndex] = r
- }
- }
- newIndex++
- }
-
- if len(remainingKeys) == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "不能删除最后一个密钥",
- })
- return
- }
-
- // Update channel with remaining keys
- channel.Key = strings.Join(remainingKeys, "\n")
- channel.ChannelInfo.MultiKeySize = len(remainingKeys)
- channel.ChannelInfo.MultiKeyStatusList = newStatusList
- channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
- channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "密钥已删除",
- })
- return
-
- case "delete_disabled_keys":
- keys := channel.GetKeys()
- var remainingKeys []string
- var deletedCount int
- var newStatusList = make(map[int]int)
- var newDisabledTime = make(map[int]int64)
- var newDisabledReason = make(map[int]string)
-
- newIndex := 0
- for i, key := range keys {
- status := 1 // default enabled
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
- status = s
- }
- }
-
- // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
- if status == 3 {
- deletedCount++
- } else {
- remainingKeys = append(remainingKeys, key)
- // 保留非自动禁用密钥的状态信息,重新索引
- if status != 1 {
- newStatusList[newIndex] = status
- if channel.ChannelInfo.MultiKeyDisabledTime != nil {
- if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
- newDisabledTime[newIndex] = t
- }
- }
- if channel.ChannelInfo.MultiKeyDisabledReason != nil {
- if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
- newDisabledReason[newIndex] = r
- }
- }
- }
- newIndex++
- }
- }
-
- if deletedCount == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "没有需要删除的自动禁用密钥",
- })
- return
- }
-
- // Update channel with remaining keys
- channel.Key = strings.Join(remainingKeys, "\n")
- channel.ChannelInfo.MultiKeySize = len(remainingKeys)
- channel.ChannelInfo.MultiKeyStatusList = newStatusList
- channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
- channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
-
- err = channel.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- model.InitChannelCache()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
- "data": deletedCount,
- })
- return
-
- default:
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "不支持的操作",
- })
- return
- }
-}
diff --git a/new-api/controller/console_migrate.go b/new-api/controller/console_migrate.go
deleted file mode 100644
index d79181c08b470671ddd6adbbaa9424efb3273d6d..0000000000000000000000000000000000000000
--- a/new-api/controller/console_migrate.go
+++ /dev/null
@@ -1,104 +0,0 @@
-// 用于迁移检测的旧键,该文件下个版本会删除
-
-package controller
-
-import (
- "encoding/json"
- "net/http"
- "one-api/common"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
-)
-
-// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
-func MigrateConsoleSetting(c *gin.Context) {
- // 读取全部 option
- opts, err := model.AllOption()
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
- return
- }
- // 建立 map
- valMap := map[string]string{}
- for _, o := range opts {
- valMap[o.Key] = o.Value
- }
-
- // 处理 APIInfo
- if v := valMap["ApiInfo"]; v != "" {
- var arr []map[string]interface{}
- if err := json.Unmarshal([]byte(v), &arr); err == nil {
- if len(arr) > 50 {
- arr = arr[:50]
- }
- bytes, _ := json.Marshal(arr)
- model.UpdateOption("console_setting.api_info", string(bytes))
- }
- model.UpdateOption("ApiInfo", "")
- }
- // Announcements 直接搬
- if v := valMap["Announcements"]; v != "" {
- model.UpdateOption("console_setting.announcements", v)
- model.UpdateOption("Announcements", "")
- }
- // FAQ 转换
- if v := valMap["FAQ"]; v != "" {
- var arr []map[string]interface{}
- if err := json.Unmarshal([]byte(v), &arr); err == nil {
- out := []map[string]interface{}{}
- for _, item := range arr {
- q, _ := item["question"].(string)
- if q == "" {
- q, _ = item["title"].(string)
- }
- a, _ := item["answer"].(string)
- if a == "" {
- a, _ = item["content"].(string)
- }
- if q != "" && a != "" {
- out = append(out, map[string]interface{}{"question": q, "answer": a})
- }
- }
- if len(out) > 50 {
- out = out[:50]
- }
- bytes, _ := json.Marshal(out)
- model.UpdateOption("console_setting.faq", string(bytes))
- }
- model.UpdateOption("FAQ", "")
- }
- // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
- url := valMap["UptimeKumaUrl"]
- slug := valMap["UptimeKumaSlug"]
- if url != "" && slug != "" {
- // 仅当同时存在 URL 与 Slug 时才进行迁移
- groups := []map[string]interface{}{
- {
- "id": 1,
- "categoryName": "old",
- "url": url,
- "slug": slug,
- "description": "",
- },
- }
- bytes, _ := json.Marshal(groups)
- model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
- }
- // 清空旧键内容
- if url != "" {
- model.UpdateOption("UptimeKumaUrl", "")
- }
- if slug != "" {
- model.UpdateOption("UptimeKumaSlug", "")
- }
-
- // 删除旧键记录
- oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
- model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
-
- // 重新加载 OptionMap
- model.InitOptionMap()
- common.SysLog("console setting migrated")
- c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
-}
diff --git a/new-api/controller/github.go b/new-api/controller/github.go
deleted file mode 100644
index f18095e96d7818a5e221c7f988b327c76c64dcc0..0000000000000000000000000000000000000000
--- a/new-api/controller/github.go
+++ /dev/null
@@ -1,239 +0,0 @@
-package controller
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-type GitHubOAuthResponse struct {
- AccessToken string `json:"access_token"`
- Scope string `json:"scope"`
- TokenType string `json:"token_type"`
-}
-
-type GitHubUser struct {
- Login string `json:"login"`
- Name string `json:"name"`
- Email string `json:"email"`
-}
-
-func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
- if code == "" {
- return nil, errors.New("无效的参数")
- }
- values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
- jsonData, err := json.Marshal(values)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- res, err := client.Do(req)
- if err != nil {
- common.SysLog(err.Error())
- return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
- }
- defer res.Body.Close()
- var oAuthResponse GitHubOAuthResponse
- err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
- if err != nil {
- return nil, err
- }
- req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
- res2, err := client.Do(req)
- if err != nil {
- common.SysLog(err.Error())
- return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
- }
- defer res2.Body.Close()
- var githubUser GitHubUser
- err = json.NewDecoder(res2.Body).Decode(&githubUser)
- if err != nil {
- return nil, err
- }
- if githubUser.Login == "" {
- return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
- }
- return &githubUser, nil
-}
-
-func GitHubOAuth(c *gin.Context) {
- session := sessions.Default(c)
- state := c.Query("state")
- if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "state is empty or not same",
- })
- return
- }
- username := session.Get("username")
- if username != nil {
- GitHubBind(c)
- return
- }
-
- if !common.GitHubOAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 GitHub 登录以及注册",
- })
- return
- }
- code := c.Query("code")
- githubUser, err := getGitHubUserInfoByCode(code)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user := model.User{
- GitHubId: githubUser.Login,
- }
- // IsGitHubIdAlreadyTaken is unscoped
- if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
- // FillUserByGitHubId is scoped
- err := user.FillUserByGitHubId()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- // if user.Id == 0 , user has been deleted
- if user.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已注销",
- })
- return
- }
- } else {
- if common.RegisterEnabled {
- user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
- if githubUser.Name != "" {
- user.DisplayName = githubUser.Name
- } else {
- user.DisplayName = "GitHub User"
- }
- user.Email = githubUser.Email
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
- affCode := session.Get("aff")
- inviterId := 0
- if affCode != nil {
- inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
- }
-
- if err := user.Insert(inviterId); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员关闭了新用户注册",
- })
- return
- }
- }
-
- if user.Status != common.UserStatusEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "用户已被封禁",
- "success": false,
- })
- return
- }
- setupLogin(&user, c)
-}
-
-func GitHubBind(c *gin.Context) {
- if !common.GitHubOAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 GitHub 登录以及注册",
- })
- return
- }
- code := c.Query("code")
- githubUser, err := getGitHubUserInfoByCode(code)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user := model.User{
- GitHubId: githubUser.Login,
- }
- if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该 GitHub 账户已被绑定",
- })
- return
- }
- session := sessions.Default(c)
- id := session.Get("id")
- // id := c.GetInt("id") // critical bug!
- user.Id = id.(int)
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user.GitHubId = githubUser.Login
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "bind",
- })
- return
-}
-
-func GenerateOAuthCode(c *gin.Context) {
- session := sessions.Default(c)
- state := common.GetRandomString(12)
- affCode := c.Query("aff")
- if affCode != "" {
- session.Set("aff", affCode)
- }
- session.Set("oauth_state", state)
- err := session.Save()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": state,
- })
-}
diff --git a/new-api/controller/group.go b/new-api/controller/group.go
deleted file mode 100644
index 7ee7a637ef0d1844d41ab84b49f08df61513d5a7..0000000000000000000000000000000000000000
--- a/new-api/controller/group.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/model"
- "one-api/setting"
- "one-api/setting/ratio_setting"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetGroups(c *gin.Context) {
- groupNames := make([]string, 0)
- for groupName := range ratio_setting.GetGroupRatioCopy() {
- groupNames = append(groupNames, groupName)
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": groupNames,
- })
-}
-
-func GetUserGroups(c *gin.Context) {
- usableGroups := make(map[string]map[string]interface{})
- userGroup := ""
- userId := c.GetInt("id")
- userGroup, _ = model.GetUserGroup(userId, false)
- for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
- // UserUsableGroups contains the groups that the user can use
- userUsableGroups := setting.GetUserUsableGroups(userGroup)
- if desc, ok := userUsableGroups[groupName]; ok {
- usableGroups[groupName] = map[string]interface{}{
- "ratio": ratio,
- "desc": desc,
- }
- }
- }
- if setting.GroupInUserUsableGroups("auto") {
- usableGroups["auto"] = map[string]interface{}{
- "ratio": "自动",
- "desc": setting.GetUsableGroupDescription("auto"),
- }
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": usableGroups,
- })
-}
diff --git a/new-api/controller/image.go b/new-api/controller/image.go
deleted file mode 100644
index 9d66047e8ee164315ecc1bd8dda5a8919ac41f94..0000000000000000000000000000000000000000
--- a/new-api/controller/image.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package controller
-
-import (
- "github.com/gin-gonic/gin"
-)
-
-func GetImage(c *gin.Context) {
-
-}
diff --git a/new-api/controller/linuxdo.go b/new-api/controller/linuxdo.go
deleted file mode 100644
index d3f9667950add4e8ad01274bd451ae62108a664f..0000000000000000000000000000000000000000
--- a/new-api/controller/linuxdo.go
+++ /dev/null
@@ -1,267 +0,0 @@
-package controller
-
-import (
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "net/url"
- "one-api/common"
- "one-api/model"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-type LinuxdoUser struct {
- Id int `json:"id"`
- Username string `json:"username"`
- Name string `json:"name"`
- Active bool `json:"active"`
- TrustLevel int `json:"trust_level"`
- Silenced bool `json:"silenced"`
-}
-
-func LinuxDoBind(c *gin.Context) {
- if !common.LinuxDOOAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 Linux DO 登录以及注册",
- })
- return
- }
-
- code := c.Query("code")
- linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- user := model.User{
- LinuxDOId: strconv.Itoa(linuxdoUser.Id),
- }
-
- if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该 Linux DO 账户已被绑定",
- })
- return
- }
-
- session := sessions.Default(c)
- id := session.Get("id")
- user.Id = id.(int)
-
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "bind",
- })
-}
-
-func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
- if code == "" {
- return nil, errors.New("invalid code")
- }
-
- // Get access token using Basic auth
- tokenEndpoint := "https://connect.linux.do/oauth2/token"
- credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
- basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
-
- // Get redirect URI from request
- scheme := "http"
- if c.Request.TLS != nil {
- scheme = "https"
- }
- redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
-
- data := url.Values{}
- data.Set("grant_type", "authorization_code")
- data.Set("code", code)
- data.Set("redirect_uri", redirectURI)
-
- req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
- if err != nil {
- return nil, err
- }
-
- req.Header.Set("Authorization", basicAuth)
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.Header.Set("Accept", "application/json")
-
- client := http.Client{Timeout: 5 * time.Second}
- res, err := client.Do(req)
- if err != nil {
- return nil, errors.New("failed to connect to Linux DO server")
- }
- defer res.Body.Close()
-
- var tokenRes struct {
- AccessToken string `json:"access_token"`
- Message string `json:"message"`
- }
- if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
- return nil, err
- }
-
- if tokenRes.AccessToken == "" {
- return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
- }
-
- // Get user info
- userEndpoint := "https://connect.linux.do/api/user"
- req, err = http.NewRequest("GET", userEndpoint, nil)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
- req.Header.Set("Accept", "application/json")
-
- res2, err := client.Do(req)
- if err != nil {
- return nil, errors.New("failed to get user info from Linux DO")
- }
- defer res2.Body.Close()
-
- var linuxdoUser LinuxdoUser
- if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
- return nil, err
- }
-
- if linuxdoUser.Id == 0 {
- return nil, errors.New("invalid user info returned")
- }
-
- return &linuxdoUser, nil
-}
-
-func LinuxdoOAuth(c *gin.Context) {
- session := sessions.Default(c)
-
- errorCode := c.Query("error")
- if errorCode != "" {
- errorDescription := c.Query("error_description")
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": errorDescription,
- })
- return
- }
-
- state := c.Query("state")
- if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "state is empty or not same",
- })
- return
- }
-
- username := session.Get("username")
- if username != nil {
- LinuxDoBind(c)
- return
- }
-
- if !common.LinuxDOOAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 Linux DO 登录以及注册",
- })
- return
- }
-
- code := c.Query("code")
- linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- user := model.User{
- LinuxDOId: strconv.Itoa(linuxdoUser.Id),
- }
-
- // Check if user exists
- if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
- err := user.FillUserByLinuxDOId()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- if user.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已注销",
- })
- return
- }
- } else {
- if common.RegisterEnabled {
- if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
- user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
- user.DisplayName = linuxdoUser.Name
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
-
- affCode := session.Get("aff")
- inviterId := 0
- if affCode != nil {
- inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
- }
-
- if err := user.Insert(inviterId); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
- })
- return
- }
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员关闭了新用户注册",
- })
- return
- }
- }
-
- if user.Status != common.UserStatusEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "用户已被封禁",
- "success": false,
- })
- return
- }
-
- setupLogin(&user, c)
-}
diff --git a/new-api/controller/log.go b/new-api/controller/log.go
deleted file mode 100644
index d8529d2d715374bdf93bebe01653d97a5bd1ac11..0000000000000000000000000000000000000000
--- a/new-api/controller/log.go
+++ /dev/null
@@ -1,168 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetAllLogs(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- logType, _ := strconv.Atoi(c.Query("type"))
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- username := c.Query("username")
- tokenName := c.Query("token_name")
- modelName := c.Query("model_name")
- channel, _ := strconv.Atoi(c.Query("channel"))
- group := c.Query("group")
- logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(logs)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func GetUserLogs(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- userId := c.GetInt("id")
- logType, _ := strconv.Atoi(c.Query("type"))
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- tokenName := c.Query("token_name")
- modelName := c.Query("model_name")
- group := c.Query("group")
- logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(logs)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func SearchAllLogs(c *gin.Context) {
- keyword := c.Query("keyword")
- logs, err := model.SearchAllLogs(keyword)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
- return
-}
-
-func SearchUserLogs(c *gin.Context) {
- keyword := c.Query("keyword")
- userId := c.GetInt("id")
- logs, err := model.SearchUserLogs(userId, keyword)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
- return
-}
-
-func GetLogByKey(c *gin.Context) {
- key := c.Query("key")
- logs, err := model.GetLogByKey(key)
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
-}
-
-func GetLogsStat(c *gin.Context) {
- logType, _ := strconv.Atoi(c.Query("type"))
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- tokenName := c.Query("token_name")
- username := c.Query("username")
- modelName := c.Query("model_name")
- channel, _ := strconv.Atoi(c.Query("channel"))
- group := c.Query("group")
- stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
- //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "quota": stat.Quota,
- "rpm": stat.Rpm,
- "tpm": stat.Tpm,
- },
- })
- return
-}
-
-func GetLogsSelfStat(c *gin.Context) {
- username := c.GetString("username")
- logType, _ := strconv.Atoi(c.Query("type"))
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- tokenName := c.Query("token_name")
- modelName := c.Query("model_name")
- channel, _ := strconv.Atoi(c.Query("channel"))
- group := c.Query("group")
- quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
- //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "quota": quotaNum.Quota,
- "rpm": quotaNum.Rpm,
- "tpm": quotaNum.Tpm,
- //"token": tokenNum,
- },
- })
- return
-}
-
-func DeleteHistoryLogs(c *gin.Context) {
- targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
- if targetTimestamp == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "target timestamp is required",
- })
- return
- }
- count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": count,
- })
- return
-}
diff --git a/new-api/controller/midjourney.go b/new-api/controller/midjourney.go
deleted file mode 100644
index ecb570cf1d43c183d6a810446a36fa0487f9bdb9..0000000000000000000000000000000000000000
--- a/new-api/controller/midjourney.go
+++ /dev/null
@@ -1,295 +0,0 @@
-package controller
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/system_setting"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func UpdateMidjourneyTaskBulk() {
- //imageModel := "midjourney"
- ctx := context.TODO()
- for {
- time.Sleep(time.Duration(15) * time.Second)
-
- tasks := model.GetAllUnFinishTasks()
- if len(tasks) == 0 {
- continue
- }
-
- logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
- taskChannelM := make(map[int][]string)
- taskM := make(map[string]*model.Midjourney)
- nullTaskIds := make([]int, 0)
- for _, task := range tasks {
- if task.MjId == "" {
- // 统计失败的未完成任务
- nullTaskIds = append(nullTaskIds, task.Id)
- continue
- }
- taskM[task.MjId] = task
- taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
- }
- if len(nullTaskIds) > 0 {
- err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
- } else {
- logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
- }
- }
- if len(taskChannelM) == 0 {
- continue
- }
-
- for channelId, taskIds := range taskChannelM {
- logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- continue
- }
- midjourneyChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
- err := model.MjBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
- }
- continue
- }
- requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
-
- body, _ := json.Marshal(map[string]any{
- "ids": taskIds,
- })
- req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
- continue
- }
- // 设置超时时间
- timeout := time.Second * 15
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("mj-api-secret", midjourneyChannel.Key)
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
- continue
- }
- if resp.StatusCode != http.StatusOK {
- logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- continue
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
- continue
- }
- var responseItems []dto.MidjourneyDto
- err = json.Unmarshal(responseBody, &responseItems)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
- continue
- }
- resp.Body.Close()
- req.Body.Close()
- cancel()
-
- for _, responseItem := range responseItems {
- task := taskM[responseItem.MjId]
-
- useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
- // 如果时间超过一小时,且进度不是100%,则认为任务失败
- if useTime > 3600000 && task.Progress != "100%" {
- responseItem.FailReason = "上游任务超时(超过1小时)"
- responseItem.Status = "FAILURE"
- }
- if !checkMjTaskNeedUpdate(task, responseItem) {
- continue
- }
- task.Code = 1
- task.Progress = responseItem.Progress
- task.PromptEn = responseItem.PromptEn
- task.State = responseItem.State
- task.SubmitTime = responseItem.SubmitTime
- task.StartTime = responseItem.StartTime
- task.FinishTime = responseItem.FinishTime
- task.ImageUrl = responseItem.ImageUrl
- task.Status = responseItem.Status
- task.FailReason = responseItem.FailReason
- if responseItem.Properties != nil {
- propertiesStr, _ := json.Marshal(responseItem.Properties)
- task.Properties = string(propertiesStr)
- }
- if responseItem.Buttons != nil {
- buttonStr, _ := json.Marshal(responseItem.Buttons)
- task.Buttons = string(buttonStr)
- }
- // 映射 VideoUrl
- task.VideoUrl = responseItem.VideoUrl
-
- // 映射 VideoUrls - 将数组序列化为 JSON 字符串
- if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
- videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
- task.VideoUrls = "[]" // 失败时设置为空数组
- } else {
- task.VideoUrls = string(videoUrlsStr)
- }
- } else {
- task.VideoUrls = "" // 空值时清空字段
- }
-
- shouldReturnQuota := false
- if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
- logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
- task.Progress = "100%"
- if task.Quota != 0 {
- shouldReturnQuota = true
- }
- }
- err = task.Update()
- if err != nil {
- logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
- } else {
- if shouldReturnQuota {
- err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
- if err != nil {
- logger.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
- }
- }
-}
-
-func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
- if oldTask.Code != 1 {
- return true
- }
- if oldTask.Progress != newTask.Progress {
- return true
- }
- if oldTask.PromptEn != newTask.PromptEn {
- return true
- }
- if oldTask.State != newTask.State {
- return true
- }
- if oldTask.SubmitTime != newTask.SubmitTime {
- return true
- }
- if oldTask.StartTime != newTask.StartTime {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if oldTask.ImageUrl != newTask.ImageUrl {
- return true
- }
- if oldTask.Status != newTask.Status {
- return true
- }
- if oldTask.FailReason != newTask.FailReason {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if oldTask.Progress != "100%" && newTask.FailReason != "" {
- return true
- }
- // 检查 VideoUrl 是否需要更新
- if oldTask.VideoUrl != newTask.VideoUrl {
- return true
- }
- // 检查 VideoUrls 是否需要更新
- if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
- newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
- if oldTask.VideoUrls != string(newVideoUrlsStr) {
- return true
- }
- } else if oldTask.VideoUrls != "" {
- // 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
- return true
- }
-
- return false
-}
-
-func GetAllMidjourney(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
-
- // 解析其他查询参数
- queryParams := model.TaskQueryParams{
- ChannelID: c.Query("channel_id"),
- MjID: c.Query("mj_id"),
- StartTimestamp: c.Query("start_timestamp"),
- EndTimestamp: c.Query("end_timestamp"),
- }
-
- items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.CountAllTasks(queryParams)
-
- if setting.MjForwardUrlEnabled {
- for i, midjourney := range items {
- midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
- items[i] = midjourney
- }
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
-}
-
-func GetUserMidjourney(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
-
- userId := c.GetInt("id")
-
- queryParams := model.TaskQueryParams{
- MjID: c.Query("mj_id"),
- StartTimestamp: c.Query("start_timestamp"),
- EndTimestamp: c.Query("end_timestamp"),
- }
-
- items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.CountAllUserTask(userId, queryParams)
-
- if setting.MjForwardUrlEnabled {
- for i, midjourney := range items {
- midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
- items[i] = midjourney
- }
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
-}
diff --git a/new-api/controller/misc.go b/new-api/controller/misc.go
deleted file mode 100644
index fb1e1b6b83eb4b2c1df51d7112732284add13093..0000000000000000000000000000000000000000
--- a/new-api/controller/misc.go
+++ /dev/null
@@ -1,314 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/middleware"
- "one-api/model"
- "one-api/setting"
- "one-api/setting/console_setting"
- "one-api/setting/operation_setting"
- "one-api/setting/system_setting"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func TestStatus(c *gin.Context) {
- err := model.PingDB()
- if err != nil {
- c.JSON(http.StatusServiceUnavailable, gin.H{
- "success": false,
- "message": "数据库连接失败",
- })
- return
- }
- // 获取HTTP统计信息
- httpStats := middleware.GetStats()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Server is running",
- "http_stats": httpStats,
- })
- return
-}
-
-func GetStatus(c *gin.Context) {
-
- cs := console_setting.GetConsoleSetting()
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
-
- passkeySetting := system_setting.GetPasskeySettings()
-
- data := gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": system_setting.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
- "default_use_auto_group": setting.DefaultUseAutoGroup,
-
- "usd_exchange_rate": operation_setting.USDExchangeRate,
- "price": operation_setting.Price,
- "stripe_unit_price": setting.StripeUnitPrice,
-
- // 面板启用开关
- "api_info_enabled": cs.ApiInfoEnabled,
- "uptime_kuma_enabled": cs.UptimeKumaEnabled,
- "announcements_enabled": cs.AnnouncementsEnabled,
- "faq_enabled": cs.FAQEnabled,
-
- // 模块管理配置
- "HeaderNavModules": common.OptionMap["HeaderNavModules"],
- "SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
-
- "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
- "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
- "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
- "passkey_login": passkeySetting.Enabled,
- "passkey_display_name": passkeySetting.RPDisplayName,
- "passkey_rp_id": passkeySetting.RPID,
- "passkey_origins": passkeySetting.Origins,
- "passkey_allow_insecure": passkeySetting.AllowInsecureOrigin,
- "passkey_user_verification": passkeySetting.UserVerification,
- "passkey_attachment": passkeySetting.AttachmentPreference,
- "setup": constant.Setup,
- }
-
- // 根据启用状态注入可选内容
- if cs.ApiInfoEnabled {
- data["api_info"] = console_setting.GetApiInfo()
- }
- if cs.AnnouncementsEnabled {
- data["announcements"] = console_setting.GetAnnouncements()
- }
- if cs.FAQEnabled {
- data["faq"] = console_setting.GetFAQ()
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": data,
- })
- return
-}
-
-func GetNotice(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": common.OptionMap["Notice"],
- })
- return
-}
-
-func GetAbout(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": common.OptionMap["About"],
- })
- return
-}
-
-func GetMidjourney(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": common.OptionMap["Midjourney"],
- })
- return
-}
-
-func GetHomePageContent(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": common.OptionMap["HomePageContent"],
- })
- return
-}
-
-func SendEmailVerification(c *gin.Context) {
- email := c.Query("email")
- if err := common.Validate.Var(email, "required,email"); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- parts := strings.Split(email, "@")
- if len(parts) != 2 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的邮箱地址",
- })
- return
- }
- localPart := parts[0]
- domainPart := parts[1]
- if common.EmailDomainRestrictionEnabled {
- allowed := false
- for _, domain := range common.EmailDomainWhitelist {
- if domainPart == domain {
- allowed = true
- break
- }
- }
- if !allowed {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
- })
- return
- }
- }
- if common.EmailAliasRestrictionEnabled {
- containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
- if containsSpecialSymbols {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
- })
- return
- }
- }
-
- if model.IsEmailAlreadyTaken(email) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "邮箱地址已被占用",
- })
- return
- }
- code := common.GenerateVerificationCode(6)
- common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
- subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
- content := fmt.Sprintf("您好,你正在进行%s邮箱验证。
"+
- "您的验证码为: %s
"+
- "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes)
- err := common.SendEmail(subject, email, content)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func SendPasswordResetEmail(c *gin.Context) {
- email := c.Query("email")
- if err := common.Validate.Var(email, "required,email"); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- if !model.IsEmailAlreadyTaken(email) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该邮箱地址未注册",
- })
- return
- }
- code := common.GenerateVerificationCode(0)
- common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
- subject := fmt.Sprintf("%s密码重置", common.SystemName)
- content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+
- "点击 此处 进行密码重置。
"+
- "如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
"+
- "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes)
- err := common.SendEmail(subject, email, content)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-type PasswordResetRequest struct {
- Email string `json:"email"`
- Token string `json:"token"`
-}
-
-func ResetPassword(c *gin.Context) {
- var req PasswordResetRequest
- err := json.NewDecoder(c.Request.Body).Decode(&req)
- if req.Email == "" || req.Token == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "重置链接非法或已过期",
- })
- return
- }
- password := common.GenerateVerificationCode(12)
- err = model.ResetUserPasswordByEmail(req.Email, password)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- common.DeleteKey(req.Email, common.PasswordResetPurpose)
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": password,
- })
- return
-}
diff --git a/new-api/controller/missing_models.go b/new-api/controller/missing_models.go
deleted file mode 100644
index c18afba25e7934d9779fd057faf72411963b6092..0000000000000000000000000000000000000000
--- a/new-api/controller/missing_models.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
-)
-
-// GetMissingModels returns the list of model names that are referenced by channels
-// but do not have corresponding records in the models meta table.
-// This helps administrators quickly discover models that need configuration.
-func GetMissingModels(c *gin.Context) {
- missing, err := model.GetMissingModels()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": missing,
- })
-}
diff --git a/new-api/controller/model.go b/new-api/controller/model.go
deleted file mode 100644
index 07a77fac76ff6525669d2ca886f8666407dfbafe..0000000000000000000000000000000000000000
--- a/new-api/controller/model.go
+++ /dev/null
@@ -1,261 +0,0 @@
-package controller
-
-import (
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/samber/lo"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- "one-api/relay"
- "one-api/relay/channel/ai360"
- "one-api/relay/channel/lingyiwanwu"
- "one-api/relay/channel/minimax"
- "one-api/relay/channel/moonshot"
- relaycommon "one-api/relay/common"
- "one-api/setting"
- "time"
-)
-
-// https://platform.openai.com/docs/api-reference/models/list
-
-var openAIModels []dto.OpenAIModels
-var openAIModelsMap map[string]dto.OpenAIModels
-var channelId2Models map[int][]string
-
-func init() {
- // https://platform.openai.com/docs/models/model-endpoint-compatibility
- for i := 0; i < constant.APITypeDummy; i++ {
- if i == constant.APITypeAIProxyLibrary {
- continue
- }
- adaptor := relay.GetAdaptor(i)
- channelName := adaptor.GetChannelName()
- modelNames := adaptor.GetModelList()
- for _, modelName := range modelNames {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: channelName,
- })
- }
- }
- for _, modelName := range ai360.ModelList {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: ai360.ChannelName,
- })
- }
- for _, modelName := range moonshot.ModelList {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: moonshot.ChannelName,
- })
- }
- for _, modelName := range lingyiwanwu.ModelList {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: lingyiwanwu.ChannelName,
- })
- }
- for _, modelName := range minimax.ModelList {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: minimax.ChannelName,
- })
- }
- for modelName, _ := range constant.MidjourneyModel2Action {
- openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "midjourney",
- })
- }
- openAIModelsMap = make(map[string]dto.OpenAIModels)
- for _, aiModel := range openAIModels {
- openAIModelsMap[aiModel.Id] = aiModel
- }
- channelId2Models = make(map[int][]string)
- for i := 1; i <= constant.ChannelTypeDummy; i++ {
- apiType, success := common.ChannelType2APIType(i)
- if !success || apiType == constant.APITypeAIProxyLibrary {
- continue
- }
- meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
- ChannelType: i,
- }}
- adaptor := relay.GetAdaptor(apiType)
- adaptor.Init(meta)
- channelId2Models[i] = adaptor.GetModelList()
- }
- openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
- return m.Id
- })
-}
-
-func ListModels(c *gin.Context, modelType int) {
- userOpenAiModels := make([]dto.OpenAIModels, 0)
-
- modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
- if modelLimitEnable {
- s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
- var tokenModelLimit map[string]bool
- if ok {
- tokenModelLimit = s.(map[string]bool)
- } else {
- tokenModelLimit = map[string]bool{}
- }
- for allowModel, _ := range tokenModelLimit {
- if oaiModel, ok := openAIModelsMap[allowModel]; ok {
- oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
- userOpenAiModels = append(userOpenAiModels, oaiModel)
- } else {
- userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
- Id: allowModel,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "custom",
- SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
- })
- }
- }
- } else {
- userId := c.GetInt("id")
- userGroup, err := model.GetUserGroup(userId, false)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "get user group failed",
- })
- return
- }
- group := userGroup
- tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
- if tokenGroup != "" {
- group = tokenGroup
- }
- var models []string
- if tokenGroup == "auto" {
- for _, autoGroup := range setting.AutoGroups {
- groupModels := model.GetGroupEnabledModels(autoGroup)
- for _, g := range groupModels {
- if !common.StringsContains(models, g) {
- models = append(models, g)
- }
- }
- }
- } else {
- models = model.GetGroupEnabledModels(group)
- }
- for _, modelName := range models {
- if oaiModel, ok := openAIModelsMap[modelName]; ok {
- oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
- userOpenAiModels = append(userOpenAiModels, oaiModel)
- } else {
- userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "custom",
- SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
- })
- }
- }
- }
- switch modelType {
- case constant.ChannelTypeAnthropic:
- useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
- for i, model := range userOpenAiModels {
- useranthropicModels[i] = dto.AnthropicModel{
- ID: model.Id,
- CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
- DisplayName: model.Id,
- Type: "model",
- }
- }
- c.JSON(200, gin.H{
- "data": useranthropicModels,
- "first_id": useranthropicModels[0].ID,
- "has_more": false,
- "last_id": useranthropicModels[len(useranthropicModels)-1].ID,
- })
- case constant.ChannelTypeGemini:
- userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
- for i, model := range userOpenAiModels {
- userGeminiModels[i] = dto.GeminiModel{
- Name: model.Id,
- DisplayName: model.Id,
- }
- }
- c.JSON(200, gin.H{
- "models": userGeminiModels,
- "nextPageToken": nil,
- })
- default:
- c.JSON(200, gin.H{
- "success": true,
- "data": userOpenAiModels,
- "object": "list",
- })
- }
-}
-
-func ChannelListModels(c *gin.Context) {
- c.JSON(200, gin.H{
- "success": true,
- "data": openAIModels,
- })
-}
-
-func DashboardListModels(c *gin.Context) {
- c.JSON(200, gin.H{
- "success": true,
- "data": channelId2Models,
- })
-}
-
-func EnabledListModels(c *gin.Context) {
- c.JSON(200, gin.H{
- "success": true,
- "data": model.GetEnabledModels(),
- })
-}
-
-func RetrieveModel(c *gin.Context, modelType int) {
- modelId := c.Param("model")
- if aiModel, ok := openAIModelsMap[modelId]; ok {
- switch modelType {
- case constant.ChannelTypeAnthropic:
- c.JSON(200, dto.AnthropicModel{
- ID: aiModel.Id,
- CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
- DisplayName: aiModel.Id,
- Type: "model",
- })
- default:
- c.JSON(200, aiModel)
- }
- } else {
- openAIError := dto.OpenAIError{
- Message: fmt.Sprintf("The model '%s' does not exist", modelId),
- Type: "invalid_request_error",
- Param: "model",
- Code: "model_not_found",
- }
- c.JSON(200, gin.H{
- "error": openAIError,
- })
- }
-}
diff --git a/new-api/controller/model_meta.go b/new-api/controller/model_meta.go
deleted file mode 100644
index bf879d432306eaf1e67f1d74732d7cb2d4a0f5a8..0000000000000000000000000000000000000000
--- a/new-api/controller/model_meta.go
+++ /dev/null
@@ -1,330 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "sort"
- "strconv"
- "strings"
-
- "one-api/common"
- "one-api/constant"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
-)
-
-// GetAllModelsMeta 获取模型列表(分页)
-func GetAllModelsMeta(c *gin.Context) {
-
- pageInfo := common.GetPageQuery(c)
- modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // 批量填充附加字段,提升列表接口性能
- enrichModels(modelsMeta)
- var total int64
- model.DB.Model(&model.Model{}).Count(&total)
-
- // 统计供应商计数(全部数据,不受分页影响)
- vendorCounts, _ := model.GetVendorModelCounts()
-
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(modelsMeta)
- common.ApiSuccess(c, gin.H{
- "items": modelsMeta,
- "total": total,
- "page": pageInfo.GetPage(),
- "page_size": pageInfo.GetPageSize(),
- "vendor_counts": vendorCounts,
- })
-}
-
-// SearchModelsMeta 搜索模型列表
-func SearchModelsMeta(c *gin.Context) {
-
- keyword := c.Query("keyword")
- vendor := c.Query("vendor")
- pageInfo := common.GetPageQuery(c)
-
- modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // 批量填充附加字段,提升列表接口性能
- enrichModels(modelsMeta)
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(modelsMeta)
- common.ApiSuccess(c, pageInfo)
-}
-
-// GetModelMeta 根据 ID 获取单条模型信息
-func GetModelMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- var m model.Model
- if err := model.DB.First(&m, id).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- enrichModels([]*model.Model{&m})
- common.ApiSuccess(c, &m)
-}
-
-// CreateModelMeta 新建模型
-func CreateModelMeta(c *gin.Context) {
- var m model.Model
- if err := c.ShouldBindJSON(&m); err != nil {
- common.ApiError(c, err)
- return
- }
- if m.ModelName == "" {
- common.ApiErrorMsg(c, "模型名称不能为空")
- return
- }
- // 名称冲突检查
- if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "模型名称已存在")
- return
- }
-
- if err := m.Insert(); err != nil {
- common.ApiError(c, err)
- return
- }
- model.RefreshPricing()
- common.ApiSuccess(c, &m)
-}
-
-// UpdateModelMeta 更新模型
-func UpdateModelMeta(c *gin.Context) {
- statusOnly := c.Query("status_only") == "true"
-
- var m model.Model
- if err := c.ShouldBindJSON(&m); err != nil {
- common.ApiError(c, err)
- return
- }
- if m.Id == 0 {
- common.ApiErrorMsg(c, "缺少模型 ID")
- return
- }
-
- if statusOnly {
- // 只更新状态,防止误清空其他字段
- if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- } else {
- // 名称冲突检查
- if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "模型名称已存在")
- return
- }
-
- if err := m.Update(); err != nil {
- common.ApiError(c, err)
- return
- }
- }
- model.RefreshPricing()
- common.ApiSuccess(c, &m)
-}
-
-// DeleteModelMeta 删除模型
-func DeleteModelMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- model.RefreshPricing()
- common.ApiSuccess(c, nil)
-}
-
-// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
-func enrichModels(models []*model.Model) {
- if len(models) == 0 {
- return
- }
-
- // 1) 拆分精确与规则匹配
- exactNames := make([]string, 0)
- exactIdx := make(map[string][]int) // modelName -> indices in models
- ruleIndices := make([]int, 0)
- for i, m := range models {
- if m == nil {
- continue
- }
- if m.NameRule == model.NameRuleExact {
- exactNames = append(exactNames, m.ModelName)
- exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
- } else {
- ruleIndices = append(ruleIndices, i)
- }
- }
-
- // 2) 批量查询精确模型的绑定渠道
- channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
-
- // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
- for name, indices := range exactIdx {
- chs := channelsByModel[name]
- for _, idx := range indices {
- mm := models[idx]
- if mm.Endpoints == "" {
- eps := model.GetModelSupportEndpointTypes(mm.ModelName)
- if b, err := json.Marshal(eps); err == nil {
- mm.Endpoints = string(b)
- }
- }
- mm.BoundChannels = chs
- mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
- mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
- }
- }
-
- if len(ruleIndices) == 0 {
- return
- }
-
- // 4) 一次性读取定价缓存,内存匹配所有规则模型
- pricings := model.GetPricing()
-
- // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
- matchedNamesByIdx := make(map[int][]string)
- endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
- groupSetByIdx := make(map[int]map[string]struct{})
- quotaSetByIdx := make(map[int]map[int]struct{})
-
- for _, p := range pricings {
- for _, idx := range ruleIndices {
- mm := models[idx]
- var matched bool
- switch mm.NameRule {
- case model.NameRulePrefix:
- matched = strings.HasPrefix(p.ModelName, mm.ModelName)
- case model.NameRuleSuffix:
- matched = strings.HasSuffix(p.ModelName, mm.ModelName)
- case model.NameRuleContains:
- matched = strings.Contains(p.ModelName, mm.ModelName)
- }
- if !matched {
- continue
- }
- matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
-
- es := endpointSetByIdx[idx]
- if es == nil {
- es = make(map[constant.EndpointType]struct{})
- endpointSetByIdx[idx] = es
- }
- for _, et := range p.SupportedEndpointTypes {
- es[et] = struct{}{}
- }
-
- gs := groupSetByIdx[idx]
- if gs == nil {
- gs = make(map[string]struct{})
- groupSetByIdx[idx] = gs
- }
- for _, g := range p.EnableGroup {
- gs[g] = struct{}{}
- }
-
- qs := quotaSetByIdx[idx]
- if qs == nil {
- qs = make(map[int]struct{})
- quotaSetByIdx[idx] = qs
- }
- qs[p.QuotaType] = struct{}{}
- }
- }
-
- // 5) 汇总所有匹配到的模型名称,批量查询一次渠道
- allMatchedSet := make(map[string]struct{})
- for _, names := range matchedNamesByIdx {
- for _, n := range names {
- allMatchedSet[n] = struct{}{}
- }
- }
- allMatched := make([]string, 0, len(allMatchedSet))
- for n := range allMatchedSet {
- allMatched = append(allMatched, n)
- }
- matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
-
- // 6) 回填每个规则模型的并集信息
- for _, idx := range ruleIndices {
- mm := models[idx]
-
- // 端点并集 -> 序列化
- if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
- eps := make([]constant.EndpointType, 0, len(es))
- for et := range es {
- eps = append(eps, et)
- }
- if b, err := json.Marshal(eps); err == nil {
- mm.Endpoints = string(b)
- }
- }
-
- // 分组并集
- if gs, ok := groupSetByIdx[idx]; ok {
- groups := make([]string, 0, len(gs))
- for g := range gs {
- groups = append(groups, g)
- }
- mm.EnableGroups = groups
- }
-
- // 配额类型集合(保持去重并排序)
- if qs, ok := quotaSetByIdx[idx]; ok {
- arr := make([]int, 0, len(qs))
- for k := range qs {
- arr = append(arr, k)
- }
- sort.Ints(arr)
- mm.QuotaTypes = arr
- }
-
- // 渠道并集
- names := matchedNamesByIdx[idx]
- channelSet := make(map[string]model.BoundChannel)
- for _, n := range names {
- for _, ch := range matchedChannelsByModel[n] {
- key := ch.Name + "_" + strconv.Itoa(ch.Type)
- channelSet[key] = ch
- }
- }
- if len(channelSet) > 0 {
- chs := make([]model.BoundChannel, 0, len(channelSet))
- for _, ch := range channelSet {
- chs = append(chs, ch)
- }
- mm.BoundChannels = chs
- }
-
- // 匹配信息
- mm.MatchedModels = names
- mm.MatchedCount = len(names)
- }
-}
diff --git a/new-api/controller/model_sync.go b/new-api/controller/model_sync.go
deleted file mode 100644
index d539e9c17c3e5024e2c1ee2de6dd2cf44e304900..0000000000000000000000000000000000000000
--- a/new-api/controller/model_sync.go
+++ /dev/null
@@ -1,604 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "math/rand"
- "net"
- "net/http"
- "strings"
- "sync"
- "time"
-
- "one-api/common"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
- "gorm.io/gorm"
-)
-
-// 上游地址
-const (
- upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
- upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
-)
-
-func normalizeLocale(locale string) (string, bool) {
- l := strings.ToLower(strings.TrimSpace(locale))
- switch l {
- case "en", "zh", "ja":
- return l, true
- default:
- return "", false
- }
-}
-
-func getUpstreamBase() string {
- return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata")
-}
-
-func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) {
- base := strings.TrimRight(getUpstreamBase(), "/")
- if l, ok := normalizeLocale(locale); ok && l != "" {
- return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l),
- fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l)
- }
- return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base)
-}
-
-type upstreamEnvelope[T any] struct {
- Success bool `json:"success"`
- Message string `json:"message"`
- Data []T `json:"data"`
-}
-
-type upstreamModel struct {
- Description string `json:"description"`
- Endpoints json.RawMessage `json:"endpoints"`
- Icon string `json:"icon"`
- ModelName string `json:"model_name"`
- NameRule int `json:"name_rule"`
- Status int `json:"status"`
- Tags string `json:"tags"`
- VendorName string `json:"vendor_name"`
-}
-
-type upstreamVendor struct {
- Description string `json:"description"`
- Icon string `json:"icon"`
- Name string `json:"name"`
- Status int `json:"status"`
-}
-
-var (
- etagCache = make(map[string]string)
- bodyCache = make(map[string][]byte)
- cacheMutex sync.RWMutex
-)
-
-type overwriteField struct {
- ModelName string `json:"model_name"`
- Fields []string `json:"fields"`
-}
-
-type syncRequest struct {
- Overwrite []overwriteField `json:"overwrite"`
- Locale string `json:"locale"`
-}
-
-func newHTTPClient() *http.Client {
- timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10)
- dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second}
- transport := &http.Transport{
- MaxIdleConns: 100,
- IdleConnTimeout: 90 * time.Second,
- TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second,
- ExpectContinueTimeout: 1 * time.Second,
- ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
- }
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- host = addr
- }
- if strings.HasSuffix(host, "github.io") {
- if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
- return conn, nil
- }
- return dialer.DialContext(ctx, "tcp6", addr)
- }
- return dialer.DialContext(ctx, network, addr)
- }
- return &http.Client{Transport: transport}
-}
-
-var httpClient = newHTTPClient()
-
-func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
- var lastErr error
- attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3)
- if attempts < 1 {
- attempts = 1
- }
- baseDelay := 200 * time.Millisecond
- maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10)
- maxBytes := int64(maxMB) << 20
- for attempt := 0; attempt < attempts; attempt++ {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
- if err != nil {
- return err
- }
- // ETag conditional request
- cacheMutex.RLock()
- if et := etagCache[url]; et != "" {
- req.Header.Set("If-None-Match", et)
- }
- cacheMutex.RUnlock()
-
- resp, err := httpClient.Do(req)
- if err != nil {
- lastErr = err
- // backoff with jitter
- sleep := baseDelay * time.Duration(1< id
- vendorIDCache := make(map[string]int)
-
- for _, name := range missing {
- up, ok := modelByName[name]
- if !ok {
- skipped = append(skipped, name)
- continue
- }
-
- // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
- var existing model.Model
- if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
- if existing.SyncOfficial == 0 {
- skipped = append(skipped, name)
- continue
- }
- }
-
- // 确保 vendor 存在
- vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
-
- // 创建模型
- mi := &model.Model{
- ModelName: name,
- Description: up.Description,
- Icon: up.Icon,
- Tags: up.Tags,
- VendorID: vendorID,
- Status: chooseStatus(up.Status, 1),
- NameRule: up.NameRule,
- }
- if err := mi.Insert(); err == nil {
- createdModels++
- createdList = append(createdList, name)
- } else {
- skipped = append(skipped, name)
- }
- }
-
- // 4) 处理可选覆盖(更新本地已有模型的差异字段)
- if len(req.Overwrite) > 0 {
- // vendorIDCache 已用于创建阶段,可复用
- for _, ow := range req.Overwrite {
- up, ok := modelByName[ow.ModelName]
- if !ok {
- continue
- }
- var local model.Model
- if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
- continue
- }
-
- // 跳过被禁用官方同步的模型
- if local.SyncOfficial == 0 {
- continue
- }
-
- // 映射 vendor
- newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
-
- // 应用字段覆盖(事务)
- _ = model.DB.Transaction(func(tx *gorm.DB) error {
- needUpdate := false
- if containsField(ow.Fields, "description") {
- local.Description = up.Description
- needUpdate = true
- }
- if containsField(ow.Fields, "icon") {
- local.Icon = up.Icon
- needUpdate = true
- }
- if containsField(ow.Fields, "tags") {
- local.Tags = up.Tags
- needUpdate = true
- }
- if containsField(ow.Fields, "vendor") {
- local.VendorID = newVendorID
- needUpdate = true
- }
- if containsField(ow.Fields, "name_rule") {
- local.NameRule = up.NameRule
- needUpdate = true
- }
- if containsField(ow.Fields, "status") {
- local.Status = chooseStatus(up.Status, local.Status)
- needUpdate = true
- }
- if !needUpdate {
- return nil
- }
- if err := tx.Save(&local).Error; err != nil {
- return err
- }
- updatedModels++
- updatedList = append(updatedList, ow.ModelName)
- return nil
- })
- }
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": gin.H{
- "created_models": createdModels,
- "created_vendors": createdVendors,
- "updated_models": updatedModels,
- "skipped_models": skipped,
- "created_list": createdList,
- "updated_list": updatedList,
- "source": gin.H{
- "locale": req.Locale,
- "models_url": modelsURL,
- "vendors_url": vendorsURL,
- },
- },
- })
-}
-
-func containsField(fields []string, key string) bool {
- key = strings.ToLower(strings.TrimSpace(key))
- for _, f := range fields {
- if strings.ToLower(strings.TrimSpace(f)) == key {
- return true
- }
- }
- return false
-}
-
-func coalesce(a, b string) string {
- if strings.TrimSpace(a) != "" {
- return a
- }
- return b
-}
-
-func chooseStatus(primary, fallback int) int {
- if primary == 0 && fallback != 0 {
- return fallback
- }
- if primary != 0 {
- return primary
- }
- return 1
-}
-
-// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
-func SyncUpstreamPreview(c *gin.Context) {
- // 1) 拉取上游数据
- timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
- ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
- defer cancel()
-
- locale := c.Query("locale")
- modelsURL, vendorsURL := getUpstreamURLs(locale)
-
- var vendorsEnv upstreamEnvelope[upstreamVendor]
- var modelsEnv upstreamEnvelope[upstreamModel]
- var fetchErr error
- var wg sync.WaitGroup
- wg.Add(2)
- go func() {
- defer wg.Done()
- _ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
- }()
- go func() {
- defer wg.Done()
- if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
- fetchErr = err
- }
- }()
- wg.Wait()
- if fetchErr != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
- return
- }
-
- vendorByName := make(map[string]upstreamVendor)
- for _, v := range vendorsEnv.Data {
- if v.Name != "" {
- vendorByName[v.Name] = v
- }
- }
- modelByName := make(map[string]upstreamModel)
- upstreamNames := make([]string, 0, len(modelsEnv.Data))
- for _, m := range modelsEnv.Data {
- if m.ModelName != "" {
- modelByName[m.ModelName] = m
- upstreamNames = append(upstreamNames, m.ModelName)
- }
- }
-
- // 2) 本地已有模型
- var locals []model.Model
- if len(upstreamNames) > 0 {
- _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
- }
-
- // 本地 vendor 名称映射
- vendorIdSet := make(map[int]struct{})
- for _, m := range locals {
- if m.VendorID != 0 {
- vendorIdSet[m.VendorID] = struct{}{}
- }
- }
- vendorIDs := make([]int, 0, len(vendorIdSet))
- for id := range vendorIdSet {
- vendorIDs = append(vendorIDs, id)
- }
- idToVendorName := make(map[int]string)
- if len(vendorIDs) > 0 {
- var dbVendors []model.Vendor
- _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
- for _, v := range dbVendors {
- idToVendorName[v.Id] = v.Name
- }
- }
-
- // 3) 缺失且上游存在的模型
- missingList, _ := model.GetMissingModels()
- var missing []string
- for _, name := range missingList {
- if _, ok := modelByName[name]; ok {
- missing = append(missing, name)
- }
- }
-
- // 4) 计算冲突字段
- type conflictField struct {
- Field string `json:"field"`
- Local interface{} `json:"local"`
- Upstream interface{} `json:"upstream"`
- }
- type conflictItem struct {
- ModelName string `json:"model_name"`
- Fields []conflictField `json:"fields"`
- }
-
- var conflicts []conflictItem
- for _, local := range locals {
- up, ok := modelByName[local.ModelName]
- if !ok {
- continue
- }
- fields := make([]conflictField, 0, 6)
- if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
- fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
- }
- if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
- fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
- }
- if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
- fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
- }
- // vendor 对比使用名称
- localVendor := idToVendorName[local.VendorID]
- if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
- fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
- }
- if local.NameRule != up.NameRule {
- fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
- }
- if local.Status != chooseStatus(up.Status, local.Status) {
- fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
- }
- if len(fields) > 0 {
- conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
- }
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": gin.H{
- "missing": missing,
- "conflicts": conflicts,
- "source": gin.H{
- "locale": locale,
- "models_url": modelsURL,
- "vendors_url": vendorsURL,
- },
- },
- })
-}
diff --git a/new-api/controller/oidc.go b/new-api/controller/oidc.go
deleted file mode 100644
index 1684da7eca758bbe4251fc6747369c87ba3ddc09..0000000000000000000000000000000000000000
--- a/new-api/controller/oidc.go
+++ /dev/null
@@ -1,227 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "net/url"
- "one-api/common"
- "one-api/model"
- "one-api/setting/system_setting"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-type OidcResponse struct {
- AccessToken string `json:"access_token"`
- IDToken string `json:"id_token"`
- RefreshToken string `json:"refresh_token"`
- TokenType string `json:"token_type"`
- ExpiresIn int `json:"expires_in"`
- Scope string `json:"scope"`
-}
-
-type OidcUser struct {
- OpenID string `json:"sub"`
- Email string `json:"email"`
- Name string `json:"name"`
- PreferredUsername string `json:"preferred_username"`
- Picture string `json:"picture"`
-}
-
-func getOidcUserInfoByCode(code string) (*OidcUser, error) {
- if code == "" {
- return nil, errors.New("无效的参数")
- }
-
- values := url.Values{}
- values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
- values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
- values.Set("code", code)
- values.Set("grant_type", "authorization_code")
- values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
- formData := values.Encode()
- req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- req.Header.Set("Accept", "application/json")
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- res, err := client.Do(req)
- if err != nil {
- common.SysLog(err.Error())
- return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
- }
- defer res.Body.Close()
- var oidcResponse OidcResponse
- err = json.NewDecoder(res.Body).Decode(&oidcResponse)
- if err != nil {
- return nil, err
- }
-
- if oidcResponse.AccessToken == "" {
- common.SysLog("OIDC 获取 Token 失败,请检查设置!")
- return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
- }
-
- req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
- res2, err := client.Do(req)
- if err != nil {
- common.SysLog(err.Error())
- return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
- }
- defer res2.Body.Close()
- if res2.StatusCode != http.StatusOK {
- common.SysLog("OIDC 获取用户信息失败!请检查设置!")
- return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
- }
-
- var oidcUser OidcUser
- err = json.NewDecoder(res2.Body).Decode(&oidcUser)
- if err != nil {
- return nil, err
- }
- if oidcUser.OpenID == "" || oidcUser.Email == "" {
- common.SysLog("OIDC 获取用户信息为空!请检查设置!")
- return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
- }
- return &oidcUser, nil
-}
-
-func OidcAuth(c *gin.Context) {
- session := sessions.Default(c)
- state := c.Query("state")
- if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "state is empty or not same",
- })
- return
- }
- username := session.Get("username")
- if username != nil {
- OidcBind(c)
- return
- }
- if !system_setting.GetOIDCSettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 OIDC 登录以及注册",
- })
- return
- }
- code := c.Query("code")
- oidcUser, err := getOidcUserInfoByCode(code)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user := model.User{
- OidcId: oidcUser.OpenID,
- }
- if model.IsOidcIdAlreadyTaken(user.OidcId) {
- err := user.FillUserByOidcId()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- if common.RegisterEnabled {
- user.Email = oidcUser.Email
- if oidcUser.PreferredUsername != "" {
- user.Username = oidcUser.PreferredUsername
- } else {
- user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
- }
- if oidcUser.Name != "" {
- user.DisplayName = oidcUser.Name
- } else {
- user.DisplayName = "OIDC User"
- }
- err := user.Insert(0)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员关闭了新用户注册",
- })
- return
- }
- }
-
- if user.Status != common.UserStatusEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "用户已被封禁",
- "success": false,
- })
- return
- }
- setupLogin(&user, c)
-}
-
-func OidcBind(c *gin.Context) {
- if !system_setting.GetOIDCSettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未开启通过 OIDC 登录以及注册",
- })
- return
- }
- code := c.Query("code")
- oidcUser, err := getOidcUserInfoByCode(code)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user := model.User{
- OidcId: oidcUser.OpenID,
- }
- if model.IsOidcIdAlreadyTaken(user.OidcId) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该 OIDC 账户已被绑定",
- })
- return
- }
- session := sessions.Default(c)
- id := session.Get("id")
- // id := c.GetInt("id") // critical bug!
- user.Id = id.(int)
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user.OidcId = oidcUser.OpenID
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "bind",
- })
- return
-}
diff --git a/new-api/controller/option.go b/new-api/controller/option.go
deleted file mode 100644
index 345a4626c977ff8daa12c6668dfebef10beb1bdb..0000000000000000000000000000000000000000
--- a/new-api/controller/option.go
+++ /dev/null
@@ -1,214 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- "one-api/setting"
- "one-api/setting/console_setting"
- "one-api/setting/ratio_setting"
- "one-api/setting/system_setting"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetOptions(c *gin.Context) {
- var options []*model.Option
- common.OptionMapRWMutex.Lock()
- for k, v := range common.OptionMap {
- if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
- continue
- }
- options = append(options, &model.Option{
- Key: k,
- Value: common.Interface2String(v),
- })
- }
- common.OptionMapRWMutex.Unlock()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": options,
- })
- return
-}
-
-type OptionUpdateRequest struct {
- Key string `json:"key"`
- Value any `json:"value"`
-}
-
-func UpdateOption(c *gin.Context) {
- var option OptionUpdateRequest
- err := json.NewDecoder(c.Request.Body).Decode(&option)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- switch option.Value.(type) {
- case bool:
- option.Value = common.Interface2String(option.Value.(bool))
- case float64:
- option.Value = common.Interface2String(option.Value.(float64))
- case int:
- option.Value = common.Interface2String(option.Value.(int))
- default:
- option.Value = fmt.Sprintf("%v", option.Value)
- }
- switch option.Key {
- case "GitHubOAuthEnabled":
- if option.Value == "true" && common.GitHubClientId == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
- })
- return
- }
- case "oidc.enabled":
- if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
- })
- return
- }
- case "LinuxDOOAuthEnabled":
- if option.Value == "true" && common.LinuxDOClientId == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用 LinuxDO OAuth,请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret!",
- })
- return
- }
- case "EmailDomainRestrictionEnabled":
- if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
- })
- return
- }
- case "WeChatAuthEnabled":
- if option.Value == "true" && common.WeChatServerAddress == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
- })
- return
- }
- case "TurnstileCheckEnabled":
- if option.Value == "true" && common.TurnstileSiteKey == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
- })
-
- return
- }
- case "TelegramOAuthEnabled":
- if option.Value == "true" && common.TelegramBotToken == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!",
- })
- return
- }
- case "GroupRatio":
- err = ratio_setting.CheckGroupRatio(option.Value.(string))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "ImageRatio":
- err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "图片倍率设置失败: " + err.Error(),
- })
- return
- }
- case "AudioRatio":
- err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "音频倍率设置失败: " + err.Error(),
- })
- return
- }
- case "AudioCompletionRatio":
- err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "音频补全倍率设置失败: " + err.Error(),
- })
- return
- }
- case "ModelRequestRateLimitGroup":
- err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "console_setting.api_info":
- err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "console_setting.announcements":
- err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements")
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "console_setting.faq":
- err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ")
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "console_setting.uptime_kuma_groups":
- err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups")
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- }
- err = model.UpdateOption(option.Key, option.Value.(string))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
diff --git a/new-api/controller/passkey.go b/new-api/controller/passkey.go
deleted file mode 100644
index 54d07a39a58d8009dee6d7d36e8fe8a61c919dd3..0000000000000000000000000000000000000000
--- a/new-api/controller/passkey.go
+++ /dev/null
@@ -1,497 +0,0 @@
-package controller
-
-import (
- "errors"
- "fmt"
- "net/http"
- "strconv"
- "time"
-
- "one-api/common"
- "one-api/model"
- passkeysvc "one-api/service/passkey"
- "one-api/setting/system_setting"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- "github.com/go-webauthn/webauthn/protocol"
- webauthnlib "github.com/go-webauthn/webauthn/webauthn"
-)
-
-func PasskeyRegisterBegin(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- credential, err := model.GetPasskeyByUserID(user.Id)
- if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
- common.ApiError(c, err)
- return
- }
- if errors.Is(err, model.ErrPasskeyNotFound) {
- credential = nil
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- waUser := passkeysvc.NewWebAuthnUser(user, credential)
- var options []webauthnlib.RegistrationOption
- if credential != nil {
- descriptor := credential.ToWebAuthnCredential().Descriptor()
- options = append(options, webauthnlib.WithExclusions([]protocol.CredentialDescriptor{descriptor}))
- }
-
- creation, sessionData, err := wa.BeginRegistration(waUser, options...)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- if err := passkeysvc.SaveSessionData(c, passkeysvc.RegistrationSessionKey, sessionData); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "options": creation,
- },
- })
-}
-
-func PasskeyRegisterFinish(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- credentialRecord, err := model.GetPasskeyByUserID(user.Id)
- if err != nil && !errors.Is(err, model.ErrPasskeyNotFound) {
- common.ApiError(c, err)
- return
- }
- if errors.Is(err, model.ErrPasskeyNotFound) {
- credentialRecord = nil
- }
-
- sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.RegistrationSessionKey)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- waUser := passkeysvc.NewWebAuthnUser(user, credentialRecord)
- credential, err := wa.FinishRegistration(waUser, *sessionData, c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- passkeyCredential := model.NewPasskeyCredentialFromWebAuthn(user.Id, credential)
- if passkeyCredential == nil {
- common.ApiErrorMsg(c, "无法创建 Passkey 凭证")
- return
- }
-
- if err := model.UpsertPasskeyCredential(passkeyCredential); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Passkey 注册成功",
- })
-}
-
-func PasskeyDelete(c *gin.Context) {
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- if err := model.DeletePasskeyByUserID(user.Id); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Passkey 已解绑",
- })
-}
-
-func PasskeyStatus(c *gin.Context) {
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- credential, err := model.GetPasskeyByUserID(user.Id)
- if errors.Is(err, model.ErrPasskeyNotFound) {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "enabled": false,
- },
- })
- return
- }
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- data := gin.H{
- "enabled": true,
- "last_used_at": credential.LastUsedAt,
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": data,
- })
-}
-
-func PasskeyLoginBegin(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- assertion, sessionData, err := wa.BeginDiscoverableLogin()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- if err := passkeysvc.SaveSessionData(c, passkeysvc.LoginSessionKey, sessionData); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "options": assertion,
- },
- })
-}
-
-func PasskeyLoginFinish(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.LoginSessionKey)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- handler := func(rawID, userHandle []byte) (webauthnlib.User, error) {
- // 首先通过凭证ID查找用户
- credential, err := model.GetPasskeyByCredentialID(rawID)
- if err != nil {
- return nil, fmt.Errorf("未找到 Passkey 凭证: %w", err)
- }
-
- // 通过凭证获取用户
- user := &model.User{Id: credential.UserID}
- if err := user.FillUserById(); err != nil {
- return nil, fmt.Errorf("用户信息获取失败: %w", err)
- }
-
- if user.Status != common.UserStatusEnabled {
- return nil, errors.New("该用户已被禁用")
- }
-
- if len(userHandle) > 0 {
- userID, parseErr := strconv.Atoi(string(userHandle))
- if parseErr != nil {
- // 记录异常但继续验证,因为某些客户端可能使用非数字格式
- common.SysLog(fmt.Sprintf("PasskeyLogin: userHandle parse error for credential, length: %d", len(userHandle)))
- } else if userID != user.Id {
- return nil, errors.New("用户句柄与凭证不匹配")
- }
- }
-
- return passkeysvc.NewWebAuthnUser(user, credential), nil
- }
-
- waUser, credential, err := wa.FinishPasskeyLogin(handler, *sessionData, c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- userWrapper, ok := waUser.(*passkeysvc.WebAuthnUser)
- if !ok {
- common.ApiErrorMsg(c, "Passkey 登录状态异常")
- return
- }
-
- modelUser := userWrapper.ModelUser()
- if modelUser == nil {
- common.ApiErrorMsg(c, "Passkey 登录状态异常")
- return
- }
-
- if modelUser.Status != common.UserStatusEnabled {
- common.ApiErrorMsg(c, "该用户已被禁用")
- return
- }
-
- // 更新凭证信息
- updatedCredential := model.NewPasskeyCredentialFromWebAuthn(modelUser.Id, credential)
- if updatedCredential == nil {
- common.ApiErrorMsg(c, "Passkey 凭证更新失败")
- return
- }
- now := time.Now()
- updatedCredential.LastUsedAt = &now
- if err := model.UpsertPasskeyCredential(updatedCredential); err != nil {
- common.ApiError(c, err)
- return
- }
-
- setupLogin(modelUser, c)
- return
-}
-
-func AdminResetPasskey(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiErrorMsg(c, "无效的用户 ID")
- return
- }
-
- user := &model.User{Id: id}
- if err := user.FillUserById(); err != nil {
- common.ApiError(c, err)
- return
- }
-
- if _, err := model.GetPasskeyByUserID(user.Id); err != nil {
- if errors.Is(err, model.ErrPasskeyNotFound) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户尚未绑定 Passkey",
- })
- return
- }
- common.ApiError(c, err)
- return
- }
-
- if err := model.DeletePasskeyByUserID(user.Id); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Passkey 已重置",
- })
-}
-
-func PasskeyVerifyBegin(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- credential, err := model.GetPasskeyByUserID(user.Id)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户尚未绑定 Passkey",
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- waUser := passkeysvc.NewWebAuthnUser(user, credential)
- assertion, sessionData, err := wa.BeginLogin(waUser)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- if err := passkeysvc.SaveSessionData(c, passkeysvc.VerifySessionKey, sessionData); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "options": assertion,
- },
- })
-}
-
-func PasskeyVerifyFinish(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- user, err := getSessionUser(c)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- credential, err := model.GetPasskeyByUserID(user.Id)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户尚未绑定 Passkey",
- })
- return
- }
-
- sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- waUser := passkeysvc.NewWebAuthnUser(user, credential)
- _, err = wa.FinishLogin(waUser, *sessionData, c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 更新凭证的最后使用时间
- now := time.Now()
- credential.LastUsedAt = &now
- if err := model.UpsertPasskeyCredential(credential); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Passkey 验证成功",
- })
-}
-
-func getSessionUser(c *gin.Context) (*model.User, error) {
- session := sessions.Default(c)
- idRaw := session.Get("id")
- if idRaw == nil {
- return nil, errors.New("未登录")
- }
- id, ok := idRaw.(int)
- if !ok {
- return nil, errors.New("无效的会话信息")
- }
- user := &model.User{Id: id}
- if err := user.FillUserById(); err != nil {
- return nil, err
- }
- if user.Status != common.UserStatusEnabled {
- return nil, errors.New("该用户已被禁用")
- }
- return user, nil
-}
diff --git a/new-api/controller/playground.go b/new-api/controller/playground.go
deleted file mode 100644
index 8f509b89ec5e1c9b48b6b430a643b9fb7f86110e..0000000000000000000000000000000000000000
--- a/new-api/controller/playground.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package controller
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "one-api/constant"
- "one-api/middleware"
- "one-api/model"
- "one-api/types"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func Playground(c *gin.Context) {
- var newAPIError *types.NewAPIError
-
- defer func() {
- if newAPIError != nil {
- c.JSON(newAPIError.StatusCode, gin.H{
- "error": newAPIError.ToOpenAIError(),
- })
- }
- }()
-
- useAccessToken := c.GetBool("use_access_token")
- if useAccessToken {
- newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
- return
- }
-
- group := c.GetString("group")
- modelName := c.GetString("original_model")
-
- userId := c.GetInt("id")
-
- // Write user context to ensure acceptUnsetRatio is available
- userCache, err := model.GetUserCache(userId)
- if err != nil {
- newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
- return
- }
- userCache.WriteContext(c)
-
- tempToken := &model.Token{
- UserId: userId,
- Name: fmt.Sprintf("playground-%s", group),
- Group: group,
- }
- _ = middleware.SetupContextForToken(c, tempToken)
- _, newAPIError = getChannel(c, group, modelName, 0)
- if newAPIError != nil {
- return
- }
- //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
- common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
-
- Relay(c, types.RelayFormatOpenAI)
-}
diff --git a/new-api/controller/prefill_group.go b/new-api/controller/prefill_group.go
deleted file mode 100644
index 296e19fb35af880e0c307fa5537a0ed1e99bd287..0000000000000000000000000000000000000000
--- a/new-api/controller/prefill_group.go
+++ /dev/null
@@ -1,90 +0,0 @@
-package controller
-
-import (
- "strconv"
-
- "one-api/common"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
-)
-
-// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
-func GetPrefillGroups(c *gin.Context) {
- groupType := c.Query("type")
- groups, err := model.GetAllPrefillGroups(groupType)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, groups)
-}
-
-// CreatePrefillGroup 创建新的预填组
-func CreatePrefillGroup(c *gin.Context) {
- var g model.PrefillGroup
- if err := c.ShouldBindJSON(&g); err != nil {
- common.ApiError(c, err)
- return
- }
- if g.Name == "" || g.Type == "" {
- common.ApiErrorMsg(c, "组名称和类型不能为空")
- return
- }
- // 创建前检查名称
- if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "组名称已存在")
- return
- }
-
- if err := g.Insert(); err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, &g)
-}
-
-// UpdatePrefillGroup 更新预填组
-func UpdatePrefillGroup(c *gin.Context) {
- var g model.PrefillGroup
- if err := c.ShouldBindJSON(&g); err != nil {
- common.ApiError(c, err)
- return
- }
- if g.Id == 0 {
- common.ApiErrorMsg(c, "缺少组 ID")
- return
- }
- // 名称冲突检查
- if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "组名称已存在")
- return
- }
-
- if err := g.Update(); err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, &g)
-}
-
-// DeletePrefillGroup 删除预填组
-func DeletePrefillGroup(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if err := model.DeletePrefillGroupByID(id); err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, nil)
-}
diff --git a/new-api/controller/pricing.go b/new-api/controller/pricing.go
deleted file mode 100644
index e8d80416ed59571445e486fa9864105abfc40b33..0000000000000000000000000000000000000000
--- a/new-api/controller/pricing.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package controller
-
-import (
- "one-api/model"
- "one-api/setting"
- "one-api/setting/ratio_setting"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetPricing(c *gin.Context) {
- pricing := model.GetPricing()
- userId, exists := c.Get("id")
- usableGroup := map[string]string{}
- groupRatio := map[string]float64{}
- for s, f := range ratio_setting.GetGroupRatioCopy() {
- groupRatio[s] = f
- }
- var group string
- if exists {
- user, err := model.GetUserCache(userId.(int))
- if err == nil {
- group = user.Group
- for g := range groupRatio {
- ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
- if ok {
- groupRatio[g] = ratio
- }
- }
- }
- }
-
- usableGroup = setting.GetUserUsableGroups(group)
- // check groupRatio contains usableGroup
- for group := range ratio_setting.GetGroupRatioCopy() {
- if _, ok := usableGroup[group]; !ok {
- delete(groupRatio, group)
- }
- }
-
- c.JSON(200, gin.H{
- "success": true,
- "data": pricing,
- "vendors": model.GetVendors(),
- "group_ratio": groupRatio,
- "usable_group": usableGroup,
- "supported_endpoint": model.GetSupportedEndpointMap(),
- "auto_groups": setting.AutoGroups,
- })
-}
-
-func ResetModelRatio(c *gin.Context) {
- defaultStr := ratio_setting.DefaultModelRatio2JSONString()
- err := model.UpdateOption("ModelRatio", defaultStr)
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- c.JSON(200, gin.H{
- "success": true,
- "message": "重置模型倍率成功",
- })
-}
diff --git a/new-api/controller/ratio_config.go b/new-api/controller/ratio_config.go
deleted file mode 100644
index b72cb638257923670670b24b110c46997957d165..0000000000000000000000000000000000000000
--- a/new-api/controller/ratio_config.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/setting/ratio_setting"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetRatioConfig(c *gin.Context) {
- if !ratio_setting.IsExposeRatioEnabled() {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "倍率配置接口未启用",
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": ratio_setting.GetExposedData(),
- })
-}
diff --git a/new-api/controller/ratio_sync.go b/new-api/controller/ratio_sync.go
deleted file mode 100644
index 2f7786666a5cb1652b9c78c000bee7c2eca66843..0000000000000000000000000000000000000000
--- a/new-api/controller/ratio_sync.go
+++ /dev/null
@@ -1,539 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "net/http"
- "one-api/logger"
- "strings"
- "sync"
- "time"
-
- "one-api/dto"
- "one-api/model"
- "one-api/setting/ratio_setting"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- defaultTimeoutSeconds = 10
- defaultEndpoint = "/api/ratio_config"
- maxConcurrentFetches = 8
- maxRatioConfigBytes = 10 << 20 // 10MB
- floatEpsilon = 1e-9
-)
-
-func nearlyEqual(a, b float64) bool {
- if a > b {
- return a-b < floatEpsilon
- }
- return b-a < floatEpsilon
-}
-
-func valuesEqual(a, b interface{}) bool {
- af, aok := a.(float64)
- bf, bok := b.(float64)
- if aok && bok {
- return nearlyEqual(af, bf)
- }
- return a == b
-}
-
-var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
-
-type upstreamResult struct {
- Name string `json:"name"`
- Data map[string]any `json:"data,omitempty"`
- Err string `json:"err,omitempty"`
-}
-
-func FetchUpstreamRatios(c *gin.Context) {
- var req dto.UpstreamRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
- return
- }
-
- if req.Timeout <= 0 {
- req.Timeout = defaultTimeoutSeconds
- }
-
- var upstreams []dto.UpstreamDTO
-
- if len(req.Upstreams) > 0 {
- for _, u := range req.Upstreams {
- if strings.HasPrefix(u.BaseURL, "http") {
- if u.Endpoint == "" {
- u.Endpoint = defaultEndpoint
- }
- u.BaseURL = strings.TrimRight(u.BaseURL, "/")
- upstreams = append(upstreams, u)
- }
- }
- } else if len(req.ChannelIDs) > 0 {
- intIds := make([]int, 0, len(req.ChannelIDs))
- for _, id64 := range req.ChannelIDs {
- intIds = append(intIds, int(id64))
- }
- dbChannels, err := model.GetChannelsByIds(intIds)
- if err != nil {
- logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
- c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
- return
- }
- for _, ch := range dbChannels {
- if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
- upstreams = append(upstreams, dto.UpstreamDTO{
- ID: ch.Id,
- Name: ch.Name,
- BaseURL: strings.TrimRight(base, "/"),
- Endpoint: "",
- })
- }
- }
- }
-
- if len(upstreams) == 0 {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
- return
- }
-
- var wg sync.WaitGroup
- ch := make(chan upstreamResult, len(upstreams))
-
- sem := make(chan struct{}, maxConcurrentFetches)
-
- dialer := &net.Dialer{Timeout: 10 * time.Second}
- transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
- transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- host = addr
- }
- // 对 github.io 优先尝试 IPv4,失败则回退 IPv6
- if strings.HasSuffix(host, "github.io") {
- if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
- return conn, nil
- }
- return dialer.DialContext(ctx, "tcp6", addr)
- }
- return dialer.DialContext(ctx, network, addr)
- }
- client := &http.Client{Transport: transport}
-
- for _, chn := range upstreams {
- wg.Add(1)
- go func(chItem dto.UpstreamDTO) {
- defer wg.Done()
-
- sem <- struct{}{}
- defer func() { <-sem }()
-
- endpoint := chItem.Endpoint
- var fullURL string
- if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
- fullURL = endpoint
- } else {
- if endpoint == "" {
- endpoint = defaultEndpoint
- } else if !strings.HasPrefix(endpoint, "/") {
- endpoint = "/" + endpoint
- }
- fullURL = chItem.BaseURL + endpoint
- }
-
- uniqueName := chItem.Name
- if chItem.ID != 0 {
- uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
- }
-
- ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
- defer cancel()
-
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
- if err != nil {
- logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
- return
- }
-
- // 简单重试:最多 3 次,指数退避
- var resp *http.Response
- var lastErr error
- for attempt := 0; attempt < 3; attempt++ {
- resp, lastErr = client.Do(httpReq)
- if lastErr == nil {
- break
- }
- time.Sleep(time.Duration(200*(1< data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
- // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
- var body struct {
- Success bool `json:"success"`
- Data json.RawMessage `json:"data"`
- Message string `json:"message"`
- }
-
- if err := json.NewDecoder(limited).Decode(&body); err != nil {
- logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
- return
- }
-
- if !body.Success {
- ch <- upstreamResult{Name: uniqueName, Err: body.Message}
- return
- }
-
- // 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
-
- // 尝试按 type1 解析
- var type1Data map[string]any
- if err := json.Unmarshal(body.Data, &type1Data); err == nil {
- // 如果包含至少一个 ratioTypes 字段,则认为是 type1
- isType1 := false
- for _, rt := range ratioTypes {
- if _, ok := type1Data[rt]; ok {
- isType1 = true
- break
- }
- }
- if isType1 {
- ch <- upstreamResult{Name: uniqueName, Data: type1Data}
- return
- }
- }
-
- // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
- var pricingItems []struct {
- ModelName string `json:"model_name"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- CompletionRatio float64 `json:"completion_ratio"`
- }
- if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
- logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
- ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
- return
- }
-
- modelRatioMap := make(map[string]float64)
- completionRatioMap := make(map[string]float64)
- modelPriceMap := make(map[string]float64)
-
- for _, item := range pricingItems {
- if item.QuotaType == 1 {
- modelPriceMap[item.ModelName] = item.ModelPrice
- } else {
- modelRatioMap[item.ModelName] = item.ModelRatio
- // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
- completionRatioMap[item.ModelName] = item.CompletionRatio
- }
- }
-
- converted := make(map[string]any)
-
- if len(modelRatioMap) > 0 {
- ratioAny := make(map[string]any, len(modelRatioMap))
- for k, v := range modelRatioMap {
- ratioAny[k] = v
- }
- converted["model_ratio"] = ratioAny
- }
-
- if len(completionRatioMap) > 0 {
- compAny := make(map[string]any, len(completionRatioMap))
- for k, v := range completionRatioMap {
- compAny[k] = v
- }
- converted["completion_ratio"] = compAny
- }
-
- if len(modelPriceMap) > 0 {
- priceAny := make(map[string]any, len(modelPriceMap))
- for k, v := range modelPriceMap {
- priceAny[k] = v
- }
- converted["model_price"] = priceAny
- }
-
- ch <- upstreamResult{Name: uniqueName, Data: converted}
- }(chn)
- }
-
- wg.Wait()
- close(ch)
-
- localData := ratio_setting.GetExposedData()
-
- var testResults []dto.TestResult
- var successfulChannels []struct {
- name string
- data map[string]any
- }
-
- for r := range ch {
- if r.Err != "" {
- testResults = append(testResults, dto.TestResult{
- Name: r.Name,
- Status: "error",
- Error: r.Err,
- })
- } else {
- testResults = append(testResults, dto.TestResult{
- Name: r.Name,
- Status: "success",
- })
- successfulChannels = append(successfulChannels, struct {
- name string
- data map[string]any
- }{name: r.Name, data: r.Data})
- }
- }
-
- differences := buildDifferences(localData, successfulChannels)
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "data": gin.H{
- "differences": differences,
- "test_results": testResults,
- },
- })
-}
-
-func buildDifferences(localData map[string]any, successfulChannels []struct {
- name string
- data map[string]any
-}) map[string]map[string]dto.DifferenceItem {
- differences := make(map[string]map[string]dto.DifferenceItem)
-
- allModels := make(map[string]struct{})
-
- for _, ratioType := range ratioTypes {
- if localRatioAny, ok := localData[ratioType]; ok {
- if localRatio, ok := localRatioAny.(map[string]float64); ok {
- for modelName := range localRatio {
- allModels[modelName] = struct{}{}
- }
- }
- }
- }
-
- for _, channel := range successfulChannels {
- for _, ratioType := range ratioTypes {
- if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
- for modelName := range upstreamRatio {
- allModels[modelName] = struct{}{}
- }
- }
- }
- }
-
- confidenceMap := make(map[string]map[string]bool)
-
- // 预处理阶段:检查pricing接口的可信度
- for _, channel := range successfulChannels {
- confidenceMap[channel.name] = make(map[string]bool)
-
- modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
- completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
-
- if hasModelRatio && hasCompletionRatio {
- // 遍历所有模型,检查是否满足不可信条件
- for modelName := range allModels {
- // 默认为可信
- confidenceMap[channel.name][modelName] = true
-
- // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
- if modelRatioVal, ok := modelRatios[modelName]; ok {
- if completionRatioVal, ok := completionRatios[modelName]; ok {
- // 转换为float64进行比较
- if modelRatioFloat, ok := modelRatioVal.(float64); ok {
- if completionRatioFloat, ok := completionRatioVal.(float64); ok {
- if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
- confidenceMap[channel.name][modelName] = false
- }
- }
- }
- }
- }
- }
- } else {
- // 如果不是从pricing接口获取的数据,则全部标记为可信
- for modelName := range allModels {
- confidenceMap[channel.name][modelName] = true
- }
- }
- }
-
- for modelName := range allModels {
- for _, ratioType := range ratioTypes {
- var localValue interface{} = nil
- if localRatioAny, ok := localData[ratioType]; ok {
- if localRatio, ok := localRatioAny.(map[string]float64); ok {
- if val, exists := localRatio[modelName]; exists {
- localValue = val
- }
- }
- }
-
- upstreamValues := make(map[string]interface{})
- confidenceValues := make(map[string]bool)
- hasUpstreamValue := false
- hasDifference := false
-
- for _, channel := range successfulChannels {
- var upstreamValue interface{} = nil
-
- if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
- if val, exists := upstreamRatio[modelName]; exists {
- upstreamValue = val
- hasUpstreamValue = true
-
- if localValue != nil && !valuesEqual(localValue, val) {
- hasDifference = true
- } else if valuesEqual(localValue, val) {
- upstreamValue = "same"
- }
- }
- }
- if upstreamValue == nil && localValue == nil {
- upstreamValue = "same"
- }
-
- if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
- hasDifference = true
- }
-
- upstreamValues[channel.name] = upstreamValue
-
- confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
- }
-
- shouldInclude := false
-
- if localValue != nil {
- if hasDifference {
- shouldInclude = true
- }
- } else {
- if hasUpstreamValue {
- shouldInclude = true
- }
- }
-
- if shouldInclude {
- if differences[modelName] == nil {
- differences[modelName] = make(map[string]dto.DifferenceItem)
- }
- differences[modelName][ratioType] = dto.DifferenceItem{
- Current: localValue,
- Upstreams: upstreamValues,
- Confidence: confidenceValues,
- }
- }
- }
- }
-
- channelHasDiff := make(map[string]bool)
- for _, ratioMap := range differences {
- for _, item := range ratioMap {
- for chName, val := range item.Upstreams {
- if val != nil && val != "same" {
- channelHasDiff[chName] = true
- }
- }
- }
- }
-
- for modelName, ratioMap := range differences {
- for ratioType, item := range ratioMap {
- for chName := range item.Upstreams {
- if !channelHasDiff[chName] {
- delete(item.Upstreams, chName)
- delete(item.Confidence, chName)
- }
- }
-
- allSame := true
- for _, v := range item.Upstreams {
- if v != "same" {
- allSame = false
- break
- }
- }
- if len(item.Upstreams) == 0 || allSame {
- delete(ratioMap, ratioType)
- } else {
- differences[modelName][ratioType] = item
- }
- }
-
- if len(ratioMap) == 0 {
- delete(differences, modelName)
- }
- }
-
- return differences
-}
-
-func GetSyncableChannels(c *gin.Context) {
- channels, err := model.GetAllChannels(0, 0, true, false)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- var syncableChannels []dto.SyncableChannel
- for _, channel := range channels {
- if channel.GetBaseURL() != "" {
- syncableChannels = append(syncableChannels, dto.SyncableChannel{
- ID: channel.Id,
- Name: channel.Name,
- BaseURL: channel.GetBaseURL(),
- Status: channel.Status,
- })
- }
- }
-
- syncableChannels = append(syncableChannels, dto.SyncableChannel{
- ID: -100,
- Name: "官方倍率预设",
- BaseURL: "https://basellm.github.io",
- Status: 1,
- })
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": syncableChannels,
- })
-}
diff --git a/new-api/controller/redemption.go b/new-api/controller/redemption.go
deleted file mode 100644
index 081a934ac2b63c37dc7e9dc0f6308ef42aa09ea8..0000000000000000000000000000000000000000
--- a/new-api/controller/redemption.go
+++ /dev/null
@@ -1,194 +0,0 @@
-package controller
-
-import (
- "errors"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
- "unicode/utf8"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetAllRedemptions(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(redemptions)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func SearchRedemptions(c *gin.Context) {
- keyword := c.Query("keyword")
- pageInfo := common.GetPageQuery(c)
- redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(redemptions)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func GetRedemption(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- redemption, err := model.GetRedemptionById(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": redemption,
- })
- return
-}
-
-func AddRedemption(c *gin.Context) {
- redemption := model.Redemption{}
- err := c.ShouldBindJSON(&redemption)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "兑换码名称长度必须在1-20之间",
- })
- return
- }
- if redemption.Count <= 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "兑换码个数必须大于0",
- })
- return
- }
- if redemption.Count > 100 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "一次兑换码批量生成的个数不能大于 100",
- })
- return
- }
- if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- var keys []string
- for i := 0; i < redemption.Count; i++ {
- key := common.GetUUID()
- cleanRedemption := model.Redemption{
- UserId: c.GetInt("id"),
- Name: redemption.Name,
- Key: key,
- CreatedTime: common.GetTimestamp(),
- Quota: redemption.Quota,
- ExpiredTime: redemption.ExpiredTime,
- }
- err = cleanRedemption.Insert()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- "data": keys,
- })
- return
- }
- keys = append(keys, key)
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": keys,
- })
- return
-}
-
-func DeleteRedemption(c *gin.Context) {
- id, _ := strconv.Atoi(c.Param("id"))
- err := model.DeleteRedemptionById(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func UpdateRedemption(c *gin.Context) {
- statusOnly := c.Query("status_only")
- redemption := model.Redemption{}
- err := c.ShouldBindJSON(&redemption)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- cleanRedemption, err := model.GetRedemptionById(redemption.Id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if statusOnly == "" {
- if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
- c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
- return
- }
- // If you add more fields, please also update redemption.Update()
- cleanRedemption.Name = redemption.Name
- cleanRedemption.Quota = redemption.Quota
- cleanRedemption.ExpiredTime = redemption.ExpiredTime
- }
- if statusOnly != "" {
- cleanRedemption.Status = redemption.Status
- }
- err = cleanRedemption.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": cleanRedemption,
- })
- return
-}
-
-func DeleteInvalidRedemption(c *gin.Context) {
- rows, err := model.DeleteInvalidRedemptions()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": rows,
- })
- return
-}
-
-func validateExpiredTime(expired int64) error {
- if expired != 0 && expired < common.GetTimestamp() {
- return errors.New("过期时间不能早于当前时间")
- }
- return nil
-}
diff --git a/new-api/controller/relay.go b/new-api/controller/relay.go
deleted file mode 100644
index 918a0c24b512286dca96b4f39af925d297b4807f..0000000000000000000000000000000000000000
--- a/new-api/controller/relay.go
+++ /dev/null
@@ -1,476 +0,0 @@
-package controller
-
-import (
- "bytes"
- "fmt"
- "io"
- "log"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/middleware"
- "one-api/model"
- "one-api/relay"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "one-api/types"
- "strings"
-
- "github.com/bytedance/gopkg/util/gopool"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
- var err *types.NewAPIError
- switch info.RelayMode {
- case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
- err = relay.ImageHelper(c, info)
- case relayconstant.RelayModeAudioSpeech:
- fallthrough
- case relayconstant.RelayModeAudioTranslation:
- fallthrough
- case relayconstant.RelayModeAudioTranscription:
- err = relay.AudioHelper(c, info)
- case relayconstant.RelayModeRerank:
- err = relay.RerankHelper(c, info)
- case relayconstant.RelayModeEmbeddings:
- err = relay.EmbeddingHelper(c, info)
- case relayconstant.RelayModeResponses:
- err = relay.ResponsesHelper(c, info)
- default:
- err = relay.TextHelper(c, info)
- }
- return err
-}
-
-func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
- var err *types.NewAPIError
- if strings.Contains(c.Request.URL.Path, "embed") {
- err = relay.GeminiEmbeddingHandler(c, info)
- } else {
- err = relay.GeminiHelper(c, info)
- }
- return err
-}
-
-func Relay(c *gin.Context, relayFormat types.RelayFormat) {
-
- requestId := c.GetString(common.RequestIdKey)
- group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
- originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
-
- var (
- newAPIError *types.NewAPIError
- ws *websocket.Conn
- )
-
- if relayFormat == types.RelayFormatOpenAIRealtime {
- var err error
- ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
- if err != nil {
- helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
- return
- }
- defer ws.Close()
- }
-
- defer func() {
- if newAPIError != nil {
- newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
- switch relayFormat {
- case types.RelayFormatOpenAIRealtime:
- helper.WssError(c, ws, newAPIError.ToOpenAIError())
- case types.RelayFormatClaude:
- c.JSON(newAPIError.StatusCode, gin.H{
- "type": "error",
- "error": newAPIError.ToClaudeError(),
- })
- default:
- c.JSON(newAPIError.StatusCode, gin.H{
- "error": newAPIError.ToOpenAIError(),
- })
- }
- }
- }()
-
- request, err := helper.GetAndValidateRequest(c, relayFormat)
- if err != nil {
- newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
- return
- }
-
- relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
- if err != nil {
- newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
- return
- }
-
- meta := request.GetTokenCountMeta()
-
- if setting.ShouldCheckPromptSensitive() {
- contains, words := service.CheckSensitiveText(meta.CombineText)
- if contains {
- logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
- newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
- return
- }
- }
-
- tokens, err := service.CountRequestToken(c, meta, relayInfo)
- if err != nil {
- newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
- return
- }
-
- relayInfo.SetPromptTokens(tokens)
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
- if err != nil {
- newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
- return
- }
-
- // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
-
- newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if newAPIError != nil {
- return
- }
-
- defer func() {
- // Only return quota if downstream failed and quota was actually pre-consumed
- if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
- service.ReturnPreConsumedQuota(c, relayInfo)
- }
- }()
-
- for i := 0; i <= common.RetryTimes; i++ {
- channel, err := getChannel(c, group, originalModel, i)
- if err != nil {
- logger.LogError(c, err.Error())
- newAPIError = err
- break
- }
-
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
-
- switch relayFormat {
- case types.RelayFormatOpenAIRealtime:
- newAPIError = relay.WssHelper(c, relayInfo)
- case types.RelayFormatClaude:
- newAPIError = relay.ClaudeHelper(c, relayInfo)
- case types.RelayFormatGemini:
- newAPIError = geminiRelayHandler(c, relayInfo)
- default:
- newAPIError = relayHandler(c, relayInfo)
- }
-
- if newAPIError == nil {
- return
- }
-
- processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
-
- if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
- break
- }
- }
-
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- logger.LogInfo(c, retryLogStr)
- }
-}
-
-var upgrader = websocket.Upgrader{
- Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
- CheckOrigin: func(r *http.Request) bool {
- return true // 允许跨域
- },
-}
-
-func addUsedChannel(c *gin.Context, channelId int) {
- useChannel := c.GetStringSlice("use_channel")
- useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
- c.Set("use_channel", useChannel)
-}
-
-func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
- if retryCount == 0 {
- autoBan := c.GetBool("auto_ban")
- autoBanInt := 1
- if !autoBan {
- autoBanInt = 0
- }
- return &model.Channel{
- Id: c.GetInt("channel_id"),
- Type: c.GetInt("channel_type"),
- Name: c.GetString("channel_name"),
- AutoBan: &autoBanInt,
- }, nil
- }
- channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
- if err != nil {
- return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
- }
- if channel == nil {
- return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
- }
- newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
- if newAPIError != nil {
- return nil, newAPIError
- }
- return channel, nil
-}
-
-func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
- if openaiErr == nil {
- return false
- }
- if types.IsChannelError(openaiErr) {
- return true
- }
- if types.IsSkipRetryError(openaiErr) {
- return false
- }
- if retryTimes <= 0 {
- return false
- }
- if _, ok := c.Get("specific_channel_id"); ok {
- return false
- }
- if openaiErr.StatusCode == http.StatusTooManyRequests {
- return true
- }
- if openaiErr.StatusCode == 307 {
- return true
- }
- if openaiErr.StatusCode/100 == 5 {
- // 超时不重试
- if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
- return false
- }
- return true
- }
- if openaiErr.StatusCode == http.StatusBadRequest {
- return false
- }
- if openaiErr.StatusCode == 408 {
- // azure处理超时不重试
- return false
- }
- if openaiErr.StatusCode/100 == 2 {
- return false
- }
- return true
-}
-
-func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
- logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
- // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
- // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
- if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
- gopool.Go(func() {
- service.DisableChannel(channelError, err.Error())
- })
- }
-
- if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
- // 保存错误日志到mysql中
- userId := c.GetInt("id")
- tokenName := c.GetString("token_name")
- modelName := c.GetString("original_model")
- tokenId := c.GetInt("token_id")
- userGroup := c.GetString("group")
- channelId := c.GetInt("channel_id")
- other := make(map[string]interface{})
- other["error_type"] = err.GetErrorType()
- other["error_code"] = err.GetErrorCode()
- other["status_code"] = err.StatusCode
- other["channel_id"] = channelId
- other["channel_name"] = c.GetString("channel_name")
- other["channel_type"] = c.GetInt("channel_type")
- adminInfo := make(map[string]interface{})
- adminInfo["use_channel"] = c.GetStringSlice("use_channel")
- isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
- if isMultiKey {
- adminInfo["is_multi_key"] = true
- adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
- }
- other["admin_info"] = adminInfo
- model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
- }
-
-}
-
-func RelayMidjourney(c *gin.Context) {
- relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
-
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
- "type": "upstream_error",
- "code": 4,
- })
- return
- }
-
- var mjErr *dto.MidjourneyResponse
- switch relayInfo.RelayMode {
- case relayconstant.RelayModeMidjourneyNotify:
- mjErr = relay.RelayMidjourneyNotify(c)
- case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
- mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
- case relayconstant.RelayModeMidjourneyTaskImageSeed:
- mjErr = relay.RelayMidjourneyTaskImageSeed(c)
- case relayconstant.RelayModeSwapFace:
- mjErr = relay.RelaySwapFace(c, relayInfo)
- default:
- mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
- }
- //err = relayMidjourneySubmit(c, relayMode)
- log.Println(mjErr)
- if mjErr != nil {
- statusCode := http.StatusBadRequest
- if mjErr.Code == 30 {
- mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
- statusCode = http.StatusTooManyRequests
- }
- c.JSON(statusCode, gin.H{
- "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
- "type": "upstream_error",
- "code": mjErr.Code,
- })
- channelId := c.GetInt("channel_id")
- logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
- }
-}
-
-func RelayNotImplemented(c *gin.Context) {
- err := dto.OpenAIError{
- Message: "API not implemented",
- Type: "new_api_error",
- Param: "",
- Code: "api_not_implemented",
- }
- c.JSON(http.StatusNotImplemented, gin.H{
- "error": err,
- })
-}
-
-func RelayNotFound(c *gin.Context) {
- err := dto.OpenAIError{
- Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
- Type: "invalid_request_error",
- Param: "",
- Code: "",
- }
- c.JSON(http.StatusNotFound, gin.H{
- "error": err,
- })
-}
-
-func RelayTask(c *gin.Context) {
- retryTimes := common.RetryTimes
- channelId := c.GetInt("channel_id")
- group := c.GetString("group")
- originalModel := c.GetString("original_model")
- c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
- relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
- if err != nil {
- return
- }
- taskErr := taskRelayHandler(c, relayInfo)
- if taskErr == nil {
- retryTimes = 0
- }
- for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, newAPIError := getChannel(c, group, originalModel, i)
- if newAPIError != nil {
- logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
- taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
- break
- }
- channelId = channel.Id
- useChannel := c.GetStringSlice("use_channel")
- useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
- c.Set("use_channel", useChannel)
- logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
- //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
-
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- taskErr = taskRelayHandler(c, relayInfo)
- }
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- logger.LogInfo(c, retryLogStr)
- }
- if taskErr != nil {
- if taskErr.StatusCode == http.StatusTooManyRequests {
- taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- c.JSON(taskErr.StatusCode, taskErr)
- }
-}
-
-func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
- var err *dto.TaskError
- switch relayInfo.RelayMode {
- case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
- err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
- default:
- err = relay.RelayTaskSubmit(c, relayInfo)
- }
- return err
-}
-
-func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
- if taskErr == nil {
- return false
- }
- if retryTimes <= 0 {
- return false
- }
- if _, ok := c.Get("specific_channel_id"); ok {
- return false
- }
- if taskErr.StatusCode == http.StatusTooManyRequests {
- return true
- }
- if taskErr.StatusCode == 307 {
- return true
- }
- if taskErr.StatusCode/100 == 5 {
- // 超时不重试
- if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
- return false
- }
- return true
- }
- if taskErr.StatusCode == http.StatusBadRequest {
- return false
- }
- if taskErr.StatusCode == 408 {
- // azure处理超时不重试
- return false
- }
- if taskErr.LocalError {
- return false
- }
- if taskErr.StatusCode/100 == 2 {
- return false
- }
- return true
-}
diff --git a/new-api/controller/secure_verification.go b/new-api/controller/secure_verification.go
deleted file mode 100644
index 8fefe7723dc132f0150b7498f0c2a377e549b40d..0000000000000000000000000000000000000000
--- a/new-api/controller/secure_verification.go
+++ /dev/null
@@ -1,313 +0,0 @@
-package controller
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- passkeysvc "one-api/service/passkey"
- "one-api/setting/system_setting"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-const (
- // SecureVerificationSessionKey 安全验证的 session key
- SecureVerificationSessionKey = "secure_verified_at"
- // SecureVerificationTimeout 验证有效期(秒)
- SecureVerificationTimeout = 300 // 5分钟
-)
-
-type UniversalVerifyRequest struct {
- Method string `json:"method"` // "2fa" 或 "passkey"
- Code string `json:"code,omitempty"`
-}
-
-type VerificationStatusResponse struct {
- Verified bool `json:"verified"`
- ExpiresAt int64 `json:"expires_at,omitempty"`
-}
-
-// UniversalVerify 通用验证接口
-// 支持 2FA 和 Passkey 验证,验证成功后在 session 中记录时间戳
-func UniversalVerify(c *gin.Context) {
- userId := c.GetInt("id")
- if userId == 0 {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "未登录",
- })
- return
- }
-
- var req UniversalVerifyRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- common.ApiError(c, fmt.Errorf("参数错误: %v", err))
- return
- }
-
- // 获取用户信息
- user := &model.User{Id: userId}
- if err := user.FillUserById(); err != nil {
- common.ApiError(c, fmt.Errorf("获取用户信息失败: %v", err))
- return
- }
-
- if user.Status != common.UserStatusEnabled {
- common.ApiError(c, fmt.Errorf("该用户已被禁用"))
- return
- }
-
- // 检查用户的验证方式
- twoFA, _ := model.GetTwoFAByUserId(userId)
- has2FA := twoFA != nil && twoFA.IsEnabled
-
- passkey, passkeyErr := model.GetPasskeyByUserID(userId)
- hasPasskey := passkeyErr == nil && passkey != nil
-
- if !has2FA && !hasPasskey {
- common.ApiError(c, fmt.Errorf("用户未启用2FA或Passkey"))
- return
- }
-
- // 根据验证方式进行验证
- var verified bool
- var verifyMethod string
-
- switch req.Method {
- case "2fa":
- if !has2FA {
- common.ApiError(c, fmt.Errorf("用户未启用2FA"))
- return
- }
- if req.Code == "" {
- common.ApiError(c, fmt.Errorf("验证码不能为空"))
- return
- }
- verified = validateTwoFactorAuth(twoFA, req.Code)
- verifyMethod = "2FA"
-
- case "passkey":
- if !hasPasskey {
- common.ApiError(c, fmt.Errorf("用户未启用Passkey"))
- return
- }
- // Passkey 验证需要先调用 PasskeyVerifyBegin 和 PasskeyVerifyFinish
- // 这里只是验证 Passkey 验证流程是否已经完成
- // 实际上,前端应该先调用这两个接口,然后再调用本接口
- verified = true // Passkey 验证逻辑已在 PasskeyVerifyFinish 中完成
- verifyMethod = "Passkey"
-
- default:
- common.ApiError(c, fmt.Errorf("不支持的验证方式: %s", req.Method))
- return
- }
-
- if !verified {
- common.ApiError(c, fmt.Errorf("验证失败,请检查验证码"))
- return
- }
-
- // 验证成功,在 session 中记录时间戳
- session := sessions.Default(c)
- now := time.Now().Unix()
- session.Set(SecureVerificationSessionKey, now)
- if err := session.Save(); err != nil {
- common.ApiError(c, fmt.Errorf("保存验证状态失败: %v", err))
- return
- }
-
- // 记录日志
- model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("通用安全验证成功 (验证方式: %s)", verifyMethod))
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "验证成功",
- "data": gin.H{
- "verified": true,
- "expires_at": now + SecureVerificationTimeout,
- },
- })
-}
-
-// GetVerificationStatus 获取验证状态
-func GetVerificationStatus(c *gin.Context) {
- userId := c.GetInt("id")
- if userId == 0 {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "未登录",
- })
- return
- }
-
- session := sessions.Default(c)
- verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
- if verifiedAtRaw == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": VerificationStatusResponse{
- Verified: false,
- },
- })
- return
- }
-
- verifiedAt, ok := verifiedAtRaw.(int64)
- if !ok {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": VerificationStatusResponse{
- Verified: false,
- },
- })
- return
- }
-
- elapsed := time.Now().Unix() - verifiedAt
- if elapsed >= SecureVerificationTimeout {
- // 验证已过期
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": VerificationStatusResponse{
- Verified: false,
- },
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": VerificationStatusResponse{
- Verified: true,
- ExpiresAt: verifiedAt + SecureVerificationTimeout,
- },
- })
-}
-
-// CheckSecureVerification 检查是否已通过安全验证
-// 返回 true 表示验证有效,false 表示需要重新验证
-func CheckSecureVerification(c *gin.Context) bool {
- session := sessions.Default(c)
- verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
- if verifiedAtRaw == nil {
- return false
- }
-
- verifiedAt, ok := verifiedAtRaw.(int64)
- if !ok {
- return false
- }
-
- elapsed := time.Now().Unix() - verifiedAt
- if elapsed >= SecureVerificationTimeout {
- // 验证已过期,清除 session
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
- return false
- }
-
- return true
-}
-
-// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
-// 这是一个辅助函数,供 PasskeyVerifyFinish 调用
-func PasskeyVerifyAndSetSession(c *gin.Context) {
- session := sessions.Default(c)
- now := time.Now().Unix()
- session.Set(SecureVerificationSessionKey, now)
- _ = session.Save()
-}
-
-// PasskeyVerifyForSecure 用于安全验证的 Passkey 验证流程
-// 整合了 begin 和 finish 流程
-func PasskeyVerifyForSecure(c *gin.Context) {
- if !system_setting.GetPasskeySettings().Enabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员未启用 Passkey 登录",
- })
- return
- }
-
- userId := c.GetInt("id")
- if userId == 0 {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "未登录",
- })
- return
- }
-
- user := &model.User{Id: userId}
- if err := user.FillUserById(); err != nil {
- common.ApiError(c, fmt.Errorf("获取用户信息失败: %v", err))
- return
- }
-
- if user.Status != common.UserStatusEnabled {
- common.ApiError(c, fmt.Errorf("该用户已被禁用"))
- return
- }
-
- credential, err := model.GetPasskeyByUserID(userId)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户尚未绑定 Passkey",
- })
- return
- }
-
- wa, err := passkeysvc.BuildWebAuthn(c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- waUser := passkeysvc.NewWebAuthnUser(user, credential)
- sessionData, err := passkeysvc.PopSessionData(c, passkeysvc.VerifySessionKey)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- _, err = wa.FinishLogin(waUser, *sessionData, c.Request)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 更新凭证的最后使用时间
- now := time.Now()
- credential.LastUsedAt = &now
- if err := model.UpsertPasskeyCredential(credential); err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 验证成功,设置 session
- PasskeyVerifyAndSetSession(c)
-
- // 记录日志
- model.RecordLog(userId, model.LogTypeSystem, "Passkey 安全验证成功")
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Passkey 验证成功",
- "data": gin.H{
- "verified": true,
- "expires_at": time.Now().Unix() + SecureVerificationTimeout,
- },
- })
-}
diff --git a/new-api/controller/setup.go b/new-api/controller/setup.go
deleted file mode 100644
index f9c43c270debd7a859ca574bc13636fec4691489..0000000000000000000000000000000000000000
--- a/new-api/controller/setup.go
+++ /dev/null
@@ -1,181 +0,0 @@
-package controller
-
-import (
- "github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/constant"
- "one-api/model"
- "one-api/setting/operation_setting"
- "time"
-)
-
-type Setup struct {
- Status bool `json:"status"`
- RootInit bool `json:"root_init"`
- DatabaseType string `json:"database_type"`
-}
-
-type SetupRequest struct {
- Username string `json:"username"`
- Password string `json:"password"`
- ConfirmPassword string `json:"confirmPassword"`
- SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
- DemoSiteEnabled bool `json:"DemoSiteEnabled"`
-}
-
-func GetSetup(c *gin.Context) {
- setup := Setup{
- Status: constant.Setup,
- }
- if constant.Setup {
- c.JSON(200, gin.H{
- "success": true,
- "data": setup,
- })
- return
- }
- setup.RootInit = model.RootUserExists()
- if common.UsingMySQL {
- setup.DatabaseType = "mysql"
- }
- if common.UsingPostgreSQL {
- setup.DatabaseType = "postgres"
- }
- if common.UsingSQLite {
- setup.DatabaseType = "sqlite"
- }
- c.JSON(200, gin.H{
- "success": true,
- "data": setup,
- })
-}
-
-func PostSetup(c *gin.Context) {
- // Check if setup is already completed
- if constant.Setup {
- c.JSON(200, gin.H{
- "success": false,
- "message": "系统已经初始化完成",
- })
- return
- }
-
- // Check if root user already exists
- rootExists := model.RootUserExists()
-
- var req SetupRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "请求参数有误",
- })
- return
- }
-
- // If root doesn't exist, validate and create admin account
- if !rootExists {
- // Validate username length: max 12 characters to align with model.User validation
- if len(req.Username) > 12 {
- c.JSON(200, gin.H{
- "success": false,
- "message": "用户名长度不能超过12个字符",
- })
- return
- }
- // Validate password
- if req.Password != req.ConfirmPassword {
- c.JSON(200, gin.H{
- "success": false,
- "message": "两次输入的密码不一致",
- })
- return
- }
-
- if len(req.Password) < 8 {
- c.JSON(200, gin.H{
- "success": false,
- "message": "密码长度至少为8个字符",
- })
- return
- }
-
- // Create root user
- hashedPassword, err := common.Password2Hash(req.Password)
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "系统错误: " + err.Error(),
- })
- return
- }
- rootUser := model.User{
- Username: req.Username,
- Password: hashedPassword,
- Role: common.RoleRootUser,
- Status: common.UserStatusEnabled,
- DisplayName: "Root User",
- AccessToken: nil,
- Quota: 100000000,
- }
- err = model.DB.Create(&rootUser).Error
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "创建管理员账号失败: " + err.Error(),
- })
- return
- }
- }
-
- // Set operation modes
- operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
- operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
-
- // Save operation modes to database for persistence
- err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "保存自用模式设置失败: " + err.Error(),
- })
- return
- }
-
- err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "保存演示站点模式设置失败: " + err.Error(),
- })
- return
- }
-
- // Update setup status
- constant.Setup = true
-
- setup := model.Setup{
- Version: common.Version,
- InitializedAt: time.Now().Unix(),
- }
- err = model.DB.Create(&setup).Error
- if err != nil {
- c.JSON(200, gin.H{
- "success": false,
- "message": "系统初始化失败: " + err.Error(),
- })
- return
- }
-
- c.JSON(200, gin.H{
- "success": true,
- "message": "系统初始化成功",
- })
-}
-
-func boolToString(b bool) string {
- if b {
- return "true"
- }
- return "false"
-}
\ No newline at end of file
diff --git a/new-api/controller/swag_video.go b/new-api/controller/swag_video.go
deleted file mode 100644
index 383ddaad0ef23dead5f3f6e13d924fdb104ee2ee..0000000000000000000000000000000000000000
--- a/new-api/controller/swag_video.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package controller
-
-import (
- "github.com/gin-gonic/gin"
-)
-
-// VideoGenerations
-// @Summary 生成视频
-// @Description 调用视频生成接口生成视频
-// @Description 支持多种视频生成服务:
-// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
-// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
-// @Tags Video
-// @Accept json
-// @Produce json
-// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
-// @Param request body dto.VideoRequest true "视频生成请求参数"
-// @Failure 400 {object} dto.OpenAIError "请求参数错误"
-// @Failure 401 {object} dto.OpenAIError "未授权"
-// @Failure 403 {object} dto.OpenAIError "无权限"
-// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
-// @Router /v1/video/generations [post]
-func VideoGenerations(c *gin.Context) {
-}
-
-// VideoGenerationsTaskId
-// @Summary 查询视频
-// @Description 根据任务ID查询视频生成任务的状态和结果
-// @Tags Video
-// @Accept json
-// @Produce json
-// @Security BearerAuth
-// @Param task_id path string true "Task ID"
-// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
-// @Failure 400 {object} dto.OpenAIError "请求参数错误"
-// @Failure 401 {object} dto.OpenAIError "未授权"
-// @Failure 403 {object} dto.OpenAIError "无权限"
-// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
-// @Router /v1/video/generations/{task_id} [get]
-func VideoGenerationsTaskId(c *gin.Context) {
-}
-
-// KlingText2VideoGenerations
-// @Summary 可灵文生视频
-// @Description 调用可灵AI文生视频接口,生成视频内容
-// @Tags Video
-// @Accept json
-// @Produce json
-// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
-// @Param request body KlingText2VideoRequest true "视频生成请求参数"
-// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
-// @Failure 400 {object} dto.OpenAIError "请求参数错误"
-// @Failure 401 {object} dto.OpenAIError "未授权"
-// @Failure 403 {object} dto.OpenAIError "无权限"
-// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
-// @Router /kling/v1/videos/text2video [post]
-func KlingText2VideoGenerations(c *gin.Context) {
-}
-
-type KlingText2VideoRequest struct {
- ModelName string `json:"model_name,omitempty" example:"kling-v1"`
- Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
- NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
- CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
- Mode string `json:"mode,omitempty" example:"std"`
- CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
- AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
- Duration string `json:"duration,omitempty" example:"5"`
- CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
- ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
-}
-
-type KlingCameraControl struct {
- Type string `json:"type,omitempty" example:"simple"`
- Config *KlingCameraConfig `json:"config,omitempty"`
-}
-
-type KlingCameraConfig struct {
- Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
- Vertical float64 `json:"vertical,omitempty" example:"0"`
- Pan float64 `json:"pan,omitempty" example:"0"`
- Tilt float64 `json:"tilt,omitempty" example:"0"`
- Roll float64 `json:"roll,omitempty" example:"0"`
- Zoom float64 `json:"zoom,omitempty" example:"0"`
-}
-
-// KlingImage2VideoGenerations
-// @Summary 可灵官方-图生视频
-// @Description 调用可灵AI图生视频接口,生成视频内容
-// @Tags Video
-// @Accept json
-// @Produce json
-// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
-// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
-// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
-// @Failure 400 {object} dto.OpenAIError "请求参数错误"
-// @Failure 401 {object} dto.OpenAIError "未授权"
-// @Failure 403 {object} dto.OpenAIError "无权限"
-// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
-// @Router /kling/v1/videos/image2video [post]
-func KlingImage2VideoGenerations(c *gin.Context) {
-}
-
-type KlingImage2VideoRequest struct {
- ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
- Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
- Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
- NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
- CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
- Mode string `json:"mode,omitempty" example:"std"`
- CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
- AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
- Duration string `json:"duration,omitempty" example:"5"`
- CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
- ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
-}
-
-// KlingImage2videoTaskId godoc
-// @Summary 可灵任务查询--图生视频
-// @Description Query the status and result of a Kling video generation task by task ID
-// @Tags Origin
-// @Accept json
-// @Produce json
-// @Param task_id path string true "Task ID"
-// @Router /kling/v1/videos/image2video/{task_id} [get]
-func KlingImage2videoTaskId(c *gin.Context) {}
-
-// KlingText2videoTaskId godoc
-// @Summary 可灵任务查询--文生视频
-// @Description Query the status and result of a Kling text-to-video generation task by task ID
-// @Tags Origin
-// @Accept json
-// @Produce json
-// @Param task_id path string true "Task ID"
-// @Router /kling/v1/videos/text2video/{task_id} [get]
-func KlingText2videoTaskId(c *gin.Context) {}
diff --git a/new-api/controller/task.go b/new-api/controller/task.go
deleted file mode 100644
index 3ce397860bda0ada889a1a8ea764715351dbde2d..0000000000000000000000000000000000000000
--- a/new-api/controller/task.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- "one-api/relay"
- "sort"
- "strconv"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/samber/lo"
-)
-
-func UpdateTaskBulk() {
- //revocer
- //imageModel := "midjourney"
- for {
- time.Sleep(time.Duration(15) * time.Second)
- common.SysLog("任务进度轮询开始")
- ctx := context.TODO()
- allTasks := model.GetAllUnFinishSyncTasks(500)
- platformTask := make(map[constant.TaskPlatform][]*model.Task)
- for _, t := range allTasks {
- platformTask[t.Platform] = append(platformTask[t.Platform], t)
- }
- for platform, tasks := range platformTask {
- if len(tasks) == 0 {
- continue
- }
- taskChannelM := make(map[int][]string)
- taskM := make(map[string]*model.Task)
- nullTaskIds := make([]int64, 0)
- for _, task := range tasks {
- if task.TaskID == "" {
- // 统计失败的未完成任务
- nullTaskIds = append(nullTaskIds, task.ID)
- continue
- }
- taskM[task.TaskID] = task
- taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
- }
- if len(nullTaskIds) > 0 {
- err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
- } else {
- logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
- }
- }
- if len(taskChannelM) == 0 {
- continue
- }
-
- UpdateTaskByPlatform(platform, taskChannelM, taskM)
- }
- common.SysLog("任务进度轮询完成")
- }
-}
-
-func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
- switch platform {
- case constant.TaskPlatformMidjourney:
- //_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
- case constant.TaskPlatformSuno:
- _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
- default:
- if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
- }
- }
-}
-
-func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
- }
- }
- return nil
-}
-
-func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- channel, err := model.CacheGetChannel(channelId)
- if err != nil {
- common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
- err = model.TaskBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if err != nil {
- common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
- }
- return err
- }
- adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
- if adaptor == nil {
- return errors.New("adaptor not found")
- }
- resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
- "ids": taskIds,
- })
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
- return err
- }
- if resp.StatusCode != http.StatusOK {
- logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
- }
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
- return err
- }
- var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
- err = json.Unmarshal(responseBody, &responseItems)
- if err != nil {
- logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
- return err
- }
- if !responseItems.IsSuccess() {
- common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
- return err
- }
-
- for _, responseItem := range responseItems.Data {
- task := taskM[responseItem.TaskID]
- if !checkTaskNeedUpdate(task, responseItem) {
- continue
- }
-
- task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
- task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
- task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
- task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
- task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
- if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
- logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
- task.Progress = "100%"
- //err = model.CacheUpdateUserQuota(task.UserId) ?
- if err != nil {
- logger.LogError(ctx, "error update user quota cache: "+err.Error())
- } else {
- quota := task.Quota
- if quota != 0 {
- err = model.IncreaseUserQuota(task.UserId, quota, false)
- if err != nil {
- logger.LogError(ctx, "fail to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- }
- }
- if responseItem.Status == model.TaskStatusSuccess {
- task.Progress = "100%"
- }
- task.Data = responseItem.Data
-
- err = task.Update()
- if err != nil {
- common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
- }
- }
- return nil
-}
-
-func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
-
- if oldTask.SubmitTime != newTask.SubmitTime {
- return true
- }
- if oldTask.StartTime != newTask.StartTime {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
- if string(oldTask.Status) != newTask.Status {
- return true
- }
- if oldTask.FailReason != newTask.FailReason {
- return true
- }
- if oldTask.FinishTime != newTask.FinishTime {
- return true
- }
-
- if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
- return true
- }
-
- oldData, _ := json.Marshal(oldTask.Data)
- newData, _ := json.Marshal(newTask.Data)
-
- sort.Slice(oldData, func(i, j int) bool {
- return oldData[i] < oldData[j]
- })
- sort.Slice(newData, func(i, j int) bool {
- return newData[i] < newData[j]
- })
-
- if string(oldData) != string(newData) {
- return true
- }
- return false
-}
-
-func GetAllTask(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
-
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- // 解析其他查询参数
- queryParams := model.SyncTaskQueryParams{
- Platform: constant.TaskPlatform(c.Query("platform")),
- TaskID: c.Query("task_id"),
- Status: c.Query("status"),
- Action: c.Query("action"),
- StartTimestamp: startTimestamp,
- EndTimestamp: endTimestamp,
- ChannelID: c.Query("channel_id"),
- }
-
- items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.TaskCountAllTasks(queryParams)
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
-}
-
-func GetUserTask(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
-
- userId := c.GetInt("id")
-
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
-
- queryParams := model.SyncTaskQueryParams{
- Platform: constant.TaskPlatform(c.Query("platform")),
- TaskID: c.Query("task_id"),
- Status: c.Query("status"),
- Action: c.Query("action"),
- StartTimestamp: startTimestamp,
- EndTimestamp: endTimestamp,
- }
-
- items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
- total := model.TaskCountAllUserTask(userId, queryParams)
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(items)
- common.ApiSuccess(c, pageInfo)
-}
diff --git a/new-api/controller/task_video.go b/new-api/controller/task_video.go
deleted file mode 100644
index f75bae7dc1a9c7990e5081ad6076330e91578079..0000000000000000000000000000000000000000
--- a/new-api/controller/task_video.go
+++ /dev/null
@@ -1,184 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- "one-api/relay"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "time"
-)
-
-func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
- for channelId, taskIds := range taskChannelM {
- if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
- }
- }
- return nil
-}
-
-func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
- if len(taskIds) == 0 {
- return nil
- }
- cacheGetChannel, err := model.CacheGetChannel(channelId)
- if err != nil {
- errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
- "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
- "status": "FAILURE",
- "progress": "100%",
- })
- if errUpdate != nil {
- common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
- }
- return fmt.Errorf("CacheGetChannel failed: %w", err)
- }
- adaptor := relay.GetTaskAdaptor(platform)
- if adaptor == nil {
- return fmt.Errorf("video adaptor not found")
- }
- for _, taskId := range taskIds {
- if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
- logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
- }
- }
- return nil
-}
-
-func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
- baseURL := constant.ChannelBaseURLs[channel.Type]
- if channel.GetBaseURL() != "" {
- baseURL = channel.GetBaseURL()
- }
-
- task := taskM[taskId]
- if task == nil {
- logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
- return fmt.Errorf("task %s not found", taskId)
- }
- resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
- "task_id": taskId,
- "action": task.Action,
- })
- if err != nil {
- return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
- }
- //if resp.StatusCode != http.StatusOK {
- //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
- //}
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
- }
-
- taskResult := &relaycommon.TaskInfo{}
- // try parse as New API response format
- var responseItems dto.TaskResponse[model.Task]
- if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
- t := responseItems.Data
- taskResult.TaskID = t.TaskID
- taskResult.Status = string(t.Status)
- taskResult.Url = t.FailReason
- taskResult.Progress = t.Progress
- taskResult.Reason = t.FailReason
- } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
- return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
- } else {
- task.Data = redactVideoResponseBody(responseBody)
- }
-
- now := time.Now().Unix()
- if taskResult.Status == "" {
- return fmt.Errorf("task %s status is empty", taskId)
- }
- task.Status = model.TaskStatus(taskResult.Status)
- switch taskResult.Status {
- case model.TaskStatusSubmitted:
- task.Progress = "10%"
- case model.TaskStatusQueued:
- task.Progress = "20%"
- case model.TaskStatusInProgress:
- task.Progress = "30%"
- if task.StartTime == 0 {
- task.StartTime = now
- }
- case model.TaskStatusSuccess:
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
- task.FailReason = taskResult.Url
- }
- case model.TaskStatusFailure:
- task.Status = model.TaskStatusFailure
- task.Progress = "100%"
- if task.FinishTime == 0 {
- task.FinishTime = now
- }
- task.FailReason = taskResult.Reason
- logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
- quota := task.Quota
- if quota != 0 {
- if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
- logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
- }
- logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
- model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
- }
- default:
- return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
- }
- if taskResult.Progress != "" {
- task.Progress = taskResult.Progress
- }
- if err := task.Update(); err != nil {
- common.SysLog("UpdateVideoTask task error: " + err.Error())
- }
-
- return nil
-}
-
-func redactVideoResponseBody(body []byte) []byte {
- var m map[string]any
- if err := json.Unmarshal(body, &m); err != nil {
- return body
- }
- resp, _ := m["response"].(map[string]any)
- if resp != nil {
- delete(resp, "bytesBase64Encoded")
- if v, ok := resp["video"].(string); ok {
- resp["video"] = truncateBase64(v)
- }
- if vs, ok := resp["videos"].([]any); ok {
- for i := range vs {
- if vm, ok := vs[i].(map[string]any); ok {
- delete(vm, "bytesBase64Encoded")
- }
- }
- }
- }
- b, err := json.Marshal(m)
- if err != nil {
- return body
- }
- return b
-}
-
-func truncateBase64(s string) string {
- const maxKeep = 256
- if len(s) <= maxKeep {
- return s
- }
- return s[:maxKeep] + "..."
-}
diff --git a/new-api/controller/telegram.go b/new-api/controller/telegram.go
deleted file mode 100644
index e4286608438517cbf11ff5ae5f9a5c8a42662c62..0000000000000000000000000000000000000000
--- a/new-api/controller/telegram.go
+++ /dev/null
@@ -1,124 +0,0 @@
-package controller
-
-import (
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "io"
- "net/http"
- "one-api/common"
- "one-api/model"
- "sort"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-func TelegramBind(c *gin.Context) {
- if !common.TelegramOAuthEnabled {
- c.JSON(200, gin.H{
- "message": "管理员未开启通过 Telegram 登录以及注册",
- "success": false,
- })
- return
- }
- params := c.Request.URL.Query()
- if !checkTelegramAuthorization(params, common.TelegramBotToken) {
- c.JSON(200, gin.H{
- "message": "无效的请求",
- "success": false,
- })
- return
- }
- telegramId := params["id"][0]
- if model.IsTelegramIdAlreadyTaken(telegramId) {
- c.JSON(200, gin.H{
- "message": "该 Telegram 账户已被绑定",
- "success": false,
- })
- return
- }
-
- session := sessions.Default(c)
- id := session.Get("id")
- user := model.User{Id: id.(int)}
- if err := user.FillUserById(); err != nil {
- c.JSON(200, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
- if user.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已注销",
- })
- return
- }
- user.TelegramId = telegramId
- if err := user.Update(false); err != nil {
- c.JSON(200, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
-
- c.Redirect(302, "/console/personal")
-}
-
-func TelegramLogin(c *gin.Context) {
- if !common.TelegramOAuthEnabled {
- c.JSON(200, gin.H{
- "message": "管理员未开启通过 Telegram 登录以及注册",
- "success": false,
- })
- return
- }
- params := c.Request.URL.Query()
- if !checkTelegramAuthorization(params, common.TelegramBotToken) {
- c.JSON(200, gin.H{
- "message": "无效的请求",
- "success": false,
- })
- return
- }
-
- telegramId := params["id"][0]
- user := model.User{TelegramId: telegramId}
- if err := user.FillUserByTelegramId(); err != nil {
- c.JSON(200, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
- setupLogin(&user, c)
-}
-
-func checkTelegramAuthorization(params map[string][]string, token string) bool {
- strs := []string{}
- var hash = ""
- for k, v := range params {
- if k == "hash" {
- hash = v[0]
- continue
- }
- strs = append(strs, k+"="+v[0])
- }
- sort.Strings(strs)
- var imploded = ""
- for _, s := range strs {
- if imploded != "" {
- imploded += "\n"
- }
- imploded += s
- }
- sha256hash := sha256.New()
- io.WriteString(sha256hash, token)
- hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
- io.WriteString(hmachash, imploded)
- ss := hex.EncodeToString(hmachash.Sum(nil))
- return hash == ss
-}
diff --git a/new-api/controller/token.go b/new-api/controller/token.go
deleted file mode 100644
index 37796b655c53c32cb105a586ff0c581e8de2e03a..0000000000000000000000000000000000000000
--- a/new-api/controller/token.go
+++ /dev/null
@@ -1,288 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetAllTokens(c *gin.Context) {
- userId := c.GetInt("id")
- pageInfo := common.GetPageQuery(c)
- tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- total, _ := model.CountUserTokens(userId)
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(tokens)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func SearchTokens(c *gin.Context) {
- userId := c.GetInt("id")
- keyword := c.Query("keyword")
- token := c.Query("token")
- tokens, err := model.SearchUserTokens(userId, keyword, token)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": tokens,
- })
- return
-}
-
-func GetToken(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- userId := c.GetInt("id")
- if err != nil {
- common.ApiError(c, err)
- return
- }
- token, err := model.GetTokenByIds(id, userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": token,
- })
- return
-}
-
-func GetTokenStatus(c *gin.Context) {
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- token, err := model.GetTokenByIds(tokenId, userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- expiredAt := token.ExpiredTime
- if expiredAt == -1 {
- expiredAt = 0
- }
- c.JSON(http.StatusOK, gin.H{
- "object": "credit_summary",
- "total_granted": token.RemainQuota,
- "total_used": 0, // not supported currently
- "total_available": token.RemainQuota,
- "expires_at": expiredAt * 1000,
- })
-}
-
-func GetTokenUsage(c *gin.Context) {
- authHeader := c.GetHeader("Authorization")
- if authHeader == "" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "No Authorization header",
- })
- return
- }
-
- parts := strings.Split(authHeader, " ")
- if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "Invalid Bearer token",
- })
- return
- }
- tokenKey := parts[1]
-
- token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- expiredAt := token.ExpiredTime
- if expiredAt == -1 {
- expiredAt = 0
- }
-
- c.JSON(http.StatusOK, gin.H{
- "code": true,
- "message": "ok",
- "data": gin.H{
- "object": "token_usage",
- "name": token.Name,
- "total_granted": token.RemainQuota + token.UsedQuota,
- "total_used": token.UsedQuota,
- "total_available": token.RemainQuota,
- "unlimited_quota": token.UnlimitedQuota,
- "model_limits": token.GetModelLimitsMap(),
- "model_limits_enabled": token.ModelLimitsEnabled,
- "expires_at": expiredAt,
- },
- })
-}
-
-func AddToken(c *gin.Context) {
- token := model.Token{}
- err := c.ShouldBindJSON(&token)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if len(token.Name) > 30 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "令牌名称过长",
- })
- return
- }
- key, err := common.GenerateKey()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成令牌失败",
- })
- common.SysLog("failed to generate token key: " + err.Error())
- return
- }
- cleanToken := model.Token{
- UserId: c.GetInt("id"),
- Name: token.Name,
- Key: key,
- CreatedTime: common.GetTimestamp(),
- AccessedTime: common.GetTimestamp(),
- ExpiredTime: token.ExpiredTime,
- RemainQuota: token.RemainQuota,
- UnlimitedQuota: token.UnlimitedQuota,
- ModelLimitsEnabled: token.ModelLimitsEnabled,
- ModelLimits: token.ModelLimits,
- AllowIps: token.AllowIps,
- Group: token.Group,
- }
- err = cleanToken.Insert()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func DeleteToken(c *gin.Context) {
- id, _ := strconv.Atoi(c.Param("id"))
- userId := c.GetInt("id")
- err := model.DeleteTokenById(id, userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func UpdateToken(c *gin.Context) {
- userId := c.GetInt("id")
- statusOnly := c.Query("status_only")
- token := model.Token{}
- err := c.ShouldBindJSON(&token)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if len(token.Name) > 30 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "令牌名称过长",
- })
- return
- }
- cleanToken, err := model.GetTokenByIds(token.Id, userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if token.Status == common.TokenStatusEnabled {
- if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
- })
- return
- }
- if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
- })
- return
- }
- }
- if statusOnly != "" {
- cleanToken.Status = token.Status
- } else {
- // If you add more fields, please also update token.Update()
- cleanToken.Name = token.Name
- cleanToken.ExpiredTime = token.ExpiredTime
- cleanToken.RemainQuota = token.RemainQuota
- cleanToken.UnlimitedQuota = token.UnlimitedQuota
- cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
- cleanToken.ModelLimits = token.ModelLimits
- cleanToken.AllowIps = token.AllowIps
- cleanToken.Group = token.Group
- }
- err = cleanToken.Update()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": cleanToken,
- })
- return
-}
-
-type TokenBatch struct {
- Ids []int `json:"ids"`
-}
-
-func DeleteTokenBatch(c *gin.Context) {
- tokenBatch := TokenBatch{}
- if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
- userId := c.GetInt("id")
- count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": count,
- })
-}
diff --git a/new-api/controller/topup.go b/new-api/controller/topup.go
deleted file mode 100644
index dd073f5429857c6dc0837a52768d0e5e5ba5a680..0000000000000000000000000000000000000000
--- a/new-api/controller/topup.go
+++ /dev/null
@@ -1,314 +0,0 @@
-package controller
-
-import (
- "fmt"
- "log"
- "net/url"
- "one-api/common"
- "one-api/logger"
- "one-api/model"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
- "one-api/setting/system_setting"
- "strconv"
- "sync"
- "time"
-
- "github.com/Calcium-Ion/go-epay/epay"
- "github.com/gin-gonic/gin"
- "github.com/samber/lo"
- "github.com/shopspring/decimal"
-)
-
-func GetTopUpInfo(c *gin.Context) {
- // 获取支付方式
- payMethods := operation_setting.PayMethods
-
- // 如果启用了 Stripe 支付,添加到支付方法列表
- if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
- // 检查是否已经包含 Stripe
- hasStripe := false
- for _, method := range payMethods {
- if method["type"] == "stripe" {
- hasStripe = true
- break
- }
- }
-
- if !hasStripe {
- stripeMethod := map[string]string{
- "name": "Stripe",
- "type": "stripe",
- "color": "rgba(var(--semi-purple-5), 1)",
- "min_topup": strconv.Itoa(setting.StripeMinTopUp),
- }
- payMethods = append(payMethods, stripeMethod)
- }
- }
-
- data := gin.H{
- "enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
- "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
- "pay_methods": payMethods,
- "min_topup": operation_setting.MinTopUp,
- "stripe_min_topup": setting.StripeMinTopUp,
- "amount_options": operation_setting.GetPaymentSetting().AmountOptions,
- "discount": operation_setting.GetPaymentSetting().AmountDiscount,
- }
- common.ApiSuccess(c, data)
-}
-
-type EpayRequest struct {
- Amount int64 `json:"amount"`
- PaymentMethod string `json:"payment_method"`
- TopUpCode string `json:"top_up_code"`
-}
-
-type AmountRequest struct {
- Amount int64 `json:"amount"`
- TopUpCode string `json:"top_up_code"`
-}
-
-func GetEpayClient() *epay.Client {
- if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
- return nil
- }
- withUrl, err := epay.NewClient(&epay.Config{
- PartnerID: operation_setting.EpayId,
- Key: operation_setting.EpayKey,
- }, operation_setting.PayAddress)
- if err != nil {
- return nil
- }
- return withUrl
-}
-
-func getPayMoney(amount int64, group string) float64 {
- dAmount := decimal.NewFromInt(amount)
-
- if !common.DisplayInCurrencyEnabled {
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
- dAmount = dAmount.Div(dQuotaPerUnit)
- }
-
- topupGroupRatio := common.GetTopupGroupRatio(group)
- if topupGroupRatio == 0 {
- topupGroupRatio = 1
- }
-
- dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
- dPrice := decimal.NewFromFloat(operation_setting.Price)
- // apply optional preset discount by the original request amount (if configured), default 1.0
- discount := 1.0
- if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
- if ds > 0 {
- discount = ds
- }
- }
- dDiscount := decimal.NewFromFloat(discount)
-
- payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
-
- return payMoney.InexactFloat64()
-}
-
-func getMinTopup() int64 {
- minTopup := operation_setting.MinTopUp
- if !common.DisplayInCurrencyEnabled {
- dMinTopup := decimal.NewFromInt(int64(minTopup))
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
- minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
- }
- return int64(minTopup)
-}
-
-func RequestEpay(c *gin.Context) {
- var req EpayRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
- return
- }
- if req.Amount < getMinTopup() {
- c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
- return
- }
-
- id := c.GetInt("id")
- group, err := model.GetUserGroup(id, true)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
- return
- }
- payMoney := getPayMoney(req.Amount, group)
- if payMoney < 0.01 {
- c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
- return
- }
-
- if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
- c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
- return
- }
-
- callBackAddress := service.GetCallbackAddress()
- returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
- notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
- tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
- tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
- client := GetEpayClient()
- if client == nil {
- c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
- return
- }
- uri, params, err := client.Purchase(&epay.PurchaseArgs{
- Type: req.PaymentMethod,
- ServiceTradeNo: tradeNo,
- Name: fmt.Sprintf("TUC%d", req.Amount),
- Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
- Device: epay.PC,
- NotifyUrl: notifyUrl,
- ReturnUrl: returnUrl,
- })
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
- return
- }
- amount := req.Amount
- if !common.DisplayInCurrencyEnabled {
- dAmount := decimal.NewFromInt(int64(amount))
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
- amount = dAmount.Div(dQuotaPerUnit).IntPart()
- }
- topUp := &model.TopUp{
- UserId: id,
- Amount: amount,
- Money: payMoney,
- TradeNo: tradeNo,
- CreateTime: time.Now().Unix(),
- Status: "pending",
- }
- err = topUp.Insert()
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
- return
- }
- c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
-}
-
-// tradeNo lock
-var orderLocks sync.Map
-var createLock sync.Mutex
-
-// LockOrder 尝试对给定订单号加锁
-func LockOrder(tradeNo string) {
- lock, ok := orderLocks.Load(tradeNo)
- if !ok {
- createLock.Lock()
- defer createLock.Unlock()
- lock, ok = orderLocks.Load(tradeNo)
- if !ok {
- lock = new(sync.Mutex)
- orderLocks.Store(tradeNo, lock)
- }
- }
- lock.(*sync.Mutex).Lock()
-}
-
-// UnlockOrder 释放给定订单号的锁
-func UnlockOrder(tradeNo string) {
- lock, ok := orderLocks.Load(tradeNo)
- if ok {
- lock.(*sync.Mutex).Unlock()
- }
-}
-
-func EpayNotify(c *gin.Context) {
- params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
- r[t] = c.Request.URL.Query().Get(t)
- return r
- }, map[string]string{})
- client := GetEpayClient()
- if client == nil {
- log.Println("易支付回调失败 未找到配置信息")
- _, err := c.Writer.Write([]byte("fail"))
- if err != nil {
- log.Println("易支付回调写入失败")
- return
- }
- }
- verifyInfo, err := client.Verify(params)
- if err == nil && verifyInfo.VerifyStatus {
- _, err := c.Writer.Write([]byte("success"))
- if err != nil {
- log.Println("易支付回调写入失败")
- }
- } else {
- _, err := c.Writer.Write([]byte("fail"))
- if err != nil {
- log.Println("易支付回调写入失败")
- }
- log.Println("易支付回调签名验证失败")
- return
- }
-
- if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
- log.Println(verifyInfo)
- LockOrder(verifyInfo.ServiceTradeNo)
- defer UnlockOrder(verifyInfo.ServiceTradeNo)
- topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
- if topUp == nil {
- log.Printf("易支付回调未找到订单: %v", verifyInfo)
- return
- }
- if topUp.Status == "pending" {
- topUp.Status = "success"
- err := topUp.Update()
- if err != nil {
- log.Printf("易支付回调更新订单失败: %v", topUp)
- return
- }
- //user, _ := model.GetUserById(topUp.UserId, false)
- //user.Quota += topUp.Amount * 500000
- dAmount := decimal.NewFromInt(int64(topUp.Amount))
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
- quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
- err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
- if err != nil {
- log.Printf("易支付回调更新用户失败: %v", topUp)
- return
- }
- log.Printf("易支付回调更新用户成功 %v", topUp)
- model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
- }
- } else {
- log.Printf("易支付异常回调: %v", verifyInfo)
- }
-}
-
-func RequestAmount(c *gin.Context) {
- var req AmountRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
- return
- }
-
- if req.Amount < getMinTopup() {
- c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
- return
- }
- id := c.GetInt("id")
- group, err := model.GetUserGroup(id, true)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
- return
- }
- payMoney := getPayMoney(req.Amount, group)
- if payMoney <= 0.01 {
- c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
- return
- }
- c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
-}
diff --git a/new-api/controller/topup_stripe.go b/new-api/controller/topup_stripe.go
deleted file mode 100644
index db6cc29942462f404d959891ba943f4708b3343f..0000000000000000000000000000000000000000
--- a/new-api/controller/topup_stripe.go
+++ /dev/null
@@ -1,286 +0,0 @@
-package controller
-
-import (
- "fmt"
- "io"
- "log"
- "net/http"
- "one-api/common"
- "one-api/model"
- "one-api/setting"
- "one-api/setting/operation_setting"
- "one-api/setting/system_setting"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/stripe/stripe-go/v81"
- "github.com/stripe/stripe-go/v81/checkout/session"
- "github.com/stripe/stripe-go/v81/webhook"
- "github.com/thanhpk/randstr"
-)
-
-const (
- PaymentMethodStripe = "stripe"
-)
-
-var stripeAdaptor = &StripeAdaptor{}
-
-type StripePayRequest struct {
- Amount int64 `json:"amount"`
- PaymentMethod string `json:"payment_method"`
-}
-
-type StripeAdaptor struct {
-}
-
-func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
- if req.Amount < getStripeMinTopup() {
- c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
- return
- }
- id := c.GetInt("id")
- group, err := model.GetUserGroup(id, true)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
- return
- }
- payMoney := getStripePayMoney(float64(req.Amount), group)
- if payMoney <= 0.01 {
- c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
- return
- }
- c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
-}
-
-func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
- if req.PaymentMethod != PaymentMethodStripe {
- c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
- return
- }
- if req.Amount < getStripeMinTopup() {
- c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
- return
- }
- if req.Amount > 10000 {
- c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
- return
- }
-
- id := c.GetInt("id")
- user, _ := model.GetUserById(id, false)
- chargedMoney := GetChargedAmount(float64(req.Amount), *user)
-
- reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
- referenceId := "ref_" + common.Sha1([]byte(reference))
-
- payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
- if err != nil {
- log.Println("获取Stripe Checkout支付链接失败", err)
- c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
- return
- }
-
- topUp := &model.TopUp{
- UserId: id,
- Amount: req.Amount,
- Money: chargedMoney,
- TradeNo: referenceId,
- CreateTime: time.Now().Unix(),
- Status: common.TopUpStatusPending,
- }
- err = topUp.Insert()
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
- return
- }
- c.JSON(200, gin.H{
- "message": "success",
- "data": gin.H{
- "pay_link": payLink,
- },
- })
-}
-
-func RequestStripeAmount(c *gin.Context) {
- var req StripePayRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
- return
- }
- stripeAdaptor.RequestAmount(c, &req)
-}
-
-func RequestStripePay(c *gin.Context) {
- var req StripePayRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
- return
- }
- stripeAdaptor.RequestPay(c, &req)
-}
-
-func StripeWebhook(c *gin.Context) {
- payload, err := io.ReadAll(c.Request.Body)
- if err != nil {
- log.Printf("解析Stripe Webhook参数失败: %v\n", err)
- c.AbortWithStatus(http.StatusServiceUnavailable)
- return
- }
-
- signature := c.GetHeader("Stripe-Signature")
- endpointSecret := setting.StripeWebhookSecret
- event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
- IgnoreAPIVersionMismatch: true,
- })
-
- if err != nil {
- log.Printf("Stripe Webhook验签失败: %v\n", err)
- c.AbortWithStatus(http.StatusBadRequest)
- return
- }
-
- switch event.Type {
- case stripe.EventTypeCheckoutSessionCompleted:
- sessionCompleted(event)
- case stripe.EventTypeCheckoutSessionExpired:
- sessionExpired(event)
- default:
- log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
- }
-
- c.Status(http.StatusOK)
-}
-
-func sessionCompleted(event stripe.Event) {
- customerId := event.GetObjectValue("customer")
- referenceId := event.GetObjectValue("client_reference_id")
- status := event.GetObjectValue("status")
- if "complete" != status {
- log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
- return
- }
-
- err := model.Recharge(referenceId, customerId)
- if err != nil {
- log.Println(err.Error(), referenceId)
- return
- }
-
- total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
- currency := strings.ToUpper(event.GetObjectValue("currency"))
- log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
-}
-
-func sessionExpired(event stripe.Event) {
- referenceId := event.GetObjectValue("client_reference_id")
- status := event.GetObjectValue("status")
- if "expired" != status {
- log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
- return
- }
-
- if len(referenceId) == 0 {
- log.Println("未提供支付单号")
- return
- }
-
- topUp := model.GetTopUpByTradeNo(referenceId)
- if topUp == nil {
- log.Println("充值订单不存在", referenceId)
- return
- }
-
- if topUp.Status != common.TopUpStatusPending {
- log.Println("充值订单状态错误", referenceId)
- }
-
- topUp.Status = common.TopUpStatusExpired
- err := topUp.Update()
- if err != nil {
- log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
- return
- }
-
- log.Println("充值订单已过期", referenceId)
-}
-
-func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
- if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
- return "", fmt.Errorf("无效的Stripe API密钥")
- }
-
- stripe.Key = setting.StripeApiSecret
-
- params := &stripe.CheckoutSessionParams{
- ClientReferenceID: stripe.String(referenceId),
- SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
- CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
- LineItems: []*stripe.CheckoutSessionLineItemParams{
- {
- Price: stripe.String(setting.StripePriceId),
- Quantity: stripe.Int64(amount),
- },
- },
- Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
- AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
- }
-
- if "" == customerId {
- if "" != email {
- params.CustomerEmail = stripe.String(email)
- }
-
- params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
- } else {
- params.Customer = stripe.String(customerId)
- }
-
- result, err := session.New(params)
- if err != nil {
- return "", err
- }
-
- return result.URL, nil
-}
-
-func GetChargedAmount(count float64, user model.User) float64 {
- topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
- if topUpGroupRatio == 0 {
- topUpGroupRatio = 1
- }
-
- return count * topUpGroupRatio
-}
-
-func getStripePayMoney(amount float64, group string) float64 {
- originalAmount := amount
- if !common.DisplayInCurrencyEnabled {
- amount = amount / common.QuotaPerUnit
- }
- // Using float64 for monetary calculations is acceptable here due to the small amounts involved
- topupGroupRatio := common.GetTopupGroupRatio(group)
- if topupGroupRatio == 0 {
- topupGroupRatio = 1
- }
- // apply optional preset discount by the original request amount (if configured), default 1.0
- discount := 1.0
- if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
- if ds > 0 {
- discount = ds
- }
- }
- payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
- return payMoney
-}
-
-func getStripeMinTopup() int64 {
- minTopup := setting.StripeMinTopUp
- if !common.DisplayInCurrencyEnabled {
- minTopup = minTopup * int(common.QuotaPerUnit)
- }
- return int64(minTopup)
-}
diff --git a/new-api/controller/twofa.go b/new-api/controller/twofa.go
deleted file mode 100644
index bb6f2d42fd4dacbeb5eaec7c02969fff63f09aca..0000000000000000000000000000000000000000
--- a/new-api/controller/twofa.go
+++ /dev/null
@@ -1,553 +0,0 @@
-package controller
-
-import (
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-// Setup2FARequest 设置2FA请求结构
-type Setup2FARequest struct {
- Code string `json:"code" binding:"required"`
-}
-
-// Verify2FARequest 验证2FA请求结构
-type Verify2FARequest struct {
- Code string `json:"code" binding:"required"`
-}
-
-// Setup2FAResponse 设置2FA响应结构
-type Setup2FAResponse struct {
- Secret string `json:"secret"`
- QRCodeData string `json:"qr_code_data"`
- BackupCodes []string `json:"backup_codes"`
-}
-
-// Setup2FA 初始化2FA设置
-func Setup2FA(c *gin.Context) {
- userId := c.GetInt("id")
-
- // 检查用户是否已经启用2FA
- existing, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if existing != nil && existing.IsEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已启用2FA,请先禁用后重新设置",
- })
- return
- }
-
- // 如果存在已禁用的2FA记录,先删除它
- if existing != nil && !existing.IsEnabled {
- if err := existing.Delete(); err != nil {
- common.ApiError(c, err)
- return
- }
- existing = nil // 重置为nil,后续将创建新记录
- }
-
- // 获取用户信息
- user, err := model.GetUserById(userId, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 生成TOTP密钥
- key, err := common.GenerateTOTPSecret(user.Username)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成2FA密钥失败",
- })
- common.SysLog("生成TOTP密钥失败: " + err.Error())
- return
- }
-
- // 生成备用码
- backupCodes, err := common.GenerateBackupCodes()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成备用码失败",
- })
- common.SysLog("生成备用码失败: " + err.Error())
- return
- }
-
- // 生成二维码数据
- qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
-
- // 创建或更新2FA记录(暂未启用)
- twoFA := &model.TwoFA{
- UserId: userId,
- Secret: key.Secret(),
- IsEnabled: false,
- }
-
- if existing != nil {
- // 更新现有记录
- twoFA.Id = existing.Id
- err = twoFA.Update()
- } else {
- // 创建新记录
- err = twoFA.Create()
- }
-
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 创建备用码记录
- if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "保存备用码失败",
- })
- common.SysLog("保存备用码失败: " + err.Error())
- return
- }
-
- // 记录操作日志
- model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
- "data": Setup2FAResponse{
- Secret: key.Secret(),
- QRCodeData: qrCodeData,
- BackupCodes: backupCodes,
- },
- })
-}
-
-// Enable2FA 启用2FA
-func Enable2FA(c *gin.Context) {
- var req Setup2FARequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
-
- userId := c.GetInt("id")
-
- // 获取2FA记录
- twoFA, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if twoFA == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "请先完成2FA初始化设置",
- })
- return
- }
- if twoFA.IsEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "2FA已经启用",
- })
- return
- }
-
- // 验证TOTP验证码
- cleanCode, err := common.ValidateNumericCode(req.Code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码或备用码错误,请重试",
- })
- return
- }
-
- // 启用2FA
- if err := twoFA.Enable(); err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 记录操作日志
- model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "两步验证启用成功",
- })
-}
-
-// Disable2FA 禁用2FA
-func Disable2FA(c *gin.Context) {
- var req Verify2FARequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
-
- userId := c.GetInt("id")
-
- // 获取2FA记录
- twoFA, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if twoFA == nil || !twoFA.IsEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户未启用2FA",
- })
- return
- }
-
- // 验证TOTP验证码或备用码
- cleanCode, err := common.ValidateNumericCode(req.Code)
- isValidTOTP := false
- isValidBackup := false
-
- if err == nil {
- // 尝试验证TOTP
- isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
- }
-
- if !isValidTOTP {
- // 尝试验证备用码
- isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- }
-
- if !isValidTOTP && !isValidBackup {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码或备用码错误,请重试",
- })
- return
- }
-
- // 禁用2FA
- if err := model.DisableTwoFA(userId); err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 记录操作日志
- model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "两步验证已禁用",
- })
-}
-
-// Get2FAStatus 获取用户2FA状态
-func Get2FAStatus(c *gin.Context) {
- userId := c.GetInt("id")
-
- twoFA, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- status := map[string]interface{}{
- "enabled": false,
- "locked": false,
- }
-
- if twoFA != nil {
- status["enabled"] = twoFA.IsEnabled
- status["locked"] = twoFA.IsLocked()
- if twoFA.IsEnabled {
- // 获取剩余备用码数量
- backupCount, err := model.GetUnusedBackupCodeCount(userId)
- if err != nil {
- common.SysLog("获取备用码数量失败: " + err.Error())
- } else {
- status["backup_codes_remaining"] = backupCount
- }
- }
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": status,
- })
-}
-
-// RegenerateBackupCodes 重新生成备用码
-func RegenerateBackupCodes(c *gin.Context) {
- var req Verify2FARequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
-
- userId := c.GetInt("id")
-
- // 获取2FA记录
- twoFA, err := model.GetTwoFAByUserId(userId)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if twoFA == nil || !twoFA.IsEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户未启用2FA",
- })
- return
- }
-
- // 验证TOTP验证码
- cleanCode, err := common.ValidateNumericCode(req.Code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- if !valid {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码或备用码错误,请重试",
- })
- return
- }
-
- // 生成新的备用码
- backupCodes, err := common.GenerateBackupCodes()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成备用码失败",
- })
- common.SysLog("生成备用码失败: " + err.Error())
- return
- }
-
- // 保存新的备用码
- if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "保存备用码失败",
- })
- common.SysLog("保存备用码失败: " + err.Error())
- return
- }
-
- // 记录操作日志
- model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "备用码重新生成成功",
- "data": map[string]interface{}{
- "backup_codes": backupCodes,
- },
- })
-}
-
-// Verify2FALogin 登录时验证2FA
-func Verify2FALogin(c *gin.Context) {
- var req Verify2FARequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "参数错误",
- })
- return
- }
-
- // 从会话中获取pending用户信息
- session := sessions.Default(c)
- pendingUserId := session.Get("pending_user_id")
- if pendingUserId == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "会话已过期,请重新登录",
- })
- return
- }
- userId, ok := pendingUserId.(int)
- if !ok {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "会话数据无效,请重新登录",
- })
- return
- }
- // 获取用户信息
- user, err := model.GetUserById(userId, false)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户不存在",
- })
- return
- }
-
- // 获取2FA记录
- twoFA, err := model.GetTwoFAByUserId(user.Id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if twoFA == nil || !twoFA.IsEnabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户未启用2FA",
- })
- return
- }
-
- // 验证TOTP验证码或备用码
- cleanCode, err := common.ValidateNumericCode(req.Code)
- isValidTOTP := false
- isValidBackup := false
-
- if err == nil {
- // 尝试验证TOTP
- isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
- }
-
- if !isValidTOTP {
- // 尝试验证备用码
- isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- }
-
- if !isValidTOTP && !isValidBackup {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码或备用码错误,请重试",
- })
- return
- }
-
- // 2FA验证成功,清理pending会话信息并完成登录
- session.Delete("pending_username")
- session.Delete("pending_user_id")
- session.Save()
-
- setupLogin(user, c)
-}
-
-// Admin2FAStats 管理员获取2FA统计信息
-func Admin2FAStats(c *gin.Context) {
- stats, err := model.GetTwoFAStats()
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": stats,
- })
-}
-
-// AdminDisable2FA 管理员强制禁用用户2FA
-func AdminDisable2FA(c *gin.Context) {
- userIdStr := c.Param("id")
- userId, err := strconv.Atoi(userIdStr)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户ID格式错误",
- })
- return
- }
-
- // 检查目标用户权限
- targetUser, err := model.GetUserById(userId, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- myRole := c.GetInt("role")
- if myRole <= targetUser.Role && myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权操作同级或更高级用户的2FA设置",
- })
- return
- }
-
- // 禁用2FA
- if err := model.DisableTwoFA(userId); err != nil {
- if errors.Is(err, model.ErrTwoFANotEnabled) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户未启用2FA",
- })
- return
- }
- common.ApiError(c, err)
- return
- }
-
- // 记录操作日志
- adminId := c.GetInt("id")
- model.RecordLog(userId, model.LogTypeManage,
- fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "用户2FA已被强制禁用",
- })
-}
diff --git a/new-api/controller/uptime_kuma.go b/new-api/controller/uptime_kuma.go
deleted file mode 100644
index a8cd9037bf83fc1808724a9f4db135714173647b..0000000000000000000000000000000000000000
--- a/new-api/controller/uptime_kuma.go
+++ /dev/null
@@ -1,154 +0,0 @@
-package controller
-
-import (
- "context"
- "encoding/json"
- "errors"
- "net/http"
- "one-api/setting/console_setting"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
- "golang.org/x/sync/errgroup"
-)
-
-const (
- requestTimeout = 30 * time.Second
- httpTimeout = 10 * time.Second
- uptimeKeySuffix = "_24"
- apiStatusPath = "/api/status-page/"
- apiHeartbeatPath = "/api/status-page/heartbeat/"
-)
-
-type Monitor struct {
- Name string `json:"name"`
- Uptime float64 `json:"uptime"`
- Status int `json:"status"`
- Group string `json:"group,omitempty"`
-}
-
-type UptimeGroupResult struct {
- CategoryName string `json:"categoryName"`
- Monitors []Monitor `json:"monitors"`
-}
-
-func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
- if err != nil {
- return err
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- return errors.New("non-200 status")
- }
-
- return json.NewDecoder(resp.Body).Decode(dest)
-}
-
-func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
- url, _ := groupConfig["url"].(string)
- slug, _ := groupConfig["slug"].(string)
- categoryName, _ := groupConfig["categoryName"].(string)
-
- result := UptimeGroupResult{
- CategoryName: categoryName,
- Monitors: []Monitor{},
- }
-
- if url == "" || slug == "" {
- return result
- }
-
- baseURL := strings.TrimSuffix(url, "/")
-
- var statusData struct {
- PublicGroupList []struct {
- ID int `json:"id"`
- Name string `json:"name"`
- MonitorList []struct {
- ID int `json:"id"`
- Name string `json:"name"`
- } `json:"monitorList"`
- } `json:"publicGroupList"`
- }
-
- var heartbeatData struct {
- HeartbeatList map[string][]struct {
- Status int `json:"status"`
- } `json:"heartbeatList"`
- UptimeList map[string]float64 `json:"uptimeList"`
- }
-
- g, gCtx := errgroup.WithContext(ctx)
- g.Go(func() error {
- return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
- })
- g.Go(func() error {
- return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
- })
-
- if g.Wait() != nil {
- return result
- }
-
- for _, pg := range statusData.PublicGroupList {
- if len(pg.MonitorList) == 0 {
- continue
- }
-
- for _, m := range pg.MonitorList {
- monitor := Monitor{
- Name: m.Name,
- Group: pg.Name,
- }
-
- monitorID := strconv.Itoa(m.ID)
-
- if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
- monitor.Uptime = uptime
- }
-
- if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
- monitor.Status = heartbeats[0].Status
- }
-
- result.Monitors = append(result.Monitors, monitor)
- }
- }
-
- return result
-}
-
-func GetUptimeKumaStatus(c *gin.Context) {
- groups := console_setting.GetUptimeKumaGroups()
- if len(groups) == 0 {
- c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
- return
- }
-
- ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
- defer cancel()
-
- client := &http.Client{Timeout: httpTimeout}
- results := make([]UptimeGroupResult, len(groups))
-
- g, gCtx := errgroup.WithContext(ctx)
- for i, group := range groups {
- i, group := i, group
- g.Go(func() error {
- results[i] = fetchGroupData(gCtx, client, group)
- return nil
- })
- }
-
- g.Wait()
- c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
-}
diff --git a/new-api/controller/usedata.go b/new-api/controller/usedata.go
deleted file mode 100644
index 5489c8dfdb8720bd02c6efcb5f67b80adcc8bf1a..0000000000000000000000000000000000000000
--- a/new-api/controller/usedata.go
+++ /dev/null
@@ -1,52 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetAllQuotaDates(c *gin.Context) {
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- username := c.Query("username")
- dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": dates,
- })
- return
-}
-
-func GetUserQuotaDates(c *gin.Context) {
- userId := c.GetInt("id")
- startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
- endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
- // 判断时间跨度是否超过 1 个月
- if endTimestamp-startTimestamp > 2592000 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "时间跨度不能超过 1 个月",
- })
- return
- }
- dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": dates,
- })
- return
-}
diff --git a/new-api/controller/user.go b/new-api/controller/user.go
deleted file mode 100644
index 5c9100a0f98d9bfd0cee7f7502ead0a28c93992a..0000000000000000000000000000000000000000
--- a/new-api/controller/user.go
+++ /dev/null
@@ -1,1242 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "net/url"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- "one-api/setting"
- "strconv"
- "strings"
- "sync"
-
- "one-api/constant"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-type LoginRequest struct {
- Username string `json:"username"`
- Password string `json:"password"`
-}
-
-func Login(c *gin.Context) {
- if !common.PasswordLoginEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "管理员关闭了密码登录",
- "success": false,
- })
- return
- }
- var loginRequest LoginRequest
- err := json.NewDecoder(c.Request.Body).Decode(&loginRequest)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": "无效的参数",
- "success": false,
- })
- return
- }
- username := loginRequest.Username
- password := loginRequest.Password
- if username == "" || password == "" {
- c.JSON(http.StatusOK, gin.H{
- "message": "无效的参数",
- "success": false,
- })
- return
- }
- user := model.User{
- Username: username,
- Password: password,
- }
- err = user.ValidateAndFill()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
-
- // 检查是否启用2FA
- if model.IsTwoFAEnabled(user.Id) {
- // 设置pending session,等待2FA验证
- session := sessions.Default(c)
- session.Set("pending_username", user.Username)
- session.Set("pending_user_id", user.Id)
- err := session.Save()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": "无法保存会话信息,请重试",
- "success": false,
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "message": "请输入两步验证码",
- "success": true,
- "data": map[string]interface{}{
- "require_2fa": true,
- },
- })
- return
- }
-
- setupLogin(&user, c)
-}
-
-// setup session & cookies and then return user info
-func setupLogin(user *model.User, c *gin.Context) {
- session := sessions.Default(c)
- session.Set("id", user.Id)
- session.Set("username", user.Username)
- session.Set("role", user.Role)
- session.Set("status", user.Status)
- session.Set("group", user.Group)
- err := session.Save()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": "无法保存会话信息,请重试",
- "success": false,
- })
- return
- }
- cleanUser := model.User{
- Id: user.Id,
- Username: user.Username,
- DisplayName: user.DisplayName,
- Role: user.Role,
- Status: user.Status,
- Group: user.Group,
- }
- c.JSON(http.StatusOK, gin.H{
- "message": "",
- "success": true,
- "data": cleanUser,
- })
-}
-
-func Logout(c *gin.Context) {
- session := sessions.Default(c)
- session.Clear()
- err := session.Save()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "message": "",
- "success": true,
- })
-}
-
-func Register(c *gin.Context) {
- if !common.RegisterEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "管理员关闭了新用户注册",
- "success": false,
- })
- return
- }
- if !common.PasswordRegisterEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
- "success": false,
- })
- return
- }
- var user model.User
- err := json.NewDecoder(c.Request.Body).Decode(&user)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- if err := common.Validate.Struct(&user); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "输入不合法 " + err.Error(),
- })
- return
- }
- if common.EmailVerificationEnabled {
- if user.Email == "" || user.VerificationCode == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员开启了邮箱验证,请输入邮箱地址和验证码",
- })
- return
- }
- if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码错误或已过期",
- })
- return
- }
- }
- exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "数据库错误,请稍后重试",
- })
- common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
- return
- }
- if exist {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户名已存在,或已注销",
- })
- return
- }
- affCode := user.AffCode // this code is the inviter's code, not the user's own code
- inviterId, _ := model.GetUserIdByAffCode(affCode)
- cleanUser := model.User{
- Username: user.Username,
- Password: user.Password,
- DisplayName: user.Username,
- InviterId: inviterId,
- Role: common.RoleCommonUser, // 明确设置角色为普通用户
- }
- if common.EmailVerificationEnabled {
- cleanUser.Email = user.Email
- }
- if err := cleanUser.Insert(inviterId); err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 获取插入后的用户ID
- var insertedUser model.User
- if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户注册失败或用户ID获取失败",
- })
- return
- }
- // 生成默认令牌
- if constant.GenerateDefaultToken {
- key, err := common.GenerateKey()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成默认令牌失败",
- })
- common.SysLog("failed to generate token key: " + err.Error())
- return
- }
- // 生成默认令牌
- token := model.Token{
- UserId: insertedUser.Id, // 使用插入后的用户ID
- Name: cleanUser.Username + "的初始令牌",
- Key: key,
- CreatedTime: common.GetTimestamp(),
- AccessedTime: common.GetTimestamp(),
- ExpiredTime: -1, // 永不过期
- RemainQuota: 500000, // 示例额度
- UnlimitedQuota: true,
- ModelLimitsEnabled: false,
- }
- if setting.DefaultUseAutoGroup {
- token.Group = "auto"
- }
- if err := token.Insert(); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "创建默认令牌失败",
- })
- return
- }
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func GetAllUsers(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- users, total, err := model.GetAllUsers(pageInfo)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(users)
-
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func SearchUsers(c *gin.Context) {
- keyword := c.Query("keyword")
- group := c.Query("group")
- pageInfo := common.GetPageQuery(c)
- users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(users)
- common.ApiSuccess(c, pageInfo)
- return
-}
-
-func GetUser(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user, err := model.GetUserById(id, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权获取同级或更高等级用户的信息",
- })
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": user,
- })
- return
-}
-
-func GenerateAccessToken(c *gin.Context) {
- id := c.GetInt("id")
- user, err := model.GetUserById(id, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // get rand int 28-32
- randI := common.GetRandomInt(4)
- key, err := common.GenerateRandomKey(29 + randI)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "生成失败",
- })
- common.SysLog("failed to generate key: " + err.Error())
- return
- }
- user.SetAccessToken(key)
-
- if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "请重试,系统生成的 UUID 竟然重复了!",
- })
- return
- }
-
- if err := user.Update(false); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": user.AccessToken,
- })
- return
-}
-
-type TransferAffQuotaRequest struct {
- Quota int `json:"quota" binding:"required"`
-}
-
-func TransferAffQuota(c *gin.Context) {
- id := c.GetInt("id")
- user, err := model.GetUserById(id, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- tran := TransferAffQuotaRequest{}
- if err := c.ShouldBindJSON(&tran); err != nil {
- common.ApiError(c, err)
- return
- }
- err = user.TransferAffQuotaToQuota(tran.Quota)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "划转失败 " + err.Error(),
- })
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "划转成功",
- })
-}
-
-func GetAffCode(c *gin.Context) {
- id := c.GetInt("id")
- user, err := model.GetUserById(id, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if user.AffCode == "" {
- user.AffCode = common.GetRandomString(4)
- if err := user.Update(false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": user.AffCode,
- })
- return
-}
-
-func GetSelf(c *gin.Context) {
- id := c.GetInt("id")
- userRole := c.GetInt("role")
- user, err := model.GetUserById(id, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
- user.Remark = ""
-
- // 计算用户权限信息
- permissions := calculateUserPermissions(userRole)
-
- // 获取用户设置并提取sidebar_modules
- userSetting := user.GetSetting()
-
- // 构建响应数据,包含用户信息和权限
- responseData := map[string]interface{}{
- "id": user.Id,
- "username": user.Username,
- "display_name": user.DisplayName,
- "role": user.Role,
- "status": user.Status,
- "email": user.Email,
- "github_id": user.GitHubId,
- "oidc_id": user.OidcId,
- "wechat_id": user.WeChatId,
- "telegram_id": user.TelegramId,
- "group": user.Group,
- "quota": user.Quota,
- "used_quota": user.UsedQuota,
- "request_count": user.RequestCount,
- "aff_code": user.AffCode,
- "aff_count": user.AffCount,
- "aff_quota": user.AffQuota,
- "aff_history_quota": user.AffHistoryQuota,
- "inviter_id": user.InviterId,
- "linux_do_id": user.LinuxDOId,
- "setting": user.Setting,
- "stripe_customer": user.StripeCustomer,
- "sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
- "permissions": permissions, // 新增权限字段
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": responseData,
- })
- return
-}
-
-// 计算用户权限的辅助函数
-func calculateUserPermissions(userRole int) map[string]interface{} {
- permissions := map[string]interface{}{}
-
- // 根据用户角色计算权限
- if userRole == common.RoleRootUser {
- // 超级管理员不需要边栏设置功能
- permissions["sidebar_settings"] = false
- permissions["sidebar_modules"] = map[string]interface{}{}
- } else if userRole == common.RoleAdminUser {
- // 管理员可以设置边栏,但不包含系统设置功能
- permissions["sidebar_settings"] = true
- permissions["sidebar_modules"] = map[string]interface{}{
- "admin": map[string]interface{}{
- "setting": false, // 管理员不能访问系统设置
- },
- }
- } else {
- // 普通用户只能设置个人功能,不包含管理员区域
- permissions["sidebar_settings"] = true
- permissions["sidebar_modules"] = map[string]interface{}{
- "admin": false, // 普通用户不能访问管理员区域
- }
- }
-
- return permissions
-}
-
-// 根据用户角色生成默认的边栏配置
-func generateDefaultSidebarConfig(userRole int) string {
- defaultConfig := map[string]interface{}{}
-
- // 聊天区域 - 所有用户都可以访问
- defaultConfig["chat"] = map[string]interface{}{
- "enabled": true,
- "playground": true,
- "chat": true,
- }
-
- // 控制台区域 - 所有用户都可以访问
- defaultConfig["console"] = map[string]interface{}{
- "enabled": true,
- "detail": true,
- "token": true,
- "log": true,
- "midjourney": true,
- "task": true,
- }
-
- // 个人中心区域 - 所有用户都可以访问
- defaultConfig["personal"] = map[string]interface{}{
- "enabled": true,
- "topup": true,
- "personal": true,
- }
-
- // 管理员区域 - 根据角色决定
- if userRole == common.RoleAdminUser {
- // 管理员可以访问管理员区域,但不能访问系统设置
- defaultConfig["admin"] = map[string]interface{}{
- "enabled": true,
- "channel": true,
- "models": true,
- "redemption": true,
- "user": true,
- "setting": false, // 管理员不能访问系统设置
- }
- } else if userRole == common.RoleRootUser {
- // 超级管理员可以访问所有功能
- defaultConfig["admin"] = map[string]interface{}{
- "enabled": true,
- "channel": true,
- "models": true,
- "redemption": true,
- "user": true,
- "setting": true,
- }
- }
- // 普通用户不包含admin区域
-
- // 转换为JSON字符串
- configBytes, err := json.Marshal(defaultConfig)
- if err != nil {
- common.SysLog("生成默认边栏配置失败: " + err.Error())
- return ""
- }
-
- return string(configBytes)
-}
-
-func GetUserModels(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- id = c.GetInt("id")
- }
- user, err := model.GetUserCache(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- groups := setting.GetUserUsableGroups(user.Group)
- var models []string
- for group := range groups {
- for _, g := range model.GetGroupEnabledModels(group) {
- if !common.StringsContains(models, g) {
- models = append(models, g)
- }
- }
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": models,
- })
- return
-}
-
-func UpdateUser(c *gin.Context) {
- var updatedUser model.User
- err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
- if err != nil || updatedUser.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- if updatedUser.Password == "" {
- updatedUser.Password = "$I_LOVE_U" // make Validator happy :)
- }
- if err := common.Validate.Struct(&updatedUser); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "输入不合法 " + err.Error(),
- })
- return
- }
- originUser, err := model.GetUserById(updatedUser.Id, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- myRole := c.GetInt("role")
- if myRole <= originUser.Role && myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权更新同权限等级或更高权限等级的用户信息",
- })
- return
- }
- if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
- })
- return
- }
- if updatedUser.Password == "$I_LOVE_U" {
- updatedUser.Password = "" // rollback to what it should be
- }
- updatePassword := updatedUser.Password != ""
- if err := updatedUser.Edit(updatePassword); err != nil {
- common.ApiError(c, err)
- return
- }
- if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func UpdateSelf(c *gin.Context) {
- var requestData map[string]interface{}
- err := json.NewDecoder(c.Request.Body).Decode(&requestData)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
-
- // 检查是否是sidebar_modules更新请求
- if sidebarModules, exists := requestData["sidebar_modules"]; exists {
- userId := c.GetInt("id")
- user, err := model.GetUserById(userId, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 获取当前用户设置
- currentSetting := user.GetSetting()
-
- // 更新sidebar_modules字段
- if sidebarModulesStr, ok := sidebarModules.(string); ok {
- currentSetting.SidebarModules = sidebarModulesStr
- }
-
- // 保存更新后的设置
- user.SetSetting(currentSetting)
- if err := user.Update(false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "更新设置失败: " + err.Error(),
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "设置更新成功",
- })
- return
- }
-
- // 原有的用户信息更新逻辑
- var user model.User
- requestDataBytes, err := json.Marshal(requestData)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- err = json.Unmarshal(requestDataBytes, &user)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
-
- if user.Password == "" {
- user.Password = "$I_LOVE_U" // make Validator happy :)
- }
- if err := common.Validate.Struct(&user); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "输入不合法 " + err.Error(),
- })
- return
- }
-
- cleanUser := model.User{
- Id: c.GetInt("id"),
- Username: user.Username,
- Password: user.Password,
- DisplayName: user.DisplayName,
- }
- if user.Password == "$I_LOVE_U" {
- user.Password = "" // rollback to what it should be
- cleanUser.Password = ""
- }
- updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if err := cleanUser.Update(updatePassword); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) {
- var currentUser *model.User
- currentUser, err = model.GetUserById(userId, true)
- if err != nil {
- return
- }
- if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
- err = fmt.Errorf("原密码错误")
- return
- }
- if newPassword == "" {
- return
- }
- updatePassword = true
- return
-}
-
-func DeleteUser(c *gin.Context) {
- id, err := strconv.Atoi(c.Param("id"))
- if err != nil {
- common.ApiError(c, err)
- return
- }
- originUser, err := model.GetUserById(id, false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- myRole := c.GetInt("role")
- if myRole <= originUser.Role {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权删除同权限等级或更高权限等级的用户",
- })
- return
- }
- err = model.HardDeleteUserById(id)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
- }
-}
-
-func DeleteSelf(c *gin.Context) {
- id := c.GetInt("id")
- user, _ := model.GetUserById(id, false)
-
- if user.Role == common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "不能删除超级管理员账户",
- })
- return
- }
-
- err := model.DeleteUserById(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-func CreateUser(c *gin.Context) {
- var user model.User
- err := json.NewDecoder(c.Request.Body).Decode(&user)
- user.Username = strings.TrimSpace(user.Username)
- if err != nil || user.Username == "" || user.Password == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- if err := common.Validate.Struct(&user); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "输入不合法 " + err.Error(),
- })
- return
- }
- if user.DisplayName == "" {
- user.DisplayName = user.Username
- }
- myRole := c.GetInt("role")
- if user.Role >= myRole {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法创建权限大于等于自己的用户",
- })
- return
- }
- // Even for admin users, we cannot fully trust them!
- cleanUser := model.User{
- Username: user.Username,
- Password: user.Password,
- DisplayName: user.DisplayName,
- Role: user.Role, // 保持管理员设置的角色
- }
- if err := cleanUser.Insert(0); err != nil {
- common.ApiError(c, err)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-type ManageRequest struct {
- Id int `json:"id"`
- Action string `json:"action"`
-}
-
-// ManageUser Only admin user can do this
-func ManageUser(c *gin.Context) {
- var req ManageRequest
- err := json.NewDecoder(c.Request.Body).Decode(&req)
-
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
- user := model.User{
- Id: req.Id,
- }
- // Fill attributes
- model.DB.Unscoped().Where(&user).First(&user)
- if user.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户不存在",
- })
- return
- }
- myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权更新同权限等级或更高权限等级的用户信息",
- })
- return
- }
- switch req.Action {
- case "disable":
- user.Status = common.UserStatusDisabled
- if user.Role == common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法禁用超级管理员用户",
- })
- return
- }
- case "enable":
- user.Status = common.UserStatusEnabled
- case "delete":
- if user.Role == common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法删除超级管理员用户",
- })
- return
- }
- if err := user.Delete(); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- case "promote":
- if myRole != common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "普通管理员用户无法提升其他用户为管理员",
- })
- return
- }
- if user.Role >= common.RoleAdminUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户已经是管理员",
- })
- return
- }
- user.Role = common.RoleAdminUser
- case "demote":
- if user.Role == common.RoleRootUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无法降级超级管理员用户",
- })
- return
- }
- if user.Role == common.RoleCommonUser {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该用户已经是普通用户",
- })
- return
- }
- user.Role = common.RoleCommonUser
- }
-
- if err := user.Update(false); err != nil {
- common.ApiError(c, err)
- return
- }
- clearUser := model.User{
- Role: user.Role,
- Status: user.Status,
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": clearUser,
- })
- return
-}
-
-func EmailBind(c *gin.Context) {
- email := c.Query("email")
- code := c.Query("code")
- if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "验证码错误或已过期",
- })
- return
- }
- session := sessions.Default(c)
- id := session.Get("id")
- user := model.User{
- Id: id.(int),
- }
- err := user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user.Email = email
- // no need to check if this email already taken, because we have used verification code to check it
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
-
-type topUpRequest struct {
- Key string `json:"key"`
-}
-
-var topUpLocks sync.Map
-var topUpCreateLock sync.Mutex
-
-type topUpTryLock struct {
- ch chan struct{}
-}
-
-func newTopUpTryLock() *topUpTryLock {
- return &topUpTryLock{ch: make(chan struct{}, 1)}
-}
-
-func (l *topUpTryLock) TryLock() bool {
- select {
- case l.ch <- struct{}{}:
- return true
- default:
- return false
- }
-}
-
-func (l *topUpTryLock) Unlock() {
- select {
- case <-l.ch:
- default:
- }
-}
-
-func getTopUpLock(userID int) *topUpTryLock {
- if v, ok := topUpLocks.Load(userID); ok {
- return v.(*topUpTryLock)
- }
- topUpCreateLock.Lock()
- defer topUpCreateLock.Unlock()
- if v, ok := topUpLocks.Load(userID); ok {
- return v.(*topUpTryLock)
- }
- l := newTopUpTryLock()
- topUpLocks.Store(userID, l)
- return l
-}
-
-func TopUp(c *gin.Context) {
- id := c.GetInt("id")
- lock := getTopUpLock(id)
- if !lock.TryLock() {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "充值处理中,请稍后重试",
- })
- return
- }
- defer lock.Unlock()
- req := topUpRequest{}
- err := c.ShouldBindJSON(&req)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- quota, err := model.Redeem(req.Key, id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": quota,
- })
-}
-
-type UpdateUserSettingRequest struct {
- QuotaWarningType string `json:"notify_type"`
- QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
- WebhookUrl string `json:"webhook_url,omitempty"`
- WebhookSecret string `json:"webhook_secret,omitempty"`
- NotificationEmail string `json:"notification_email,omitempty"`
- BarkUrl string `json:"bark_url,omitempty"`
- AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
- RecordIpLog bool `json:"record_ip_log"`
-}
-
-func UpdateUserSetting(c *gin.Context) {
- var req UpdateUserSettingRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的参数",
- })
- return
- }
-
- // 验证预警类型
- if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的预警类型",
- })
- return
- }
-
- // 验证预警阈值
- if req.QuotaWarningThreshold <= 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "预警阈值必须大于0",
- })
- return
- }
-
- // 如果是webhook类型,验证webhook地址
- if req.QuotaWarningType == dto.NotifyTypeWebhook {
- if req.WebhookUrl == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Webhook地址不能为空",
- })
- return
- }
- // 验证URL格式
- if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的Webhook地址",
- })
- return
- }
- }
-
- // 如果是邮件类型,验证邮箱地址
- if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
- // 验证邮箱格式
- if !strings.Contains(req.NotificationEmail, "@") {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的邮箱地址",
- })
- return
- }
- }
-
- // 如果是Bark类型,验证Bark URL
- if req.QuotaWarningType == dto.NotifyTypeBark {
- if req.BarkUrl == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Bark推送URL不能为空",
- })
- return
- }
- // 验证URL格式
- if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无效的Bark推送URL",
- })
- return
- }
- // 检查是否是HTTP或HTTPS
- if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Bark推送URL必须以http://或https://开头",
- })
- return
- }
- }
-
- userId := c.GetInt("id")
- user, err := model.GetUserById(userId, true)
- if err != nil {
- common.ApiError(c, err)
- return
- }
-
- // 构建设置
- settings := dto.UserSetting{
- NotifyType: req.QuotaWarningType,
- QuotaWarningThreshold: req.QuotaWarningThreshold,
- AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
- RecordIpLog: req.RecordIpLog,
- }
-
- // 如果是webhook类型,添加webhook相关设置
- if req.QuotaWarningType == dto.NotifyTypeWebhook {
- settings.WebhookUrl = req.WebhookUrl
- if req.WebhookSecret != "" {
- settings.WebhookSecret = req.WebhookSecret
- }
- }
-
- // 如果提供了通知邮箱,添加到设置中
- if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
- settings.NotificationEmail = req.NotificationEmail
- }
-
- // 如果是Bark类型,添加Bark URL到设置中
- if req.QuotaWarningType == dto.NotifyTypeBark {
- settings.BarkUrl = req.BarkUrl
- }
-
- // 更新用户设置
- user.SetSetting(settings)
- if err := user.Update(false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "更新设置失败: " + err.Error(),
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "设置已更新",
- })
-}
diff --git a/new-api/controller/vendor_meta.go b/new-api/controller/vendor_meta.go
deleted file mode 100644
index f79c7c75241df93c0c7f06dbe757aec7312a06bf..0000000000000000000000000000000000000000
--- a/new-api/controller/vendor_meta.go
+++ /dev/null
@@ -1,124 +0,0 @@
-package controller
-
-import (
- "strconv"
-
- "one-api/common"
- "one-api/model"
-
- "github.com/gin-gonic/gin"
-)
-
-// GetAllVendors 获取供应商列表(分页)
-func GetAllVendors(c *gin.Context) {
- pageInfo := common.GetPageQuery(c)
- vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- var total int64
- model.DB.Model(&model.Vendor{}).Count(&total)
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(vendors)
- common.ApiSuccess(c, pageInfo)
-}
-
-// SearchVendors 搜索供应商
-func SearchVendors(c *gin.Context) {
- keyword := c.Query("keyword")
- pageInfo := common.GetPageQuery(c)
- vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
- if err != nil {
- common.ApiError(c, err)
- return
- }
- pageInfo.SetTotal(int(total))
- pageInfo.SetItems(vendors)
- common.ApiSuccess(c, pageInfo)
-}
-
-// GetVendorMeta 根据 ID 获取供应商
-func GetVendorMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- v, err := model.GetVendorByID(id)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, v)
-}
-
-// CreateVendorMeta 新建供应商
-func CreateVendorMeta(c *gin.Context) {
- var v model.Vendor
- if err := c.ShouldBindJSON(&v); err != nil {
- common.ApiError(c, err)
- return
- }
- if v.Name == "" {
- common.ApiErrorMsg(c, "供应商名称不能为空")
- return
- }
- // 创建前先检查名称
- if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "供应商名称已存在")
- return
- }
-
- if err := v.Insert(); err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, &v)
-}
-
-// UpdateVendorMeta 更新供应商
-func UpdateVendorMeta(c *gin.Context) {
- var v model.Vendor
- if err := c.ShouldBindJSON(&v); err != nil {
- common.ApiError(c, err)
- return
- }
- if v.Id == 0 {
- common.ApiErrorMsg(c, "缺少供应商 ID")
- return
- }
- // 名称冲突检查
- if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
- common.ApiError(c, err)
- return
- } else if dup {
- common.ApiErrorMsg(c, "供应商名称已存在")
- return
- }
-
- if err := v.Update(); err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, &v)
-}
-
-// DeleteVendorMeta 删除供应商
-func DeleteVendorMeta(c *gin.Context) {
- idStr := c.Param("id")
- id, err := strconv.Atoi(idStr)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
- common.ApiError(c, err)
- return
- }
- common.ApiSuccess(c, nil)
-}
diff --git a/new-api/controller/wechat.go b/new-api/controller/wechat.go
deleted file mode 100644
index 1f325cf437730b8e0e39eb3f2245df8538306094..0000000000000000000000000000000000000000
--- a/new-api/controller/wechat.go
+++ /dev/null
@@ -1,168 +0,0 @@
-package controller
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strconv"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-type wechatLoginResponse struct {
- Success bool `json:"success"`
- Message string `json:"message"`
- Data string `json:"data"`
-}
-
-func getWeChatIdByCode(code string) (string, error) {
- if code == "" {
- return "", errors.New("无效的参数")
- }
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
- if err != nil {
- return "", err
- }
- req.Header.Set("Authorization", common.WeChatServerToken)
- client := http.Client{
- Timeout: 5 * time.Second,
- }
- httpResponse, err := client.Do(req)
- if err != nil {
- return "", err
- }
- defer httpResponse.Body.Close()
- var res wechatLoginResponse
- err = json.NewDecoder(httpResponse.Body).Decode(&res)
- if err != nil {
- return "", err
- }
- if !res.Success {
- return "", errors.New(res.Message)
- }
- if res.Data == "" {
- return "", errors.New("验证码错误或已过期")
- }
- return res.Data, nil
-}
-
-func WeChatAuth(c *gin.Context) {
- if !common.WeChatAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "管理员未开启通过微信登录以及注册",
- "success": false,
- })
- return
- }
- code := c.Query("code")
- wechatId, err := getWeChatIdByCode(code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
- user := model.User{
- WeChatId: wechatId,
- }
- if model.IsWeChatIdAlreadyTaken(wechatId) {
- err := user.FillUserByWeChatId()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- if user.Id == 0 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已注销",
- })
- return
- }
- } else {
- if common.RegisterEnabled {
- user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
- user.DisplayName = "WeChat User"
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
-
- if err := user.Insert(0); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "管理员关闭了新用户注册",
- })
- return
- }
- }
-
- if user.Status != common.UserStatusEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "用户已被封禁",
- "success": false,
- })
- return
- }
- setupLogin(&user, c)
-}
-
-func WeChatBind(c *gin.Context) {
- if !common.WeChatAuthEnabled {
- c.JSON(http.StatusOK, gin.H{
- "message": "管理员未开启通过微信登录以及注册",
- "success": false,
- })
- return
- }
- code := c.Query("code")
- wechatId, err := getWeChatIdByCode(code)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": err.Error(),
- "success": false,
- })
- return
- }
- if model.IsWeChatIdAlreadyTaken(wechatId) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "该微信账号已被绑定",
- })
- return
- }
- session := sessions.Default(c)
- id := session.Get("id")
- user := model.User{
- Id: id.(int),
- }
- err = user.FillUserById()
- if err != nil {
- common.ApiError(c, err)
- return
- }
- user.WeChatId = wechatId
- err = user.Update(false)
- if err != nil {
- common.ApiError(c, err)
- return
- }
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- })
- return
-}
diff --git a/new-api/docker-compose.yml b/new-api/docker-compose.yml
deleted file mode 100644
index 62603cf061a3d15430457cf5c92e838602017a8a..0000000000000000000000000000000000000000
--- a/new-api/docker-compose.yml
+++ /dev/null
@@ -1,52 +0,0 @@
-version: '3.4'
-
-services:
- new-api:
- image: calciumion/new-api:latest
- container_name: new-api
- restart: always
- command: --log-dir /app/logs
- ports:
- - "3000:3000"
- volumes:
- - ./data:/data
- - ./logs:/app/logs
- environment:
- - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- - REDIS_CONN_STRING=redis://redis
- - TZ=Asia/Shanghai
- - ERROR_LOG_ENABLED=true # 是否启用错误日志记录
- # - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
- # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
- # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
- # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
- # - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
-
- depends_on:
- - redis
- - mysql
- healthcheck:
- test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
- interval: 30s
- timeout: 10s
- retries: 3
-
- redis:
- image: redis:latest
- container_name: redis
- restart: always
-
- mysql:
- image: mysql:8.2
- container_name: mysql
- restart: always
- environment:
- MYSQL_ROOT_PASSWORD: 123456 # Ensure this matches the password in SQL_DSN
- MYSQL_DATABASE: new-api
- volumes:
- - mysql_data:/var/lib/mysql
- # ports:
- # - "3306:3306" # If you want to access MySQL from outside Docker, uncomment
-
-volumes:
- mysql_data:
diff --git a/new-api/docs/api/api_auth.md b/new-api/docs/api/api_auth.md
deleted file mode 100644
index 220d427b79ae65cfe3bc60f6c2fb988814ec7294..0000000000000000000000000000000000000000
--- a/new-api/docs/api/api_auth.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# API 鉴权文档
-
-## 认证方式
-
-### Access Token
-
-对于需要鉴权的 API 接口,必须同时提供以下两个请求头来进行 Access Token 认证:
-
-1. **请求头中的 `Authorization` 字段**
-
- 将 Access Token 放置于 HTTP 请求头部的 `Authorization` 字段中,格式如下:
-
- ```
- Authorization:
- ```
-
- 其中 `` 需要替换为实际的 Access Token 值。
-
-2. **请求头中的 `New-Api-User` 字段**
-
- 将用户 ID 放置于 HTTP 请求头部的 `New-Api-User` 字段中,格式如下:
-
- ```
- New-Api-User:
- ```
-
- 其中 `` 需要替换为实际的用户 ID。
-
-**注意:**
-
-* **必须同时提供 `Authorization` 和 `New-Api-User` 两个请求头才能通过鉴权。**
-* 如果只提供其中一个请求头,或者两个请求头都未提供,则会返回 `401 Unauthorized` 错误。
-* 如果 `Authorization` 中的 Access Token 无效,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,access token 无效”。
-* 如果 `New-Api-User` 中的用户 ID 与 Access Token 不匹配,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,与登录用户不匹配,请重新登录”。
-* 如果没有提供 `New-Api-User` 请求头,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,未提供 New-Api-User”。
-* 如果 `New-Api-User` 请求头格式错误,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,New-Api-User 格式错误”。
-* 如果用户已被禁用,则会返回 `403 Forbidden` 错误,并提示“用户已被封禁”。
-* 如果用户权限不足,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,权限不足”。
-* 如果用户信息无效,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,用户信息无效”。
-
-## Curl 示例
-
-假设您的 Access Token 为 `access_token`,用户 ID 为 `123`,要访问的 API 接口为 `/api/user/self`,则可以使用以下 curl 命令:
-
-```bash
-curl -X GET \
- -H "Authorization: access_token" \
- -H "New-Api-User: 123" \
- https://your-domain.com/api/user/self
-```
-
-请将 `access_token`、`123` 和 `https://your-domain.com` 替换为实际的值。
-
diff --git a/new-api/docs/api/web_api.md b/new-api/docs/api/web_api.md
deleted file mode 100644
index d8d6277d591632e2bd76464c2e353f8ccad65af6..0000000000000000000000000000000000000000
--- a/new-api/docs/api/web_api.md
+++ /dev/null
@@ -1,197 +0,0 @@
-# New API – Web 界面后端接口文档
-
-> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
->
-> 接口前缀统一为 `https://`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
->
-> 鉴权级别说明:
-> * **公开** – 不需要登录即可调用
-> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
-> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
-> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
-
----
-
-## 1. 初始化 / 系统状态
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/setup | 公开 | 获取系统初始化状态 |
-| POST | /api/setup | 公开 | 完成首次安装向导 |
-| GET | /api/status | 公开 | 获取运行状态摘要 |
-| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
-| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
-
-## 2. 公共信息
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/models | 用户 | 获取前端可用模型列表 |
-| GET | /api/notice | 公开 | 获取公告栏内容 |
-| GET | /api/about | 公开 | 关于页面信息 |
-| GET | /api/home_page_content | 公开 | 首页自定义内容 |
-| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
-| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
-
-## 3. 邮件 / 身份验证
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
-| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
-| POST | /api/user/reset | 公开 | 提交重置密码请求 |
-
-## 4. OAuth / 第三方登录
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
-| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
-| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
-| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
-| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
-| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
-| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
-| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
-| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
-
-## 5. 用户模块
-### 5.1 账号注册/登录
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| POST | /api/user/register | 公开 | 注册新账号 |
-| POST | /api/user/login | 公开 | 用户登录 |
-| GET | /api/user/logout | 用户 | 退出登录 |
-| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
-| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
-
-### 5.2 用户自身操作 (需登录)
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
-| GET | /api/user/self | 用户 | 获取个人资料 |
-| GET | /api/user/models | 用户 | 获取模型可见性 |
-| PUT | /api/user/self | 用户 | 修改个人资料 |
-| DELETE | /api/user/self | 用户 | 注销账号 |
-| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
-| GET | /api/user/aff | 用户 | 获取推广码信息 |
-| POST | /api/user/topup | 用户 | 余额直充 |
-| POST | /api/user/pay | 用户 | 提交支付订单 |
-| POST | /api/user/amount | 用户 | 余额支付 |
-| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
-| PUT | /api/user/setting | 用户 | 更新用户设置 |
-
-### 5.3 管理员用户管理
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/user/ | 管理员 | 获取全部用户列表 |
-| GET | /api/user/search | 管理员 | 搜索用户 |
-| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
-| POST | /api/user/ | 管理员 | 创建用户 |
-| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
-| PUT | /api/user/ | 管理员 | 更新用户 |
-| DELETE | /api/user/:id | 管理员 | 删除用户 |
-
-## 6. 站点选项 (Root)
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/option/ | Root | 获取全局配置 |
-| PUT | /api/option/ | Root | 更新全局配置 |
-| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
-| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
-
-## 7. 模型倍率同步 (Root)
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
-| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
-
-## 8. 渠道管理 (管理员)
-| 方法 | 路径 | 说明 |
-|------|------|------|
-| GET | /api/channel/ | 获取渠道列表 |
-| GET | /api/channel/search | 搜索渠道 |
-| GET | /api/channel/models | 查询渠道模型能力 |
-| GET | /api/channel/models_enabled | 查询启用模型能力 |
-| GET | /api/channel/:id | 获取单个渠道 |
-| GET | /api/channel/test | 批量测试渠道连通性 |
-| GET | /api/channel/test/:id | 单个渠道测试 |
-| GET | /api/channel/update_balance | 批量刷新余额 |
-| GET | /api/channel/update_balance/:id | 单个刷新余额 |
-| POST | /api/channel/ | 新增渠道 |
-| PUT | /api/channel/ | 更新渠道 |
-| DELETE | /api/channel/disabled | 删除已禁用渠道 |
-| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
-| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
-| PUT | /api/channel/tag | 编辑渠道标签 |
-| DELETE | /api/channel/:id | 删除渠道 |
-| POST | /api/channel/batch | 批量删除渠道 |
-| POST | /api/channel/fix | 修复渠道能力表 |
-| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
-| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
-| POST | /api/channel/batch/tag | 批量设置渠道标签 |
-| GET | /api/channel/tag/models | 根据标签获取模型 |
-| POST | /api/channel/copy/:id | 复制渠道 |
-
-## 9. Token 管理
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/token/ | 用户 | 获取全部 Token |
-| GET | /api/token/search | 用户 | 搜索 Token |
-| GET | /api/token/:id | 用户 | 获取单个 Token |
-| POST | /api/token/ | 用户 | 创建 Token |
-| PUT | /api/token/ | 用户 | 更新 Token |
-| DELETE | /api/token/:id | 用户 | 删除 Token |
-| POST | /api/token/batch | 用户 | 批量删除 Token |
-
-## 10. 兑换码管理 (管理员)
-| 方法 | 路径 | 说明 |
-|------|------|------|
-| GET | /api/redemption/ | 获取兑换码列表 |
-| GET | /api/redemption/search | 搜索兑换码 |
-| GET | /api/redemption/:id | 获取单个兑换码 |
-| POST | /api/redemption/ | 创建兑换码 |
-| PUT | /api/redemption/ | 更新兑换码 |
-| DELETE | /api/redemption/invalid | 删除无效兑换码 |
-| DELETE | /api/redemption/:id | 删除兑换码 |
-
-## 11. 日志
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/log/ | 管理员 | 获取全部日志 |
-| DELETE | /api/log/ | 管理员 | 删除历史日志 |
-| GET | /api/log/stat | 管理员 | 日志统计 |
-| GET | /api/log/self/stat | 用户 | 我的日志统计 |
-| GET | /api/log/search | 管理员 | 搜索全部日志 |
-| GET | /api/log/self | 用户 | 获取我的日志 |
-| GET | /api/log/self/search | 用户 | 搜索我的日志 |
-| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
-
-## 12. 数据统计
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
-| GET | /api/data/self | 用户 | 我的用量按日期统计 |
-
-## 13. 分组
-| GET | /api/group/ | 管理员 | 获取全部分组列表 |
-
-## 14. Midjourney 任务
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
-| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
-
-## 15. 任务中心
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /api/task/self | 用户 | 获取我的任务 |
-| GET | /api/task/ | 管理员 | 获取全部任务 |
-
-## 16. 账户计费面板 (Dashboard)
-| 方法 | 路径 | 鉴权 | 说明 |
-|------|------|------|------|
-| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
-| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
-| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
-| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
-
----
-
-> **更新日期**:2025.07.17
diff --git a/new-api/docs/channel/other_setting.md b/new-api/docs/channel/other_setting.md
deleted file mode 100644
index 2c9b999aa53c937f46a06f680404fff1e3e1677a..0000000000000000000000000000000000000000
--- a/new-api/docs/channel/other_setting.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# 渠道而外设置说明
-
-该配置用于设置一些额外的渠道参数,可以通过 JSON 对象进行配置。主要包含以下两个设置项:
-
-1. force_format
- - 用于标识是否对数据进行强制格式化为 OpenAI 格式
- - 类型为布尔值,设置为 true 时启用强制格式化
-
-2. proxy
- - 用于配置网络代理
- - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
-
-3. thinking_to_content
- - 用于标识是否将思考内容`reasoning_content`转换为``标签拼接到内容中返回
- - 类型为布尔值,设置为 true 时启用思考内容转换
-
---------------------------------------------------------------
-
-## JSON 格式示例
-
-以下是一个示例配置,启用强制格式化并设置了代理地址:
-
-```json
-{
- "force_format": true,
- "thinking_to_content": true,
- "proxy": "socks5://xxxxxxx"
-}
-```
-
---------------------------------------------------------------
-
-通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。
diff --git a/new-api/docs/images/aliyun.png b/new-api/docs/images/aliyun.png
deleted file mode 100644
index 6266bfbff3603b0969ee557143a5ec18e7d9e045..0000000000000000000000000000000000000000
Binary files a/new-api/docs/images/aliyun.png and /dev/null differ
diff --git a/new-api/docs/images/cherry-studio.png b/new-api/docs/images/cherry-studio.png
deleted file mode 100644
index a58a77137ffee47008b2dbf6990ba1c31b523dc9..0000000000000000000000000000000000000000
Binary files a/new-api/docs/images/cherry-studio.png and /dev/null differ
diff --git a/new-api/docs/images/io-net.png b/new-api/docs/images/io-net.png
deleted file mode 100644
index fb47534d3d60553ab7a0b2331ead32c652e30a2f..0000000000000000000000000000000000000000
Binary files a/new-api/docs/images/io-net.png and /dev/null differ
diff --git a/new-api/docs/images/pku.png b/new-api/docs/images/pku.png
deleted file mode 100644
index a058c3ce2338608f24c8051925850d89d71dc926..0000000000000000000000000000000000000000
Binary files a/new-api/docs/images/pku.png and /dev/null differ
diff --git a/new-api/docs/images/ucloud.png b/new-api/docs/images/ucloud.png
deleted file mode 100644
index 16cca7642bdca8096c2045904c2b55efe5e9ef7c..0000000000000000000000000000000000000000
Binary files a/new-api/docs/images/ucloud.png and /dev/null differ
diff --git a/new-api/docs/installation/BT.md b/new-api/docs/installation/BT.md
deleted file mode 100644
index e57cdab792db4d347622bbe9d915e8caf1fc7a42..0000000000000000000000000000000000000000
--- a/new-api/docs/installation/BT.md
+++ /dev/null
@@ -1,3 +0,0 @@
-密钥为环境变量SESSION_SECRET
-
-
diff --git a/new-api/docs/models/Midjourney.md b/new-api/docs/models/Midjourney.md
deleted file mode 100644
index 3ccce63dee7f74228d2bef446b7bd3755a594712..0000000000000000000000000000000000000000
--- a/new-api/docs/models/Midjourney.md
+++ /dev/null
@@ -1,82 +0,0 @@
-# Midjourney Proxy API文档
-
-**简介**:Midjourney Proxy API文档
-
-## 接口列表
-支持的接口如下:
-+ [x] /mj/submit/imagine
-+ [x] /mj/submit/change
-+ [x] /mj/submit/blend
-+ [x] /mj/submit/describe
-+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
-+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
-+ [x] /task/list-by-condition
-+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
-+ [x] /mj/submit/modal
-+ [x] /mj/submit/shorten
-+ [x] /mj/task/{id}/image-seed
-+ [x] /mj/insight-face/swap (InsightFace)
-
-## 模型列表
-
-### midjourney-proxy支持
-
-- mj_imagine (绘图)
-- mj_variation (变换)
-- mj_reroll (重绘)
-- mj_blend (混合)
-- mj_upscale (放大)
-- mj_describe (图生文)
-
-### 仅midjourney-proxy-plus支持
-
-- mj_zoom (比例变焦)
-- mj_shorten (提示词缩短)
-- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
-- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
-- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
-- mj_high_variation (强变换)
-- mj_low_variation (弱变换)
-- mj_pan (平移)
-- swap_face (换脸)
-
-## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
-```json
-{
- "mj_imagine": 0.1,
- "mj_variation": 0.1,
- "mj_reroll": 0.1,
- "mj_blend": 0.1,
- "mj_modal": 0.1,
- "mj_zoom": 0.1,
- "mj_shorten": 0.1,
- "mj_high_variation": 0.1,
- "mj_low_variation": 0.1,
- "mj_pan": 0.1,
- "mj_inpaint": 0,
- "mj_custom_zoom": 0,
- "mj_describe": 0.05,
- "mj_upscale": 0.05,
- "swap_face": 0.05
-}
-```
-其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
-
-## 渠道设置
-
-### 对接 midjourney-proxy(plus)
-
-1.
-
-部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
-
-2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
- ,模型请参考上方模型列表
-3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080
-4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
-
-### 对接上游new api
-
-1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
-2. **代理**填写上游new api的地址,例如:http://localhost:3000
-3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/new-api/docs/models/Rerank.md b/new-api/docs/models/Rerank.md
deleted file mode 100644
index db16a92d9733fbf7e7736d1969800e9a0a7c4071..0000000000000000000000000000000000000000
--- a/new-api/docs/models/Rerank.md
+++ /dev/null
@@ -1,62 +0,0 @@
-# Rerank API文档
-
-**简介**:Rerank API文档
-
-## 接入Dify
-模型供应商选择Jina,按要求填写模型信息即可接入Dify。
-
-## 请求方式
-
-Post: /v1/rerank
-
-Request:
-
-```json
-{
- "model": "jina-reranker-v2-base-multilingual",
- "query": "What is the capital of the United States?",
- "top_n": 3,
- "documents": [
- "Carson City is the capital city of the American state of Nevada.",
- "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
- "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
- "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
- "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
- ]
-}
-```
-
-Response:
-
-```json
-{
- "results": [
- {
- "document": {
- "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
- },
- "index": 2,
- "relevance_score": 0.9999702
- },
- {
- "document": {
- "text": "Carson City is the capital city of the American state of Nevada."
- },
- "index": 0,
- "relevance_score": 0.67800725
- },
- {
- "document": {
- "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
- },
- "index": 3,
- "relevance_score": 0.02800752
- }
- ],
- "usage": {
- "prompt_tokens": 158,
- "completion_tokens": 0,
- "total_tokens": 158
- }
-}
-```
\ No newline at end of file
diff --git a/new-api/docs/models/Suno.md b/new-api/docs/models/Suno.md
deleted file mode 100644
index 3d9720a21d0fe413c799cc38e1f9ce824110dc87..0000000000000000000000000000000000000000
--- a/new-api/docs/models/Suno.md
+++ /dev/null
@@ -1,44 +0,0 @@
-# Suno API文档
-
-**简介**:Suno API文档
-
-## 接口列表
-支持的接口如下:
-+ [x] /suno/submit/music
-+ [x] /suno/submit/lyrics
-+ [x] /suno/fetch
-+ [x] /suno/fetch/:id
-
-## 模型列表
-
-### Suno API支持
-
-- suno_music (自定义模式、灵感模式、续写)
-- suno_lyrics (生成歌词)
-
-
-## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
-```json
-{
- "suno_music": 0.3,
- "suno_lyrics": 0.01
-}
-```
-
-## 渠道设置
-
-### 对接 Suno API
-
-1.
-部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
-
-2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
- ,模型请参考上方模型列表
-3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
-4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
-
-### 对接上游new api
-
-1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
-2. **代理**填写上游new api的地址,例如:http://localhost:3000
-3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/new-api/dto/audio.go b/new-api/dto/audio.go
deleted file mode 100644
index 347b1284b1c400efc4f1702a33ae2fe99170025c..0000000000000000000000000000000000000000
--- a/new-api/dto/audio.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package dto
-
-import (
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type AudioRequest struct {
- Model string `json:"model"`
- Input string `json:"input"`
- Voice string `json:"voice"`
- Speed float64 `json:"speed,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
-}
-
-func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
- meta := &types.TokenCountMeta{
- CombineText: r.Input,
- TokenType: types.TokenTypeTextNumber,
- }
- return meta
-}
-
-func (r *AudioRequest) IsStream(c *gin.Context) bool {
- return false
-}
-
-func (r *AudioRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-type AudioResponse struct {
- Text string `json:"text"`
-}
-
-type WhisperVerboseJSONResponse struct {
- Task string `json:"task,omitempty"`
- Language string `json:"language,omitempty"`
- Duration float64 `json:"duration,omitempty"`
- Text string `json:"text,omitempty"`
- Segments []Segment `json:"segments,omitempty"`
-}
-
-type Segment struct {
- Id int `json:"id"`
- Seek int `json:"seek"`
- Start float64 `json:"start"`
- End float64 `json:"end"`
- Text string `json:"text"`
- Tokens []int `json:"tokens"`
- Temperature float64 `json:"temperature"`
- AvgLogprob float64 `json:"avg_logprob"`
- CompressionRatio float64 `json:"compression_ratio"`
- NoSpeechProb float64 `json:"no_speech_prob"`
-}
diff --git a/new-api/dto/channel_settings.go b/new-api/dto/channel_settings.go
deleted file mode 100644
index b1c1ba20a1bbf37e838fdcac0974d6f20cb6728c..0000000000000000000000000000000000000000
--- a/new-api/dto/channel_settings.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package dto
-
-type ChannelSettings struct {
- ForceFormat bool `json:"force_format,omitempty"`
- ThinkingToContent bool `json:"thinking_to_content,omitempty"`
- Proxy string `json:"proxy"`
- PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
- SystemPrompt string `json:"system_prompt,omitempty"`
- SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
-}
-
-type VertexKeyType string
-
-const (
- VertexKeyTypeJSON VertexKeyType = "json"
- VertexKeyTypeAPIKey VertexKeyType = "api_key"
-)
-
-type ChannelOtherSettings struct {
- AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
- VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
- OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
-}
-
-func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
- if s == nil || s.OpenRouterEnterprise == nil {
- return false
- }
- return *s.OpenRouterEnterprise
-}
diff --git a/new-api/dto/claude.go b/new-api/dto/claude.go
deleted file mode 100644
index 0855cbb67265c3b09fb68da3841adc7f9d11dc92..0000000000000000000000000000000000000000
--- a/new-api/dto/claude.go
+++ /dev/null
@@ -1,514 +0,0 @@
-package dto
-
-import (
- "encoding/json"
- "fmt"
- "one-api/common"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type ClaudeMetadata struct {
- UserId string `json:"user_id"`
-}
-
-type ClaudeMediaMessage struct {
- Type string `json:"type,omitempty"`
- Text *string `json:"text,omitempty"`
- Model string `json:"model,omitempty"`
- Source *ClaudeMessageSource `json:"source,omitempty"`
- Usage *ClaudeUsage `json:"usage,omitempty"`
- StopReason *string `json:"stop_reason,omitempty"`
- PartialJson *string `json:"partial_json,omitempty"`
- Role string `json:"role,omitempty"`
- Thinking string `json:"thinking,omitempty"`
- Signature string `json:"signature,omitempty"`
- Delta string `json:"delta,omitempty"`
- CacheControl json.RawMessage `json:"cache_control,omitempty"`
- // tool_calls
- Id string `json:"id,omitempty"`
- Name string `json:"name,omitempty"`
- Input any `json:"input,omitempty"`
- Content any `json:"content,omitempty"`
- ToolUseId string `json:"tool_use_id,omitempty"`
-}
-
-func (c *ClaudeMediaMessage) SetText(s string) {
- c.Text = &s
-}
-
-func (c *ClaudeMediaMessage) GetText() string {
- if c.Text == nil {
- return ""
- }
- return *c.Text
-}
-
-func (c *ClaudeMediaMessage) IsStringContent() bool {
- if c.Content == nil {
- return false
- }
- _, ok := c.Content.(string)
- if ok {
- return true
- }
- return false
-}
-
-func (c *ClaudeMediaMessage) GetStringContent() string {
- if c.Content == nil {
- return ""
- }
- switch c.Content.(type) {
- case string:
- return c.Content.(string)
- case []any:
- var contentStr string
- for _, contentItem := range c.Content.([]any) {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
-
- return ""
-}
-
-func (c *ClaudeMediaMessage) GetJsonRowString() string {
- jsonContent, _ := common.Marshal(c)
- return string(jsonContent)
-}
-
-func (c *ClaudeMediaMessage) SetContent(content any) {
- c.Content = content
-}
-
-func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
- mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
- return mediaContent
-}
-
-type ClaudeMessageSource struct {
- Type string `json:"type"`
- MediaType string `json:"media_type,omitempty"`
- Data any `json:"data,omitempty"`
- Url string `json:"url,omitempty"`
-}
-
-type ClaudeMessage struct {
- Role string `json:"role"`
- Content any `json:"content"`
-}
-
-func (c *ClaudeMessage) IsStringContent() bool {
- if c.Content == nil {
- return false
- }
- _, ok := c.Content.(string)
- return ok
-}
-
-func (c *ClaudeMessage) GetStringContent() string {
- if c.Content == nil {
- return ""
- }
- switch c.Content.(type) {
- case string:
- return c.Content.(string)
- case []any:
- var contentStr string
- for _, contentItem := range c.Content.([]any) {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
-
- return ""
-}
-
-func (c *ClaudeMessage) SetStringContent(content string) {
- c.Content = content
-}
-
-func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
- return common.Any2Type[[]ClaudeMediaMessage](c.Content)
-}
-
-type Tool struct {
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- InputSchema map[string]interface{} `json:"input_schema"`
-}
-
-type InputSchema struct {
- Type string `json:"type"`
- Properties any `json:"properties,omitempty"`
- Required any `json:"required,omitempty"`
-}
-
-type ClaudeWebSearchTool struct {
- Type string `json:"type"`
- Name string `json:"name"`
- MaxUses int `json:"max_uses,omitempty"`
- UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
-}
-
-type ClaudeWebSearchUserLocation struct {
- Type string `json:"type"`
- Timezone string `json:"timezone,omitempty"`
- Country string `json:"country,omitempty"`
- Region string `json:"region,omitempty"`
- City string `json:"city,omitempty"`
-}
-
-type ClaudeToolChoice struct {
- Type string `json:"type"`
- Name string `json:"name,omitempty"`
- DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
-}
-
-type ClaudeRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt,omitempty"`
- System any `json:"system,omitempty"`
- Messages []ClaudeMessage `json:"messages,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
- StopSequences []string `json:"stop_sequences,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- //ClaudeMetadata `json:"metadata,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Tools any `json:"tools,omitempty"`
- ContextManagement json.RawMessage `json:"context_management,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- Thinking *Thinking `json:"thinking,omitempty"`
-}
-
-func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var tokenCountMeta = types.TokenCountMeta{
- TokenType: types.TokenTypeTokenizer,
- MaxTokens: int(c.MaxTokens),
- }
-
- var texts = make([]string, 0)
- var fileMeta = make([]*types.FileMeta, 0)
-
- // system
- if c.System != nil {
- if c.IsStringSystem() {
- sys := c.GetStringSystem()
- if sys != "" {
- texts = append(texts, sys)
- }
- } else {
- systemMedia := c.ParseSystem()
- for _, media := range systemMedia {
- switch media.Type {
- case "text":
- texts = append(texts, media.GetText())
- case "image":
- if media.Source != nil {
- data := media.Source.Url
- if data == "" {
- data = common.Interface2String(media.Source.Data)
- }
- if data != "" {
- fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
- }
- }
- }
- }
- }
- }
-
- // messages
- for _, message := range c.Messages {
- tokenCountMeta.MessagesCount++
- texts = append(texts, message.Role)
- if message.IsStringContent() {
- content := message.GetStringContent()
- if content != "" {
- texts = append(texts, content)
- }
- continue
- }
-
- content, _ := message.ParseContent()
- for _, media := range content {
- switch media.Type {
- case "text":
- texts = append(texts, media.GetText())
- case "image":
- if media.Source != nil {
- data := media.Source.Url
- if data == "" {
- data = common.Interface2String(media.Source.Data)
- }
- if data != "" {
- fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
- }
- }
- case "tool_use":
- if media.Name != "" {
- texts = append(texts, media.Name)
- }
- if media.Input != nil {
- b, _ := common.Marshal(media.Input)
- texts = append(texts, string(b))
- }
- case "tool_result":
- if media.Content != nil {
- b, _ := common.Marshal(media.Content)
- texts = append(texts, string(b))
- }
- }
- }
- }
-
- // tools
- if c.Tools != nil {
- tools := c.GetTools()
- normalTools, webSearchTools := ProcessTools(tools)
- if normalTools != nil {
- for _, t := range normalTools {
- tokenCountMeta.ToolsCount++
- if t.Name != "" {
- texts = append(texts, t.Name)
- }
- if t.Description != "" {
- texts = append(texts, t.Description)
- }
- if t.InputSchema != nil {
- b, _ := common.Marshal(t.InputSchema)
- texts = append(texts, string(b))
- }
- }
- }
- if webSearchTools != nil {
- for _, t := range webSearchTools {
- tokenCountMeta.ToolsCount++
- if t.Name != "" {
- texts = append(texts, t.Name)
- }
- if t.UserLocation != nil {
- b, _ := common.Marshal(t.UserLocation)
- texts = append(texts, string(b))
- }
- }
- }
- }
-
- tokenCountMeta.CombineText = strings.Join(texts, "\n")
- tokenCountMeta.Files = fileMeta
- return &tokenCountMeta
-}
-
-func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
- return c.Stream
-}
-
-func (c *ClaudeRequest) SetModelName(modelName string) {
- if modelName != "" {
- c.Model = modelName
- }
-}
-
-func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
- for _, message := range c.Messages {
- content, _ := message.ParseContent()
- for _, mediaMessage := range content {
- if mediaMessage.Id == toolCallId {
- return mediaMessage.Name
- }
- }
- }
- return ""
-}
-
-// AddTool 添加工具到请求中
-func (c *ClaudeRequest) AddTool(tool any) {
- if c.Tools == nil {
- c.Tools = make([]any, 0)
- }
-
- switch tools := c.Tools.(type) {
- case []any:
- c.Tools = append(tools, tool)
- default:
- // 如果Tools不是[]any类型,重新初始化为[]any
- c.Tools = []any{tool}
- }
-}
-
-// GetTools 获取工具列表
-func (c *ClaudeRequest) GetTools() []any {
- if c.Tools == nil {
- return nil
- }
-
- switch tools := c.Tools.(type) {
- case []any:
- return tools
- default:
- return nil
- }
-}
-
-// ProcessTools 处理工具列表,支持类型断言
-func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
- var normalTools []*Tool
- var webSearchTools []*ClaudeWebSearchTool
-
- for _, tool := range tools {
- switch t := tool.(type) {
- case *Tool:
- normalTools = append(normalTools, t)
- case *ClaudeWebSearchTool:
- webSearchTools = append(webSearchTools, t)
- case Tool:
- normalTools = append(normalTools, &t)
- case ClaudeWebSearchTool:
- webSearchTools = append(webSearchTools, &t)
- default:
- // 未知类型,跳过
- continue
- }
- }
-
- return normalTools, webSearchTools
-}
-
-type Thinking struct {
- Type string `json:"type"`
- BudgetTokens *int `json:"budget_tokens,omitempty"`
-}
-
-func (c *Thinking) GetBudgetTokens() int {
- if c.BudgetTokens == nil {
- return 0
- }
- return *c.BudgetTokens
-}
-
-func (c *ClaudeRequest) IsStringSystem() bool {
- _, ok := c.System.(string)
- return ok
-}
-
-func (c *ClaudeRequest) GetStringSystem() string {
- if c.IsStringSystem() {
- return c.System.(string)
- }
- return ""
-}
-
-func (c *ClaudeRequest) SetStringSystem(system string) {
- c.System = system
-}
-
-func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
- mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
- return mediaContent
-}
-
-type ClaudeErrorWithStatusCode struct {
- Error types.ClaudeError `json:"error"`
- StatusCode int `json:"status_code"`
- LocalError bool
-}
-
-type ClaudeResponse struct {
- Id string `json:"id,omitempty"`
- Type string `json:"type"`
- Role string `json:"role,omitempty"`
- Content []ClaudeMediaMessage `json:"content,omitempty"`
- Completion string `json:"completion,omitempty"`
- StopReason string `json:"stop_reason,omitempty"`
- Model string `json:"model,omitempty"`
- Error any `json:"error,omitempty"`
- Usage *ClaudeUsage `json:"usage,omitempty"`
- Index *int `json:"index,omitempty"`
- ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
- Delta *ClaudeMediaMessage `json:"delta,omitempty"`
- Message *ClaudeMediaMessage `json:"message,omitempty"`
-}
-
-// set index
-func (c *ClaudeResponse) SetIndex(i int) {
- c.Index = &i
-}
-
-// get index
-func (c *ClaudeResponse) GetIndex() int {
- if c.Index == nil {
- return 0
- }
- return *c.Index
-}
-
-// GetClaudeError 从动态错误类型中提取ClaudeError结构
-func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
- if c.Error == nil {
- return nil
- }
-
- switch err := c.Error.(type) {
- case types.ClaudeError:
- return &err
- case *types.ClaudeError:
- return err
- case map[string]interface{}:
- // 处理从JSON解析来的map结构
- claudeErr := &types.ClaudeError{}
- if errType, ok := err["type"].(string); ok {
- claudeErr.Type = errType
- }
- if errMsg, ok := err["message"].(string); ok {
- claudeErr.Message = errMsg
- }
- return claudeErr
- case string:
- // 处理简单字符串错误
- return &types.ClaudeError{
- Type: "upstream_error",
- Message: err,
- }
- default:
- // 未知类型,尝试转换为字符串
- return &types.ClaudeError{
- Type: "unknown_upstream_error",
- Message: fmt.Sprintf("unknown_error: %v", err),
- }
- }
-}
-
-type ClaudeUsage struct {
- InputTokens int `json:"input_tokens"`
- CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
- CacheReadInputTokens int `json:"cache_read_input_tokens"`
- OutputTokens int `json:"output_tokens"`
- ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
-}
-
-type ClaudeServerToolUse struct {
- WebSearchRequests int `json:"web_search_requests"`
-}
diff --git a/new-api/dto/embedding.go b/new-api/dto/embedding.go
deleted file mode 100644
index 2681cd17c6cec40f5f19f509d8c31f6acd491881..0000000000000000000000000000000000000000
--- a/new-api/dto/embedding.go
+++ /dev/null
@@ -1,87 +0,0 @@
-package dto
-
-import (
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type EmbeddingOptions struct {
- Seed int `json:"seed,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopK int `json:"top_k,omitempty"`
- TopP *float64 `json:"top_p,omitempty"`
- FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty *float64 `json:"presence_penalty,omitempty"`
- NumPredict int `json:"num_predict,omitempty"`
- NumCtx int `json:"num_ctx,omitempty"`
-}
-
-type EmbeddingRequest struct {
- Model string `json:"model"`
- Input any `json:"input"`
- EncodingFormat string `json:"encoding_format,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
- User string `json:"user,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
-}
-
-func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var texts = make([]string, 0)
-
- inputs := r.ParseInput()
- for _, input := range inputs {
- texts = append(texts, input)
- }
-
- return &types.TokenCountMeta{
- CombineText: strings.Join(texts, "\n"),
- }
-}
-
-func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
- return false
-}
-
-func (r *EmbeddingRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-func (r *EmbeddingRequest) ParseInput() []string {
- if r.Input == nil {
- return make([]string, 0)
- }
- var input []string
- switch r.Input.(type) {
- case string:
- input = []string{r.Input.(string)}
- case []any:
- input = make([]string, 0, len(r.Input.([]any)))
- for _, item := range r.Input.([]any) {
- if str, ok := item.(string); ok {
- input = append(input, str)
- }
- }
- }
- return input
-}
-
-type EmbeddingResponseItem struct {
- Object string `json:"object"`
- Index int `json:"index"`
- Embedding []float64 `json:"embedding"`
-}
-
-type EmbeddingResponse struct {
- Object string `json:"object"`
- Data []EmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
-}
diff --git a/new-api/dto/error.go b/new-api/dto/error.go
deleted file mode 100644
index 9c09e4e846b434adabce55d0223bffc47517c474..0000000000000000000000000000000000000000
--- a/new-api/dto/error.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package dto
-
-import "one-api/types"
-
-type OpenAIError struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param string `json:"param"`
- Code any `json:"code"`
-}
-
-type OpenAIErrorWithStatusCode struct {
- Error OpenAIError `json:"error"`
- StatusCode int `json:"status_code"`
- LocalError bool
-}
-
-type GeneralErrorResponse struct {
- Error types.OpenAIError `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
- Header struct {
- Message string `json:"message"`
- } `json:"header"`
- Response struct {
- Error struct {
- Message string `json:"message"`
- } `json:"error"`
- } `json:"response"`
-}
-
-func (e GeneralErrorResponse) ToMessage() string {
- if e.Error.Message != "" {
- return e.Error.Message
- }
- if e.Message != "" {
- return e.Message
- }
- if e.Msg != "" {
- return e.Msg
- }
- if e.Err != "" {
- return e.Err
- }
- if e.ErrorMsg != "" {
- return e.ErrorMsg
- }
- if e.Header.Message != "" {
- return e.Header.Message
- }
- if e.Response.Error.Message != "" {
- return e.Response.Error.Message
- }
- return ""
-}
diff --git a/new-api/dto/gemini.go b/new-api/dto/gemini.go
deleted file mode 100644
index 077443e055afe3d6bfaace1b050e7afed0f85f0e..0000000000000000000000000000000000000000
--- a/new-api/dto/gemini.go
+++ /dev/null
@@ -1,416 +0,0 @@
-package dto
-
-import (
- "encoding/json"
- "github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/logger"
- "one-api/types"
- "strings"
-)
-
-type GeminiChatRequest struct {
- Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
- Tools json.RawMessage `json:"tools,omitempty"`
- ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
- SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
- CachedContent string `json:"cachedContent,omitempty"`
-}
-
-type ToolConfig struct {
- FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
- RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
-}
-
-type FunctionCallingConfig struct {
- Mode FunctionCallingConfigMode `json:"mode,omitempty"`
- AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
-}
-type FunctionCallingConfigMode string
-
-type RetrievalConfig struct {
- LatLng *LatLng `json:"latLng,omitempty"`
- LanguageCode string `json:"languageCode,omitempty"`
-}
-
-type LatLng struct {
- Latitude *float64 `json:"latitude,omitempty"`
- Longitude *float64 `json:"longitude,omitempty"`
-}
-
-func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var files []*types.FileMeta = make([]*types.FileMeta, 0)
-
- var maxTokens int
-
- if r.GenerationConfig.MaxOutputTokens > 0 {
- maxTokens = int(r.GenerationConfig.MaxOutputTokens)
- }
-
- var inputTexts []string
- for _, content := range r.Contents {
- for _, part := range content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- if part.InlineData != nil && part.InlineData.Data != "" {
- if strings.HasPrefix(part.InlineData.MimeType, "image/") {
- files = append(files, &types.FileMeta{
- FileType: types.FileTypeImage,
- OriginData: part.InlineData.Data,
- })
- } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
- files = append(files, &types.FileMeta{
- FileType: types.FileTypeAudio,
- OriginData: part.InlineData.Data,
- })
- } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
- files = append(files, &types.FileMeta{
- FileType: types.FileTypeVideo,
- OriginData: part.InlineData.Data,
- })
- } else {
- files = append(files, &types.FileMeta{
- FileType: types.FileTypeFile,
- OriginData: part.InlineData.Data,
- })
- }
- }
- }
- }
-
- inputText := strings.Join(inputTexts, "\n")
- return &types.TokenCountMeta{
- CombineText: inputText,
- Files: files,
- MaxTokens: maxTokens,
- }
-}
-
-func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
- if c.Query("alt") == "sse" {
- return true
- }
- return false
-}
-
-func (r *GeminiChatRequest) SetModelName(modelName string) {
- // GeminiChatRequest does not have a model field, so this method does nothing.
-}
-
-func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
- var tools []GeminiChatTool
- if strings.HasSuffix(string(r.Tools), "[") {
- // is array
- if err := common.Unmarshal(r.Tools, &tools); err != nil {
- logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
- return nil
- }
- } else if strings.HasPrefix(string(r.Tools), "{") {
- // is object
- singleTool := GeminiChatTool{}
- if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
- logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
- return nil
- }
- tools = []GeminiChatTool{singleTool}
- }
- return tools
-}
-
-func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
- if len(tools) == 0 {
- r.Tools = json.RawMessage("[]")
- return
- }
-
- // Marshal the tools to JSON
- data, err := common.Marshal(tools)
- if err != nil {
- logger.LogError(nil, "error_marshalling_tools: "+err.Error())
- return
- }
- r.Tools = data
-}
-
-type GeminiThinkingConfig struct {
- IncludeThoughts bool `json:"includeThoughts,omitempty"`
- ThinkingBudget *int `json:"thinkingBudget,omitempty"`
-}
-
-func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
- c.ThinkingBudget = &budget
-}
-
-type GeminiInlineData struct {
- MimeType string `json:"mimeType"`
- Data string `json:"data"`
-}
-
-// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
-func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
- type Alias GeminiInlineData // Use type alias to avoid recursion
- var aux struct {
- Alias
- MimeTypeSnake string `json:"mime_type"`
- }
-
- if err := common.Unmarshal(data, &aux); err != nil {
- return err
- }
-
- *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
-
- // Prioritize snake_case if present
- if aux.MimeTypeSnake != "" {
- g.MimeType = aux.MimeTypeSnake
- } else if aux.MimeType != "" { // Fallback to camelCase from Alias
- g.MimeType = aux.MimeType
- }
- // g.Data would be populated by aux.Alias.Data
- return nil
-}
-
-type FunctionCall struct {
- FunctionName string `json:"name"`
- Arguments any `json:"args"`
-}
-
-type GeminiFunctionResponse struct {
- Name string `json:"name"`
- Response map[string]interface{} `json:"response"`
-}
-
-type GeminiPartExecutableCode struct {
- Language string `json:"language,omitempty"`
- Code string `json:"code,omitempty"`
-}
-
-type GeminiPartCodeExecutionResult struct {
- Outcome string `json:"outcome,omitempty"`
- Output string `json:"output,omitempty"`
-}
-
-type GeminiFileData struct {
- MimeType string `json:"mimeType,omitempty"`
- FileUri string `json:"fileUri,omitempty"`
-}
-
-type GeminiPart struct {
- Text string `json:"text,omitempty"`
- Thought bool `json:"thought,omitempty"`
- InlineData *GeminiInlineData `json:"inlineData,omitempty"`
- FunctionCall *FunctionCall `json:"functionCall,omitempty"`
- FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
- FileData *GeminiFileData `json:"fileData,omitempty"`
- ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
- CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
-}
-
-// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
-func (p *GeminiPart) UnmarshalJSON(data []byte) error {
- // Alias to avoid recursion during unmarshalling
- type Alias GeminiPart
- var aux struct {
- Alias
- InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
- }
-
- if err := common.Unmarshal(data, &aux); err != nil {
- return err
- }
-
- // Assign fields from alias
- *p = GeminiPart(aux.Alias)
-
- // Prioritize snake_case for InlineData if present
- if aux.InlineDataSnake != nil {
- p.InlineData = aux.InlineDataSnake
- } else if aux.InlineData != nil { // Fallback to camelCase from Alias
- p.InlineData = aux.InlineData
- }
- // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
-
- return nil
-}
-
-type GeminiChatContent struct {
- Role string `json:"role,omitempty"`
- Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
- Category string `json:"category"`
- Threshold string `json:"threshold"`
-}
-
-type GeminiChatTool struct {
- GoogleSearch any `json:"googleSearch,omitempty"`
- GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
- CodeExecution any `json:"codeExecution,omitempty"`
- FunctionDeclarations any `json:"functionDeclarations,omitempty"`
- URLContext any `json:"urlContext,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
- ResponseMimeType string `json:"responseMimeType,omitempty"`
- ResponseSchema any `json:"responseSchema,omitempty"`
- ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
- PresencePenalty *float32 `json:"presencePenalty,omitempty"`
- FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
- ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
- Logprobs *int32 `json:"logprobs,omitempty"`
- MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- ResponseModalities []string `json:"responseModalities,omitempty"`
- ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
- SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
-}
-
-type MediaResolution string
-
-type GeminiChatCandidate struct {
- Content GeminiChatContent `json:"content"`
- FinishReason *string `json:"finishReason"`
- Index int64 `json:"index"`
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-type GeminiChatSafetyRating struct {
- Category string `json:"category"`
- Probability string `json:"probability"`
-}
-
-type GeminiChatPromptFeedback struct {
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-type GeminiChatResponse struct {
- Candidates []GeminiChatCandidate `json:"candidates"`
- PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
- UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
-}
-
-type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount"`
- PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
-}
-
-type GeminiPromptTokensDetails struct {
- Modality string `json:"modality"`
- TokenCount int `json:"tokenCount"`
-}
-
-// Imagen related structs
-type GeminiImageRequest struct {
- Instances []GeminiImageInstance `json:"instances"`
- Parameters GeminiImageParameters `json:"parameters"`
-}
-
-type GeminiImageInstance struct {
- Prompt string `json:"prompt"`
-}
-
-type GeminiImageParameters struct {
- SampleCount int `json:"sampleCount,omitempty"`
- AspectRatio string `json:"aspectRatio,omitempty"`
- PersonGeneration string `json:"personGeneration,omitempty"`
-}
-
-type GeminiImageResponse struct {
- Predictions []GeminiImagePrediction `json:"predictions"`
-}
-
-type GeminiImagePrediction struct {
- MimeType string `json:"mimeType"`
- BytesBase64Encoded string `json:"bytesBase64Encoded"`
- RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
- SafetyAttributes any `json:"safetyAttributes,omitempty"`
-}
-
-// Embedding related structs
-type GeminiEmbeddingRequest struct {
- Model string `json:"model,omitempty"`
- Content GeminiChatContent `json:"content"`
- TaskType string `json:"taskType,omitempty"`
- Title string `json:"title,omitempty"`
- OutputDimensionality int `json:"outputDimensionality,omitempty"`
-}
-
-func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
- // Gemini embedding requests are not streamed
- return false
-}
-
-func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var inputTexts []string
- for _, part := range r.Content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- inputText := strings.Join(inputTexts, "\n")
- return &types.TokenCountMeta{
- CombineText: inputText,
- }
-}
-
-func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-type GeminiBatchEmbeddingRequest struct {
- Requests []*GeminiEmbeddingRequest `json:"requests"`
-}
-
-func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
- // Gemini batch embedding requests are not streamed
- return false
-}
-
-func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var inputTexts []string
- for _, request := range r.Requests {
- meta := request.GetTokenCountMeta()
- if meta != nil && meta.CombineText != "" {
- inputTexts = append(inputTexts, meta.CombineText)
- }
- }
- inputText := strings.Join(inputTexts, "\n")
- return &types.TokenCountMeta{
- CombineText: inputText,
- }
-}
-
-func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
- if modelName != "" {
- for _, req := range r.Requests {
- req.SetModelName(modelName)
- }
- }
-}
-
-type GeminiEmbeddingResponse struct {
- Embedding ContentEmbedding `json:"embedding"`
-}
-
-type GeminiBatchEmbeddingResponse struct {
- Embeddings []*ContentEmbedding `json:"embeddings"`
-}
-
-type ContentEmbedding struct {
- Values []float64 `json:"values"`
-}
diff --git a/new-api/dto/midjourney.go b/new-api/dto/midjourney.go
deleted file mode 100644
index 0057aaa495a02fdf64d6282ba658a66bb0000b1d..0000000000000000000000000000000000000000
--- a/new-api/dto/midjourney.go
+++ /dev/null
@@ -1,107 +0,0 @@
-package dto
-
-//type SimpleMjRequest struct {
-// Prompt string `json:"prompt"`
-// CustomId string `json:"customId"`
-// Action string `json:"action"`
-// Content string `json:"content"`
-//}
-
-type SwapFaceRequest struct {
- SourceBase64 string `json:"sourceBase64"`
- TargetBase64 string `json:"targetBase64"`
-}
-
-type MidjourneyRequest struct {
- Prompt string `json:"prompt"`
- CustomId string `json:"customId"`
- BotType string `json:"botType"`
- NotifyHook string `json:"notifyHook"`
- Action string `json:"action"`
- Index int `json:"index"`
- State string `json:"state"`
- TaskId string `json:"taskId"`
- Base64Array []string `json:"base64Array"`
- Content string `json:"content"`
- MaskBase64 string `json:"maskBase64"`
-}
-
-type MidjourneyResponse struct {
- Code int `json:"code"`
- Description string `json:"description"`
- Properties interface{} `json:"properties"`
- Result string `json:"result"`
-}
-
-type MidjourneyUploadResponse struct {
- Code int `json:"code"`
- Description string `json:"description"`
- Result []string `json:"result"`
-}
-
-type MidjourneyResponseWithStatusCode struct {
- StatusCode int `json:"statusCode"`
- Response MidjourneyResponse
-}
-
-type MidjourneyDto struct {
- MjId string `json:"id"`
- Action string `json:"action"`
- CustomId string `json:"customId"`
- BotType string `json:"botType"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"promptEn"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submitTime"`
- StartTime int64 `json:"startTime"`
- FinishTime int64 `json:"finishTime"`
- ImageUrl string `json:"imageUrl"`
- VideoUrl string `json:"videoUrl"`
- VideoUrls []ImgUrls `json:"videoUrls"`
- Status string `json:"status"`
- Progress string `json:"progress"`
- FailReason string `json:"failReason"`
- Buttons any `json:"buttons"`
- MaskBase64 string `json:"maskBase64"`
- Properties *Properties `json:"properties"`
-}
-
-type ImgUrls struct {
- Url string `json:"url"`
-}
-
-type MidjourneyStatus struct {
- Status int `json:"status"`
-}
-type MidjourneyWithoutStatus struct {
- Id int `json:"id"`
- Code int `json:"code"`
- UserId int `json:"user_id" gorm:"index"`
- Action string `json:"action"`
- MjId string `json:"mj_id" gorm:"index"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"prompt_en"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- ImageUrl string `json:"image_url"`
- Progress string `json:"progress"`
- FailReason string `json:"fail_reason"`
- ChannelId int `json:"channel_id"`
-}
-
-type ActionButton struct {
- CustomId any `json:"customId"`
- Emoji any `json:"emoji"`
- Label any `json:"label"`
- Type any `json:"type"`
- Style any `json:"style"`
-}
-
-type Properties struct {
- FinalPrompt string `json:"finalPrompt"`
- FinalZhPrompt string `json:"finalZhPrompt"`
-}
diff --git a/new-api/dto/notify.go b/new-api/dto/notify.go
deleted file mode 100644
index 1c0e9b837d5f102cbae49eee0150f5ca1ea009f4..0000000000000000000000000000000000000000
--- a/new-api/dto/notify.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package dto
-
-type Notify struct {
- Type string `json:"type"`
- Title string `json:"title"`
- Content string `json:"content"`
- Values []interface{} `json:"values"`
-}
-
-const ContentValueParam = "{{value}}"
-
-const (
- NotifyTypeQuotaExceed = "quota_exceed"
- NotifyTypeChannelUpdate = "channel_update"
- NotifyTypeChannelTest = "channel_test"
-)
-
-func NewNotify(t string, title string, content string, values []interface{}) Notify {
- return Notify{
- Type: t,
- Title: title,
- Content: content,
- Values: values,
- }
-}
diff --git a/new-api/dto/openai_image.go b/new-api/dto/openai_image.go
deleted file mode 100644
index 66f6dca7c17fbe548045e2123ef6aa1908df8295..0000000000000000000000000000000000000000
--- a/new-api/dto/openai_image.go
+++ /dev/null
@@ -1,172 +0,0 @@
-package dto
-
-import (
- "encoding/json"
- "one-api/common"
- "one-api/types"
- "reflect"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N uint `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- Style json.RawMessage `json:"style,omitempty"`
- User json.RawMessage `json:"user,omitempty"`
- ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
- Background json.RawMessage `json:"background,omitempty"`
- Moderation json.RawMessage `json:"moderation,omitempty"`
- OutputFormat json.RawMessage `json:"output_format,omitempty"`
- OutputCompression json.RawMessage `json:"output_compression,omitempty"`
- PartialImages json.RawMessage `json:"partial_images,omitempty"`
- // Stream bool `json:"stream,omitempty"`
- Watermark *bool `json:"watermark,omitempty"`
- // 用匿名参数接收额外参数
- Extra map[string]json.RawMessage `json:"-"`
-}
-
-func (i *ImageRequest) UnmarshalJSON(data []byte) error {
- // 先解析成 map[string]interface{}
- var rawMap map[string]json.RawMessage
- if err := common.Unmarshal(data, &rawMap); err != nil {
- return err
- }
-
- // 用 struct tag 获取所有已定义字段名
- knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
-
- // 再正常解析已定义字段
- type Alias ImageRequest
- var known Alias
- if err := common.Unmarshal(data, &known); err != nil {
- return err
- }
- *i = ImageRequest(known)
-
- // 提取多余字段
- i.Extra = make(map[string]json.RawMessage)
- for k, v := range rawMap {
- if _, ok := knownFields[k]; !ok {
- i.Extra[k] = v
- }
- }
- return nil
-}
-
-// 序列化时需要重新把字段平铺
-func (r ImageRequest) MarshalJSON() ([]byte, error) {
- // 将已定义字段转为 map
- type Alias ImageRequest
- alias := Alias(r)
- base, err := common.Marshal(alias)
- if err != nil {
- return nil, err
- }
-
- var baseMap map[string]json.RawMessage
- if err := common.Unmarshal(base, &baseMap); err != nil {
- return nil, err
- }
-
- // 合并 ExtraFields
- for k, v := range r.Extra {
- if _, exists := baseMap[k]; !exists {
- baseMap[k] = v
- }
- }
-
- return json.Marshal(baseMap)
-}
-
-func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
- fields := make(map[string]struct{})
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
-
- // 跳过匿名字段(例如 ExtraFields)
- if field.Anonymous {
- continue
- }
-
- tag := field.Tag.Get("json")
- if tag == "-" || tag == "" {
- continue
- }
-
- // 取逗号前字段名(排除 omitempty 等)
- name := tag
- if commaIdx := indexComma(tag); commaIdx != -1 {
- name = tag[:commaIdx]
- }
- fields[name] = struct{}{}
- }
- return fields
-}
-
-func indexComma(s string) int {
- for i := 0; i < len(s); i++ {
- if s[i] == ',' {
- return i
- }
- }
- return -1
-}
-
-func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var sizeRatio = 1.0
- var qualityRatio = 1.0
-
- if strings.HasPrefix(i.Model, "dall-e") {
- // Size
- if i.Size == "256x256" {
- sizeRatio = 0.4
- } else if i.Size == "512x512" {
- sizeRatio = 0.45
- } else if i.Size == "1024x1024" {
- sizeRatio = 1
- } else if i.Size == "1024x1792" || i.Size == "1792x1024" {
- sizeRatio = 2
- }
-
- if i.Model == "dall-e-3" && i.Quality == "hd" {
- qualityRatio = 2.0
- if i.Size == "1024x1792" || i.Size == "1792x1024" {
- qualityRatio = 1.5
- }
- }
- }
-
- // not support token count for dalle
- return &types.TokenCountMeta{
- CombineText: i.Prompt,
- MaxTokens: 1584,
- ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
- }
-}
-
-func (i *ImageRequest) IsStream(c *gin.Context) bool {
- return false
-}
-
-func (i *ImageRequest) SetModelName(modelName string) {
- if modelName != "" {
- i.Model = modelName
- }
-}
-
-type ImageResponse struct {
- Data []ImageData `json:"data"`
- Created int64 `json:"created"`
- Extra any `json:"extra,omitempty"`
-}
-type ImageData struct {
- Url string `json:"url"`
- B64Json string `json:"b64_json"`
- RevisedPrompt string `json:"revised_prompt"`
-}
diff --git a/new-api/dto/openai_request.go b/new-api/dto/openai_request.go
deleted file mode 100644
index 5ab77e7beb41821ce9c412681cd3569ddb337e11..0000000000000000000000000000000000000000
--- a/new-api/dto/openai_request.go
+++ /dev/null
@@ -1,959 +0,0 @@
-package dto
-
-import (
- "encoding/json"
- "fmt"
- "one-api/common"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type ResponseFormat struct {
- Type string `json:"type,omitempty"`
- JsonSchema json.RawMessage `json:"json_schema,omitempty"`
-}
-
-type FormatJsonSchema struct {
- Description string `json:"description,omitempty"`
- Name string `json:"name"`
- Schema any `json:"schema,omitempty"`
- Strict json.RawMessage `json:"strict,omitempty"`
-}
-
-type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Prefix any `json:"prefix,omitempty"`
- Suffix any `json:"suffix,omitempty"`
- Stream bool `json:"stream,omitempty"`
- StreamOptions *StreamOptions `json:"stream_options,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
- Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Stop any `json:"stop,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions json.RawMessage `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
- Tools []ToolCallRequest `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
- LogProbs bool `json:"logprobs,omitempty"`
- TopLogProbs int `json:"top_logprobs,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
- Modalities json.RawMessage `json:"modalities,omitempty"`
- Audio json.RawMessage `json:"audio,omitempty"`
- // gemini
- ExtraBody json.RawMessage `json:"extra_body,omitempty"`
- //xai
- SearchParameters json.RawMessage `json:"search_parameters,omitempty"`
- // claude
- WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
- // OpenRouter Params
- Usage json.RawMessage `json:"usage,omitempty"`
- Reasoning json.RawMessage `json:"reasoning,omitempty"`
- // Ali Qwen Params
- VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
- EnableThinking any `json:"enable_thinking,omitempty"`
- // ollama Params
- Think json.RawMessage `json:"think,omitempty"`
- // baidu v2
- WebSearch json.RawMessage `json:"web_search,omitempty"`
- // doubao,zhipu_v4
- THINKING json.RawMessage `json:"thinking,omitempty"`
-}
-
-func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var tokenCountMeta types.TokenCountMeta
- var texts = make([]string, 0)
- var fileMeta = make([]*types.FileMeta, 0)
-
- if r.Prompt != nil {
- switch v := r.Prompt.(type) {
- case string:
- texts = append(texts, v)
- case []any:
- for _, item := range v {
- if str, ok := item.(string); ok {
- texts = append(texts, str)
- }
- }
- default:
- texts = append(texts, fmt.Sprintf("%v", r.Prompt))
- }
- }
-
- if r.Input != nil {
- inputs := r.ParseInput()
- texts = append(texts, inputs...)
- }
-
- if r.MaxCompletionTokens > r.MaxTokens {
- tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
- } else {
- tokenCountMeta.MaxTokens = int(r.MaxTokens)
- }
-
- for _, message := range r.Messages {
- tokenCountMeta.MessagesCount++
- texts = append(texts, message.Role)
- if message.Content != nil {
- if message.Name != nil {
- tokenCountMeta.NameCount++
- texts = append(texts, *message.Name)
- }
- arrayContent := message.ParseContent()
- for _, m := range arrayContent {
- if m.Type == ContentTypeImageURL {
- imageUrl := m.GetImageMedia()
- if imageUrl != nil {
- if imageUrl.Url != "" {
- meta := &types.FileMeta{
- FileType: types.FileTypeImage,
- }
- meta.OriginData = imageUrl.Url
- meta.Detail = imageUrl.Detail
- fileMeta = append(fileMeta, meta)
- }
- }
- } else if m.Type == ContentTypeInputAudio {
- inputAudio := m.GetInputAudio()
- if inputAudio != nil {
- meta := &types.FileMeta{
- FileType: types.FileTypeAudio,
- }
- meta.OriginData = inputAudio.Data
- fileMeta = append(fileMeta, meta)
- }
- } else if m.Type == ContentTypeFile {
- file := m.GetFile()
- if file != nil {
- meta := &types.FileMeta{
- FileType: types.FileTypeFile,
- }
- meta.OriginData = file.FileData
- fileMeta = append(fileMeta, meta)
- }
- } else if m.Type == ContentTypeVideoUrl {
- videoUrl := m.GetVideoUrl()
- if videoUrl != nil && videoUrl.Url != "" {
- meta := &types.FileMeta{
- FileType: types.FileTypeVideo,
- }
- meta.OriginData = videoUrl.Url
- fileMeta = append(fileMeta, meta)
- }
- } else {
- texts = append(texts, m.Text)
- }
- }
- }
- }
-
- if r.Tools != nil {
- openaiTools := r.Tools
- for _, tool := range openaiTools {
- tokenCountMeta.ToolsCount++
- texts = append(texts, tool.Function.Name)
- if tool.Function.Description != "" {
- texts = append(texts, tool.Function.Description)
- }
- if tool.Function.Parameters != nil {
- texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
- }
- }
- //toolTokens := CountTokenInput(countStr, request.Model)
- //tkm += 8
- //tkm += toolTokens
- }
- tokenCountMeta.CombineText = strings.Join(texts, "\n")
- tokenCountMeta.Files = fileMeta
- return &tokenCountMeta
-}
-
-func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
- return r.Stream
-}
-
-func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-func (r *GeneralOpenAIRequest) ToMap() map[string]any {
- result := make(map[string]any)
- data, _ := common.Marshal(r)
- _ = common.Unmarshal(data, &result)
- return result
-}
-
-func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
- if strings.HasPrefix(r.Model, "o") {
- if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
- return "developer"
- }
- } else if strings.HasPrefix(r.Model, "gpt-5") {
- return "developer"
- }
- return "system"
-}
-
-type ToolCallRequest struct {
- ID string `json:"id,omitempty"`
- Type string `json:"type"`
- Function FunctionRequest `json:"function"`
-}
-
-type FunctionRequest struct {
- Description string `json:"description,omitempty"`
- Name string `json:"name"`
- Parameters any `json:"parameters,omitempty"`
- Arguments string `json:"arguments,omitempty"`
-}
-
-type StreamOptions struct {
- IncludeUsage bool `json:"include_usage,omitempty"`
-}
-
-func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
- if r.MaxCompletionTokens != 0 {
- return r.MaxCompletionTokens
- }
- return r.MaxTokens
-}
-
-func (r *GeneralOpenAIRequest) ParseInput() []string {
- if r.Input == nil {
- return nil
- }
- var input []string
- switch r.Input.(type) {
- case string:
- input = []string{r.Input.(string)}
- case []any:
- input = make([]string, 0, len(r.Input.([]any)))
- for _, item := range r.Input.([]any) {
- if str, ok := item.(string); ok {
- input = append(input, str)
- }
- }
- }
- return input
-}
-
-type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- Name *string `json:"name,omitempty"`
- Prefix *bool `json:"prefix,omitempty"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- Reasoning string `json:"reasoning,omitempty"`
- ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
- ToolCallId string `json:"tool_call_id,omitempty"`
- parsedContent []MediaContent
- //parsedStringContent *string
-}
-
-type MediaContent struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- ImageUrl any `json:"image_url,omitempty"`
- InputAudio any `json:"input_audio,omitempty"`
- File any `json:"file,omitempty"`
- VideoUrl any `json:"video_url,omitempty"`
- // OpenRouter Params
- CacheControl json.RawMessage `json:"cache_control,omitempty"`
-}
-
-func (m *MediaContent) GetImageMedia() *MessageImageUrl {
- if m.ImageUrl != nil {
- if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
- return m.ImageUrl.(*MessageImageUrl)
- }
- if itemMap, ok := m.ImageUrl.(map[string]any); ok {
- out := &MessageImageUrl{
- Url: common.Interface2String(itemMap["url"]),
- Detail: common.Interface2String(itemMap["detail"]),
- MimeType: common.Interface2String(itemMap["mime_type"]),
- }
- return out
- }
- }
- return nil
-}
-
-func (m *MediaContent) GetInputAudio() *MessageInputAudio {
- if m.InputAudio != nil {
- if _, ok := m.InputAudio.(*MessageInputAudio); ok {
- return m.InputAudio.(*MessageInputAudio)
- }
- if itemMap, ok := m.InputAudio.(map[string]any); ok {
- out := &MessageInputAudio{
- Data: common.Interface2String(itemMap["data"]),
- Format: common.Interface2String(itemMap["format"]),
- }
- return out
- }
- }
- return nil
-}
-
-func (m *MediaContent) GetFile() *MessageFile {
- if m.File != nil {
- if _, ok := m.File.(*MessageFile); ok {
- return m.File.(*MessageFile)
- }
- if itemMap, ok := m.File.(map[string]any); ok {
- out := &MessageFile{
- FileName: common.Interface2String(itemMap["file_name"]),
- FileData: common.Interface2String(itemMap["file_data"]),
- FileId: common.Interface2String(itemMap["file_id"]),
- }
- return out
- }
- }
- return nil
-}
-
-func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
- if m.VideoUrl != nil {
- if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
- return m.VideoUrl.(*MessageVideoUrl)
- }
- if itemMap, ok := m.VideoUrl.(map[string]any); ok {
- out := &MessageVideoUrl{
- Url: common.Interface2String(itemMap["url"]),
- }
- return out
- }
- }
- return nil
-}
-
-type MessageImageUrl struct {
- Url string `json:"url"`
- Detail string `json:"detail"`
- MimeType string
-}
-
-func (m *MessageImageUrl) IsRemoteImage() bool {
- return strings.HasPrefix(m.Url, "http")
-}
-
-type MessageInputAudio struct {
- Data string `json:"data"` //base64
- Format string `json:"format"`
-}
-
-type MessageFile struct {
- FileName string `json:"filename,omitempty"`
- FileData string `json:"file_data,omitempty"`
- FileId string `json:"file_id,omitempty"`
-}
-
-type MessageVideoUrl struct {
- Url string `json:"url"`
-}
-
-const (
- ContentTypeText = "text"
- ContentTypeImageURL = "image_url"
- ContentTypeInputAudio = "input_audio"
- ContentTypeFile = "file"
- ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
- //ContentTypeAudioUrl = "audio_url"
-)
-
-func (m *Message) GetPrefix() bool {
- if m.Prefix == nil {
- return false
- }
- return *m.Prefix
-}
-
-func (m *Message) SetPrefix(prefix bool) {
- m.Prefix = &prefix
-}
-
-func (m *Message) ParseToolCalls() []ToolCallRequest {
- if m.ToolCalls == nil {
- return nil
- }
- var toolCalls []ToolCallRequest
- if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
- return toolCalls
- }
- return toolCalls
-}
-
-func (m *Message) SetToolCalls(toolCalls any) {
- toolCallsJson, _ := json.Marshal(toolCalls)
- m.ToolCalls = toolCallsJson
-}
-
-func (m *Message) StringContent() string {
- switch m.Content.(type) {
- case string:
- return m.Content.(string)
- case []any:
- var contentStr string
- for _, contentItem := range m.Content.([]any) {
- contentMap, ok := contentItem.(map[string]any)
- if !ok {
- continue
- }
- if contentMap["type"] == ContentTypeText {
- if subStr, ok := contentMap["text"].(string); ok {
- contentStr += subStr
- }
- }
- }
- return contentStr
- }
-
- return ""
-}
-
-func (m *Message) SetNullContent() {
- m.Content = nil
- m.parsedContent = nil
-}
-
-func (m *Message) SetStringContent(content string) {
- m.Content = content
- m.parsedContent = nil
-}
-
-func (m *Message) SetMediaContent(content []MediaContent) {
- m.Content = content
- m.parsedContent = content
-}
-
-func (m *Message) IsStringContent() bool {
- _, ok := m.Content.(string)
- if ok {
- return true
- }
- return false
-}
-
-func (m *Message) ParseContent() []MediaContent {
- if m.Content == nil {
- return nil
- }
- if len(m.parsedContent) > 0 {
- return m.parsedContent
- }
-
- var contentList []MediaContent
- // 先尝试解析为字符串
- content, ok := m.Content.(string)
- if ok {
- contentList = []MediaContent{{
- Type: ContentTypeText,
- Text: content,
- }}
- m.parsedContent = contentList
- return contentList
- }
-
- // 尝试解析为数组
- //var arrayContent []map[string]interface{}
-
- arrayContent, ok := m.Content.([]any)
- if !ok {
- return contentList
- }
-
- for _, contentItemAny := range arrayContent {
- mediaItem, ok := contentItemAny.(MediaContent)
- if ok {
- contentList = append(contentList, mediaItem)
- continue
- }
-
- contentItem, ok := contentItemAny.(map[string]any)
- if !ok {
- continue
- }
- contentType, ok := contentItem["type"].(string)
- if !ok {
- continue
- }
-
- switch contentType {
- case ContentTypeText:
- if text, ok := contentItem["text"].(string); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeText,
- Text: text,
- })
- }
-
- case ContentTypeImageURL:
- imageUrl := contentItem["image_url"]
- temp := &MessageImageUrl{
- Detail: "high",
- }
- switch v := imageUrl.(type) {
- case string:
- temp.Url = v
- case map[string]interface{}:
- url, ok1 := v["url"].(string)
- detail, ok2 := v["detail"].(string)
- if ok2 {
- temp.Detail = detail
- }
- if ok1 {
- temp.Url = url
- }
- }
- contentList = append(contentList, MediaContent{
- Type: ContentTypeImageURL,
- ImageUrl: temp,
- })
-
- case ContentTypeInputAudio:
- if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
- data, ok1 := audioData["data"].(string)
- format, ok2 := audioData["format"].(string)
- if ok1 && ok2 {
- temp := &MessageInputAudio{
- Data: data,
- Format: format,
- }
- contentList = append(contentList, MediaContent{
- Type: ContentTypeInputAudio,
- InputAudio: temp,
- })
- }
- }
- case ContentTypeFile:
- if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
- fileId, ok3 := fileData["file_id"].(string)
- if ok3 {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeFile,
- File: &MessageFile{
- FileId: fileId,
- },
- })
- } else {
- fileName, ok1 := fileData["filename"].(string)
- fileDataStr, ok2 := fileData["file_data"].(string)
- if ok1 && ok2 {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeFile,
- File: &MessageFile{
- FileName: fileName,
- FileData: fileDataStr,
- },
- })
- }
- }
- }
- case ContentTypeVideoUrl:
- if videoUrl, ok := contentItem["video_url"].(string); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeVideoUrl,
- VideoUrl: &MessageVideoUrl{
- Url: videoUrl,
- },
- })
- }
- }
- }
-
- if len(contentList) > 0 {
- m.parsedContent = contentList
- }
- return contentList
-}
-
-// old code
-/*func (m *Message) StringContent() string {
- if m.parsedStringContent != nil {
- return *m.parsedStringContent
- }
-
- var stringContent string
- if err := json.Unmarshal(m.Content, &stringContent); err == nil {
- m.parsedStringContent = &stringContent
- return stringContent
- }
-
- contentStr := new(strings.Builder)
- arrayContent := m.ParseContent()
- for _, content := range arrayContent {
- if content.Type == ContentTypeText {
- contentStr.WriteString(content.Text)
- }
- }
- stringContent = contentStr.String()
- m.parsedStringContent = &stringContent
-
- return stringContent
-}
-
-func (m *Message) SetNullContent() {
- m.Content = nil
- m.parsedStringContent = nil
- m.parsedContent = nil
-}
-
-func (m *Message) SetStringContent(content string) {
- jsonContent, _ := json.Marshal(content)
- m.Content = jsonContent
- m.parsedStringContent = &content
- m.parsedContent = nil
-}
-
-func (m *Message) SetMediaContent(content []MediaContent) {
- jsonContent, _ := json.Marshal(content)
- m.Content = jsonContent
- m.parsedContent = nil
- m.parsedStringContent = nil
-}
-
-func (m *Message) IsStringContent() bool {
- if m.parsedStringContent != nil {
- return true
- }
- var stringContent string
- if err := json.Unmarshal(m.Content, &stringContent); err == nil {
- m.parsedStringContent = &stringContent
- return true
- }
- return false
-}
-
-func (m *Message) ParseContent() []MediaContent {
- if m.parsedContent != nil {
- return m.parsedContent
- }
-
- var contentList []MediaContent
-
- // 先尝试解析为字符串
- var stringContent string
- if err := json.Unmarshal(m.Content, &stringContent); err == nil {
- contentList = []MediaContent{{
- Type: ContentTypeText,
- Text: stringContent,
- }}
- m.parsedContent = contentList
- return contentList
- }
-
- // 尝试解析为数组
- var arrayContent []map[string]interface{}
- if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
- for _, contentItem := range arrayContent {
- contentType, ok := contentItem["type"].(string)
- if !ok {
- continue
- }
-
- switch contentType {
- case ContentTypeText:
- if text, ok := contentItem["text"].(string); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeText,
- Text: text,
- })
- }
-
- case ContentTypeImageURL:
- imageUrl := contentItem["image_url"]
- temp := &MessageImageUrl{
- Detail: "high",
- }
- switch v := imageUrl.(type) {
- case string:
- temp.Url = v
- case map[string]interface{}:
- url, ok1 := v["url"].(string)
- detail, ok2 := v["detail"].(string)
- if ok2 {
- temp.Detail = detail
- }
- if ok1 {
- temp.Url = url
- }
- }
- contentList = append(contentList, MediaContent{
- Type: ContentTypeImageURL,
- ImageUrl: temp,
- })
-
- case ContentTypeInputAudio:
- if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
- data, ok1 := audioData["data"].(string)
- format, ok2 := audioData["format"].(string)
- if ok1 && ok2 {
- temp := &MessageInputAudio{
- Data: data,
- Format: format,
- }
- contentList = append(contentList, MediaContent{
- Type: ContentTypeInputAudio,
- InputAudio: temp,
- })
- }
- }
- case ContentTypeFile:
- if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
- fileId, ok3 := fileData["file_id"].(string)
- if ok3 {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeFile,
- File: &MessageFile{
- FileId: fileId,
- },
- })
- } else {
- fileName, ok1 := fileData["filename"].(string)
- fileDataStr, ok2 := fileData["file_data"].(string)
- if ok1 && ok2 {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeFile,
- File: &MessageFile{
- FileName: fileName,
- FileData: fileDataStr,
- },
- })
- }
- }
- }
- case ContentTypeVideoUrl:
- if videoUrl, ok := contentItem["video_url"].(string); ok {
- contentList = append(contentList, MediaContent{
- Type: ContentTypeVideoUrl,
- VideoUrl: &MessageVideoUrl{
- Url: videoUrl,
- },
- })
- }
- }
- }
- }
-
- if len(contentList) > 0 {
- m.parsedContent = contentList
- }
- return contentList
-}*/
-
-type WebSearchOptions struct {
- SearchContextSize string `json:"search_context_size,omitempty"`
- UserLocation json.RawMessage `json:"user_location,omitempty"`
-}
-
-// https://platform.openai.com/docs/api-reference/responses/create
-type OpenAIResponsesRequest struct {
- Model string `json:"model"`
- Input json.RawMessage `json:"input,omitempty"`
- Include json.RawMessage `json:"include,omitempty"`
- Instructions json.RawMessage `json:"instructions,omitempty"`
- MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
- Metadata json.RawMessage `json:"metadata,omitempty"`
- ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
- PreviousResponseID string `json:"previous_response_id,omitempty"`
- Reasoning *Reasoning `json:"reasoning,omitempty"`
- ServiceTier string `json:"service_tier,omitempty"`
- Store json.RawMessage `json:"store,omitempty"`
- PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- Text json.RawMessage `json:"text,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
- Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
- TopP float64 `json:"top_p,omitempty"`
- Truncation string `json:"truncation,omitempty"`
- User string `json:"user,omitempty"`
- MaxToolCalls uint `json:"max_tool_calls,omitempty"`
- Prompt json.RawMessage `json:"prompt,omitempty"`
-}
-
-func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var fileMeta = make([]*types.FileMeta, 0)
- var texts = make([]string, 0)
-
- if r.Input != nil {
- inputs := r.ParseInput()
- for _, input := range inputs {
- if input.Type == "input_image" {
- if input.ImageUrl != "" {
- fileMeta = append(fileMeta, &types.FileMeta{
- FileType: types.FileTypeImage,
- OriginData: input.ImageUrl,
- Detail: input.Detail,
- })
- }
- } else if input.Type == "input_file" {
- if input.FileUrl != "" {
- fileMeta = append(fileMeta, &types.FileMeta{
- FileType: types.FileTypeFile,
- OriginData: input.FileUrl,
- })
- }
- } else {
- texts = append(texts, input.Text)
- }
- }
- }
-
- if len(r.Instructions) > 0 {
- texts = append(texts, string(r.Instructions))
- }
-
- if len(r.Metadata) > 0 {
- texts = append(texts, string(r.Metadata))
- }
-
- if len(r.Text) > 0 {
- texts = append(texts, string(r.Text))
- }
-
- if len(r.ToolChoice) > 0 {
- texts = append(texts, string(r.ToolChoice))
- }
-
- if len(r.Prompt) > 0 {
- texts = append(texts, string(r.Prompt))
- }
-
- if len(r.Tools) > 0 {
- texts = append(texts, string(r.Tools))
- }
-
- return &types.TokenCountMeta{
- CombineText: strings.Join(texts, "\n"),
- Files: fileMeta,
- MaxTokens: int(r.MaxOutputTokens),
- }
-}
-
-func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
- return r.Stream
-}
-
-func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
- var toolsMap []map[string]any
- if len(r.Tools) > 0 {
- _ = common.Unmarshal(r.Tools, &toolsMap)
- }
- return toolsMap
-}
-
-type Reasoning struct {
- Effort string `json:"effort,omitempty"`
- Summary string `json:"summary,omitempty"`
-}
-
-type MediaInput struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- FileUrl string `json:"file_url,omitempty"`
- ImageUrl string `json:"image_url,omitempty"`
- Detail string `json:"detail,omitempty"` // 仅 input_image 有效
-}
-
-// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
-// Reference implementation mirrors Message.ParseContent:
-// - input can be a string, treated as an input_text item
-// - input can be an array of objects with a `type` field
-// supported types: input_text, input_image, input_file
-func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
- if r.Input == nil {
- return nil
- }
-
- var inputs []MediaInput
-
- // Try string first
- // if str, ok := common.GetJsonType(r.Input); ok {
- // inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
- // return inputs
- // }
- if common.GetJsonType(r.Input) == "string" {
- var str string
- _ = common.Unmarshal(r.Input, &str)
- inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
- return inputs
- }
-
- // Try array of parts
- if common.GetJsonType(r.Input) == "array" {
- var array []any
- _ = common.Unmarshal(r.Input, &array)
- for _, itemAny := range array {
- // Already parsed MediaInput
- if media, ok := itemAny.(MediaInput); ok {
- inputs = append(inputs, media)
- continue
- }
- // Generic map
- item, ok := itemAny.(map[string]any)
- if !ok {
- continue
- }
- typeVal, ok := item["type"].(string)
- if !ok {
- continue
- }
- switch typeVal {
- case "input_text":
- text, _ := item["text"].(string)
- inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
- case "input_image":
- // image_url may be string or object with url field
- var imageUrl string
- switch v := item["image_url"].(type) {
- case string:
- imageUrl = v
- case map[string]any:
- if url, ok := v["url"].(string); ok {
- imageUrl = url
- }
- }
- inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
- case "input_file":
- // file_url may be string or object with url field
- var fileUrl string
- switch v := item["file_url"].(type) {
- case string:
- fileUrl = v
- case map[string]any:
- if url, ok := v["url"].(string); ok {
- fileUrl = url
- }
- }
- inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
- }
- }
- }
-
- return inputs
-}
diff --git a/new-api/dto/openai_response.go b/new-api/dto/openai_response.go
deleted file mode 100644
index c152b3d7f45e8dc628403c403840886381082d69..0000000000000000000000000000000000000000
--- a/new-api/dto/openai_response.go
+++ /dev/null
@@ -1,398 +0,0 @@
-package dto
-
-import (
- "encoding/json"
- "fmt"
- "one-api/types"
-)
-
-const (
- ResponsesOutputTypeImageGenerationCall = "image_generation_call"
-)
-
-type SimpleResponse struct {
- Usage `json:"usage"`
- Error any `json:"error"`
-}
-
-// GetOpenAIError 从动态错误类型中提取OpenAIError结构
-func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
- return GetOpenAIError(s.Error)
-}
-
-type TextResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []OpenAITextResponseChoice `json:"choices"`
- Usage `json:"usage"`
-}
-
-type OpenAITextResponseChoice struct {
- Index int `json:"index"`
- Message `json:"message"`
- FinishReason string `json:"finish_reason"`
-}
-
-type OpenAITextResponse struct {
- Id string `json:"id"`
- Model string `json:"model"`
- Object string `json:"object"`
- Created any `json:"created"`
- Choices []OpenAITextResponseChoice `json:"choices"`
- Error any `json:"error,omitempty"`
- Usage `json:"usage"`
-}
-
-// GetOpenAIError 从动态错误类型中提取OpenAIError结构
-func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
- return GetOpenAIError(o.Error)
-}
-
-type OpenAIEmbeddingResponseItem struct {
- Object string `json:"object"`
- Index int `json:"index"`
- Embedding []float64 `json:"embedding"`
-}
-
-type OpenAIEmbeddingResponse struct {
- Object string `json:"object"`
- Data []OpenAIEmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
-}
-
-type FlexibleEmbeddingResponseItem struct {
- Object string `json:"object"`
- Index int `json:"index"`
- Embedding any `json:"embedding"`
-}
-
-type FlexibleEmbeddingResponse struct {
- Object string `json:"object"`
- Data []FlexibleEmbeddingResponseItem `json:"data"`
- Model string `json:"model"`
- Usage `json:"usage"`
-}
-
-type ChatCompletionsStreamResponseChoice struct {
- Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
- Logprobs *any `json:"logprobs"`
- FinishReason *string `json:"finish_reason"`
- Index int `json:"index"`
-}
-
-type ChatCompletionsStreamResponseChoiceDelta struct {
- Content *string `json:"content,omitempty"`
- ReasoningContent *string `json:"reasoning_content,omitempty"`
- Reasoning *string `json:"reasoning,omitempty"`
- Role string `json:"role,omitempty"`
- ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
-}
-
-func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
- c.Content = &s
-}
-
-func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
- if c.Content == nil {
- return ""
- }
- return *c.Content
-}
-
-func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
- if c.ReasoningContent == nil && c.Reasoning == nil {
- return ""
- }
- if c.ReasoningContent != nil {
- return *c.ReasoningContent
- }
- return *c.Reasoning
-}
-
-func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
- c.ReasoningContent = &s
- //c.Reasoning = &s
-}
-
-type ToolCallResponse struct {
- // Index is not nil only in chat completion chunk object
- Index *int `json:"index,omitempty"`
- ID string `json:"id,omitempty"`
- Type any `json:"type"`
- Function FunctionResponse `json:"function"`
-}
-
-func (c *ToolCallResponse) SetIndex(i int) {
- c.Index = &i
-}
-
-type FunctionResponse struct {
- Description string `json:"description,omitempty"`
- Name string `json:"name,omitempty"`
- // call function with arguments in JSON format
- Parameters any `json:"parameters,omitempty"` // request
- Arguments string `json:"arguments"` // response
-}
-
-type ChatCompletionsStreamResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint *string `json:"system_fingerprint"`
- Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
- Usage *Usage `json:"usage"`
-}
-
-func (c *ChatCompletionsStreamResponse) IsFinished() bool {
- if len(c.Choices) == 0 {
- return false
- }
- return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
-}
-
-func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
- if len(c.Choices) == 0 {
- return false
- }
- return len(c.Choices[0].Delta.ToolCalls) > 0
-}
-
-func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
- if c.IsToolCall() {
- return &c.Choices[0].Delta.ToolCalls[0]
- }
- return nil
-}
-
-func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
- if !c.IsToolCall() {
- return
- }
- for choiceIdx := range c.Choices {
- for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
- c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
- c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
- c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
- }
- }
-}
-
-func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
- choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
- copy(choices, c.Choices)
- return &ChatCompletionsStreamResponse{
- Id: c.Id,
- Object: c.Object,
- Created: c.Created,
- Model: c.Model,
- SystemFingerprint: c.SystemFingerprint,
- Choices: choices,
- Usage: c.Usage,
- }
-}
-
-func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
- if c.SystemFingerprint == nil {
- return ""
- }
- return *c.SystemFingerprint
-}
-
-func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
- c.SystemFingerprint = &s
-}
-
-type ChatCompletionsStreamResponseSimple struct {
- Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
- Usage *Usage `json:"usage"`
-}
-
-type CompletionsStreamResponse struct {
- Choices []struct {
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
- } `json:"choices"`
-}
-
-type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
-
- PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
- CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
- // OpenRouter Params
- Cost any `json:"cost,omitempty"`
-}
-
-type InputTokenDetails struct {
- CachedTokens int `json:"cached_tokens"`
- CachedCreationTokens int `json:"-"`
- TextTokens int `json:"text_tokens"`
- AudioTokens int `json:"audio_tokens"`
- ImageTokens int `json:"image_tokens"`
-}
-
-type OutputTokenDetails struct {
- TextTokens int `json:"text_tokens"`
- AudioTokens int `json:"audio_tokens"`
- ReasoningTokens int `json:"reasoning_tokens"`
-}
-
-type OpenAIResponsesResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- CreatedAt int `json:"created_at"`
- Status string `json:"status"`
- Error any `json:"error,omitempty"`
- IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
- Instructions string `json:"instructions"`
- MaxOutputTokens int `json:"max_output_tokens"`
- Model string `json:"model"`
- Output []ResponsesOutput `json:"output"`
- ParallelToolCalls bool `json:"parallel_tool_calls"`
- PreviousResponseID string `json:"previous_response_id"`
- Reasoning *Reasoning `json:"reasoning"`
- Store bool `json:"store"`
- Temperature float64 `json:"temperature"`
- ToolChoice string `json:"tool_choice"`
- Tools []map[string]any `json:"tools"`
- TopP float64 `json:"top_p"`
- Truncation string `json:"truncation"`
- Usage *Usage `json:"usage"`
- User json.RawMessage `json:"user"`
- Metadata json.RawMessage `json:"metadata"`
-}
-
-// GetOpenAIError 从动态错误类型中提取OpenAIError结构
-func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
- return GetOpenAIError(o.Error)
-}
-
-func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
- if len(o.Output) == 0 {
- return false
- }
- for _, output := range o.Output {
- if output.Type == ResponsesOutputTypeImageGenerationCall {
- return true
- }
- }
- return false
-}
-
-func (o *OpenAIResponsesResponse) GetQuality() string {
- if len(o.Output) == 0 {
- return ""
- }
- for _, output := range o.Output {
- if output.Type == ResponsesOutputTypeImageGenerationCall {
- return output.Quality
- }
- }
- return ""
-}
-
-func (o *OpenAIResponsesResponse) GetSize() string {
- if len(o.Output) == 0 {
- return ""
- }
- for _, output := range o.Output {
- if output.Type == ResponsesOutputTypeImageGenerationCall {
- return output.Size
- }
- }
- return ""
-}
-
-type IncompleteDetails struct {
- Reasoning string `json:"reasoning"`
-}
-
-type ResponsesOutput struct {
- Type string `json:"type"`
- ID string `json:"id"`
- Status string `json:"status"`
- Role string `json:"role"`
- Content []ResponsesOutputContent `json:"content"`
- Quality string `json:"quality"`
- Size string `json:"size"`
-}
-
-type ResponsesOutputContent struct {
- Type string `json:"type"`
- Text string `json:"text"`
- Annotations []interface{} `json:"annotations"`
-}
-
-const (
- BuildInToolWebSearchPreview = "web_search_preview"
- BuildInToolFileSearch = "file_search"
-)
-
-const (
- BuildInCallWebSearchCall = "web_search_call"
-)
-
-const (
- ResponsesOutputTypeItemAdded = "response.output_item.added"
- ResponsesOutputTypeItemDone = "response.output_item.done"
-)
-
-// ResponsesStreamResponse 用于处理 /v1/responses 流式响应
-type ResponsesStreamResponse struct {
- Type string `json:"type"`
- Response *OpenAIResponsesResponse `json:"response,omitempty"`
- Delta string `json:"delta,omitempty"`
- Item *ResponsesOutput `json:"item,omitempty"`
-}
-
-// GetOpenAIError 从动态错误类型中提取OpenAIError结构
-func GetOpenAIError(errorField any) *types.OpenAIError {
- if errorField == nil {
- return nil
- }
-
- switch err := errorField.(type) {
- case types.OpenAIError:
- return &err
- case *types.OpenAIError:
- return err
- case map[string]interface{}:
- // 处理从JSON解析来的map结构
- openaiErr := &types.OpenAIError{}
- if errType, ok := err["type"].(string); ok {
- openaiErr.Type = errType
- }
- if errMsg, ok := err["message"].(string); ok {
- openaiErr.Message = errMsg
- }
- if errParam, ok := err["param"].(string); ok {
- openaiErr.Param = errParam
- }
- if errCode, ok := err["code"]; ok {
- openaiErr.Code = errCode
- }
- return openaiErr
- case string:
- // 处理简单字符串错误
- return &types.OpenAIError{
- Type: "error",
- Message: err,
- }
- default:
- // 未知类型,尝试转换为字符串
- return &types.OpenAIError{
- Type: "unknown_error",
- Message: fmt.Sprintf("%v", err),
- }
- }
-}
diff --git a/new-api/dto/playground.go b/new-api/dto/playground.go
deleted file mode 100644
index 75f4fc6ff1a79abca7ffaea8a20d0324afa579ce..0000000000000000000000000000000000000000
--- a/new-api/dto/playground.go
+++ /dev/null
@@ -1,6 +0,0 @@
-package dto
-
-type PlayGroundRequest struct {
- Model string `json:"model,omitempty"`
- Group string `json:"group,omitempty"`
-}
diff --git a/new-api/dto/pricing.go b/new-api/dto/pricing.go
deleted file mode 100644
index 37e60c8b505ec0ce53e1d856893c04afd63e352c..0000000000000000000000000000000000000000
--- a/new-api/dto/pricing.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package dto
-
-import "one-api/constant"
-
-// 这里不好动就不动了,本来想独立出来的(
-type OpenAIModels struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int `json:"created"`
- OwnedBy string `json:"owned_by"`
- SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
-}
-
-type AnthropicModel struct {
- ID string `json:"id"`
- CreatedAt string `json:"created_at"`
- DisplayName string `json:"display_name"`
- Type string `json:"type"`
-}
-
-type GeminiModel struct {
- Name interface{} `json:"name"`
- BaseModelId interface{} `json:"baseModelId"`
- Version interface{} `json:"version"`
- DisplayName interface{} `json:"displayName"`
- Description interface{} `json:"description"`
- InputTokenLimit interface{} `json:"inputTokenLimit"`
- OutputTokenLimit interface{} `json:"outputTokenLimit"`
- SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
- Thinking interface{} `json:"thinking"`
- Temperature interface{} `json:"temperature"`
- MaxTemperature interface{} `json:"maxTemperature"`
- TopP interface{} `json:"topP"`
- TopK interface{} `json:"topK"`
-}
diff --git a/new-api/dto/ratio_sync.go b/new-api/dto/ratio_sync.go
deleted file mode 100644
index c475069ccf47baf8cc09a2d443111ff508f68f5d..0000000000000000000000000000000000000000
--- a/new-api/dto/ratio_sync.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package dto
-
-type UpstreamDTO struct {
- ID int `json:"id,omitempty"`
- Name string `json:"name" binding:"required"`
- BaseURL string `json:"base_url" binding:"required"`
- Endpoint string `json:"endpoint"`
-}
-
-type UpstreamRequest struct {
- ChannelIDs []int64 `json:"channel_ids"`
- Upstreams []UpstreamDTO `json:"upstreams"`
- Timeout int `json:"timeout"`
-}
-
-// TestResult 上游测试连通性结果
-type TestResult struct {
- Name string `json:"name"`
- Status string `json:"status"`
- Error string `json:"error,omitempty"`
-}
-
-// DifferenceItem 差异项
-// Current 为本地值,可能为 nil
-// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
-
-type DifferenceItem struct {
- Current interface{} `json:"current"`
- Upstreams map[string]interface{} `json:"upstreams"`
- Confidence map[string]bool `json:"confidence"`
-}
-
-type SyncableChannel struct {
- ID int `json:"id"`
- Name string `json:"name"`
- BaseURL string `json:"base_url"`
- Status int `json:"status"`
-}
diff --git a/new-api/dto/realtime.go b/new-api/dto/realtime.go
deleted file mode 100644
index cfc435c1839b1e3898f55ba822b587b181ccaedd..0000000000000000000000000000000000000000
--- a/new-api/dto/realtime.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package dto
-
-import "one-api/types"
-
-const (
- RealtimeEventTypeError = "error"
- RealtimeEventTypeSessionUpdate = "session.update"
- RealtimeEventTypeConversationCreate = "conversation.item.create"
- RealtimeEventTypeResponseCreate = "response.create"
- RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
-)
-
-const (
- RealtimeEventTypeResponseDone = "response.done"
- RealtimeEventTypeSessionUpdated = "session.updated"
- RealtimeEventTypeSessionCreated = "session.created"
- RealtimeEventResponseAudioDelta = "response.audio.delta"
- RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
- RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
- RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
- RealtimeEventConversationItemCreated = "conversation.item.created"
-)
-
-type RealtimeEvent struct {
- EventId string `json:"event_id"`
- Type string `json:"type"`
- //PreviousItemId string `json:"previous_item_id"`
- Session *RealtimeSession `json:"session,omitempty"`
- Item *RealtimeItem `json:"item,omitempty"`
- Error *types.OpenAIError `json:"error,omitempty"`
- Response *RealtimeResponse `json:"response,omitempty"`
- Delta string `json:"delta,omitempty"`
- Audio string `json:"audio,omitempty"`
-}
-
-type RealtimeResponse struct {
- Usage *RealtimeUsage `json:"usage"`
-}
-
-type RealtimeUsage struct {
- TotalTokens int `json:"total_tokens"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- InputTokenDetails InputTokenDetails `json:"input_token_details"`
- OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
-}
-
-type RealtimeSession struct {
- Modalities []string `json:"modalities"`
- Instructions string `json:"instructions"`
- Voice string `json:"voice"`
- InputAudioFormat string `json:"input_audio_format"`
- OutputAudioFormat string `json:"output_audio_format"`
- InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
- TurnDetection interface{} `json:"turn_detection"`
- Tools []RealTimeTool `json:"tools"`
- ToolChoice string `json:"tool_choice"`
- Temperature float64 `json:"temperature"`
- //MaxResponseOutputTokens int `json:"max_response_output_tokens"`
-}
-
-type InputAudioTranscription struct {
- Model string `json:"model"`
-}
-
-type RealTimeTool struct {
- Type string `json:"type"`
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters any `json:"parameters"`
-}
-
-type RealtimeItem struct {
- Id string `json:"id"`
- Type string `json:"type"`
- Status string `json:"status"`
- Role string `json:"role"`
- Content []RealtimeContent `json:"content"`
- Name *string `json:"name,omitempty"`
- ToolCalls any `json:"tool_calls,omitempty"`
- CallId string `json:"call_id,omitempty"`
-}
-type RealtimeContent struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
- Transcript string `json:"transcript,omitempty"`
-}
diff --git a/new-api/dto/request_common.go b/new-api/dto/request_common.go
deleted file mode 100644
index 39f92a6bfe0b95833340730d840673bc191923ca..0000000000000000000000000000000000000000
--- a/new-api/dto/request_common.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package dto
-
-import (
- "github.com/gin-gonic/gin"
- "one-api/types"
-)
-
-type Request interface {
- GetTokenCountMeta() *types.TokenCountMeta
- IsStream(c *gin.Context) bool
- SetModelName(modelName string)
-}
-
-type BaseRequest struct {
-}
-
-func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
- return &types.TokenCountMeta{
- TokenType: types.TokenTypeTokenizer,
- }
-}
-func (b *BaseRequest) IsStream(c *gin.Context) bool {
- return false
-}
-func (b *BaseRequest) SetModelName(modelName string) {}
diff --git a/new-api/dto/rerank.go b/new-api/dto/rerank.go
deleted file mode 100644
index fe11579d5083bf71ecb61d7724701a753ab6533a..0000000000000000000000000000000000000000
--- a/new-api/dto/rerank.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package dto
-
-import (
- "fmt"
- "github.com/gin-gonic/gin"
- "one-api/types"
- "strings"
-)
-
-type RerankRequest struct {
- Documents []any `json:"documents"`
- Query string `json:"query"`
- Model string `json:"model"`
- TopN int `json:"top_n,omitempty"`
- ReturnDocuments *bool `json:"return_documents,omitempty"`
- MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
- OverLapTokens int `json:"overlap_tokens,omitempty"`
-}
-
-func (r *RerankRequest) IsStream(c *gin.Context) bool {
- return false
-}
-
-func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
- var texts = make([]string, 0)
-
- for _, document := range r.Documents {
- texts = append(texts, fmt.Sprintf("%v", document))
- }
-
- if r.Query != "" {
- texts = append(texts, r.Query)
- }
-
- return &types.TokenCountMeta{
- CombineText: strings.Join(texts, "\n"),
- }
-}
-
-func (r *RerankRequest) SetModelName(modelName string) {
- if modelName != "" {
- r.Model = modelName
- }
-}
-
-func (r *RerankRequest) GetReturnDocuments() bool {
- if r.ReturnDocuments == nil {
- return false
- }
- return *r.ReturnDocuments
-}
-
-type RerankResponseResult struct {
- Document any `json:"document,omitempty"`
- Index int `json:"index"`
- RelevanceScore float64 `json:"relevance_score"`
-}
-
-type RerankDocument struct {
- Text any `json:"text"`
-}
-
-type RerankResponse struct {
- Results []RerankResponseResult `json:"results"`
- Usage Usage `json:"usage"`
-}
diff --git a/new-api/dto/sensitive.go b/new-api/dto/sensitive.go
deleted file mode 100644
index 8b2956d0133ac9f3e03d743a338a8a0f248310b4..0000000000000000000000000000000000000000
--- a/new-api/dto/sensitive.go
+++ /dev/null
@@ -1,6 +0,0 @@
-package dto
-
-type SensitiveResponse struct {
- SensitiveWords []string `json:"sensitive_words"`
- Content string `json:"content"`
-}
diff --git a/new-api/dto/suno.go b/new-api/dto/suno.go
deleted file mode 100644
index 54e1b9ebc8ffac7f30afdbd100d5ba79d22e1f58..0000000000000000000000000000000000000000
--- a/new-api/dto/suno.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package dto
-
-import (
- "encoding/json"
-)
-
-type TaskData interface {
- SunoDataResponse | []SunoDataResponse | string | any
-}
-
-type SunoSubmitReq struct {
- GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
- Prompt string `json:"prompt,omitempty"`
- Mv string `json:"mv,omitempty"`
- Title string `json:"title,omitempty"`
- Tags string `json:"tags,omitempty"`
- ContinueAt float64 `json:"continue_at,omitempty"`
- TaskID string `json:"task_id,omitempty"`
- ContinueClipId string `json:"continue_clip_id,omitempty"`
- MakeInstrumental bool `json:"make_instrumental"`
-}
-
-type FetchReq struct {
- IDs []string `json:"ids"`
-}
-
-type SunoDataResponse struct {
- TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
- Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
- Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
- FailReason string `json:"fail_reason"`
- SubmitTime int64 `json:"submit_time" gorm:"index"`
- StartTime int64 `json:"start_time" gorm:"index"`
- FinishTime int64 `json:"finish_time" gorm:"index"`
- Data json.RawMessage `json:"data" gorm:"type:json"`
-}
-
-type SunoSong struct {
- ID string `json:"id"`
- VideoURL string `json:"video_url"`
- AudioURL string `json:"audio_url"`
- ImageURL string `json:"image_url"`
- ImageLargeURL string `json:"image_large_url"`
- MajorModelVersion string `json:"major_model_version"`
- ModelName string `json:"model_name"`
- Status string `json:"status"`
- Title string `json:"title"`
- Text string `json:"text"`
- Metadata SunoMetadata `json:"metadata"`
-}
-
-type SunoMetadata struct {
- Tags string `json:"tags"`
- Prompt string `json:"prompt"`
- GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
- AudioPromptID interface{} `json:"audio_prompt_id"`
- Duration interface{} `json:"duration"`
- ErrorType interface{} `json:"error_type"`
- ErrorMessage interface{} `json:"error_message"`
-}
-
-type SunoLyrics struct {
- ID string `json:"id"`
- Status string `json:"status"`
- Title string `json:"title"`
- Text string `json:"text"`
-}
-
-const TaskSuccessCode = "success"
-
-type TaskResponse[T TaskData] struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Data T `json:"data"`
-}
-
-func (t *TaskResponse[T]) IsSuccess() bool {
- return t.Code == TaskSuccessCode
-}
-
-type TaskDto struct {
- TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
- Action string `json:"action"` // 任务类型, song, lyrics, description-mode
- Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
- FailReason string `json:"fail_reason"`
- SubmitTime int64 `json:"submit_time"`
- StartTime int64 `json:"start_time"`
- FinishTime int64 `json:"finish_time"`
- Progress string `json:"progress"`
- Data json.RawMessage `json:"data"`
-}
-
-type SunoGoAPISubmitReq struct {
- CustomMode bool `json:"custom_mode"`
-
- Input SunoGoAPISubmitReqInput `json:"input"`
-
- NotifyHook string `json:"notify_hook,omitempty"`
-}
-
-type SunoGoAPISubmitReqInput struct {
- GptDescriptionPrompt string `json:"gpt_description_prompt"`
- Prompt string `json:"prompt"`
- Mv string `json:"mv"`
- Title string `json:"title"`
- Tags string `json:"tags"`
- ContinueAt float64 `json:"continue_at"`
- TaskID string `json:"task_id"`
- ContinueClipId string `json:"continue_clip_id"`
- MakeInstrumental bool `json:"make_instrumental"`
-}
-
-type GoAPITaskResponse[T any] struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data T `json:"data"`
- ErrorMessage string `json:"error_message,omitempty"`
-}
-
-type GoAPITaskResponseData struct {
- TaskID string `json:"task_id"`
-}
-
-type GoAPIFetchResponseData struct {
- TaskID string `json:"task_id"`
- Status string `json:"status"`
- Input string `json:"input"`
- Clips map[string]SunoSong `json:"clips"`
-}
diff --git a/new-api/dto/task.go b/new-api/dto/task.go
deleted file mode 100644
index 3300f4b3e50086c5ba0d3aa105a6e087d256094c..0000000000000000000000000000000000000000
--- a/new-api/dto/task.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package dto
-
-type TaskError struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Data any `json:"data"`
- StatusCode int `json:"-"`
- LocalError bool `json:"-"`
- Error error `json:"-"`
-}
diff --git a/new-api/dto/user_settings.go b/new-api/dto/user_settings.go
deleted file mode 100644
index b4a1eee9c875319bf83a720b5cbf9998a0c91e77..0000000000000000000000000000000000000000
--- a/new-api/dto/user_settings.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package dto
-
-type UserSetting struct {
- NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
- QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
- WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
- WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
- NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
- BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
- AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
- RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
- SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
-}
-
-var (
- NotifyTypeEmail = "email" // Email 邮件
- NotifyTypeWebhook = "webhook" // Webhook
- NotifyTypeBark = "bark" // Bark 推送
-)
diff --git a/new-api/dto/video.go b/new-api/dto/video.go
deleted file mode 100644
index 367feec5d742b939cd22b12f1af10abe12ce6e8a..0000000000000000000000000000000000000000
--- a/new-api/dto/video.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package dto
-
-type VideoRequest struct {
- Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
- Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
- Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
- Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
- Width int `json:"width" example:"512"` // Video width
- Height int `json:"height" example:"512"` // Video height
- Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
- Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
- N int `json:"n,omitempty" example:"1"` // Number of videos to generate
- ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
- User string `json:"user,omitempty" example:"user-1234"` // User identifier
- Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
-}
-
-// VideoResponse 视频生成提交任务后的响应
-type VideoResponse struct {
- TaskId string `json:"task_id"`
- Status string `json:"status"`
-}
-
-// VideoTaskResponse 查询视频生成任务状态的响应
-type VideoTaskResponse struct {
- TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
- Status string `json:"status" example:"succeeded"` // 任务状态
- Url string `json:"url,omitempty"` // 视频资源URL(成功时)
- Format string `json:"format,omitempty" example:"mp4"` // 视频格式
- Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
- Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
-}
-
-// VideoTaskMetadata 视频任务元数据
-type VideoTaskMetadata struct {
- Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
- Fps int `json:"fps" example:"30"` // 实际帧率
- Width int `json:"width" example:"512"` // 实际宽度
- Height int `json:"height" example:"512"` // 实际高度
- Seed int `json:"seed" example:"20231234"` // 使用的随机种子
-}
-
-// VideoTaskError 视频任务错误信息
-type VideoTaskError struct {
- Code int `json:"code"`
- Message string `json:"message"`
-}
diff --git a/new-api/go.mod b/new-api/go.mod
deleted file mode 100644
index a3bdd2fb44d2ff7aace5c82fd18f3a561672e8f4..0000000000000000000000000000000000000000
--- a/new-api/go.mod
+++ /dev/null
@@ -1,113 +0,0 @@
-module one-api
-
-// +heroku goVersion go1.18
-go 1.24.0
-
-toolchain go1.24.6
-
-require (
- github.com/Calcium-Ion/go-epay v0.0.4
- github.com/andybalholm/brotli v1.1.1
- github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
- github.com/aws/aws-sdk-go-v2 v1.37.2
- github.com/aws/aws-sdk-go-v2/credentials v1.17.11
- github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
- github.com/aws/smithy-go v1.22.5
- github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
- github.com/gin-contrib/cors v1.7.2
- github.com/gin-contrib/gzip v0.0.6
- github.com/gin-contrib/sessions v0.0.5
- github.com/gin-contrib/static v0.0.1
- github.com/gin-gonic/gin v1.9.1
- github.com/glebarez/sqlite v1.9.0
- github.com/go-playground/validator/v10 v10.20.0
- github.com/go-redis/redis/v8 v8.11.5
- github.com/go-webauthn/webauthn v0.14.0
- github.com/golang-jwt/jwt v3.2.2+incompatible
- github.com/google/uuid v1.6.0
- github.com/gorilla/websocket v1.5.0
- github.com/jinzhu/copier v0.4.0
- github.com/joho/godotenv v1.5.1
- github.com/pkg/errors v0.9.1
- github.com/pquerna/otp v1.5.0
- github.com/samber/lo v1.39.0
- github.com/shirou/gopsutil v3.21.11+incompatible
- github.com/shopspring/decimal v1.4.0
- github.com/stripe/stripe-go/v81 v81.4.0
- github.com/thanhpk/randstr v1.0.6
- github.com/tidwall/gjson v1.18.0
- github.com/tidwall/sjson v1.2.5
- github.com/tiktoken-go/tokenizer v0.6.2
- golang.org/x/crypto v0.42.0
- golang.org/x/image v0.23.0
- golang.org/x/net v0.43.0
- golang.org/x/sync v0.17.0
- gorm.io/driver/mysql v1.4.3
- gorm.io/driver/postgres v1.5.2
- gorm.io/gorm v1.25.2
-)
-
-require (
- github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
- github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
- github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
- github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
- github.com/boombuler/barcode v1.1.0 // indirect
- github.com/bytedance/sonic v1.11.6 // indirect
- github.com/bytedance/sonic/loader v0.1.1 // indirect
- github.com/cespare/xxhash/v2 v2.3.0 // indirect
- github.com/cloudwego/base64x v0.1.4 // indirect
- github.com/cloudwego/iasm v0.2.0 // indirect
- github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
- github.com/dlclark/regexp2 v1.11.5 // indirect
- github.com/dustin/go-humanize v1.0.1 // indirect
- github.com/fxamacker/cbor/v2 v2.9.0 // indirect
- github.com/gabriel-vasile/mimetype v1.4.3 // indirect
- github.com/gin-contrib/sse v0.1.0 // indirect
- github.com/glebarez/go-sqlite v1.21.2 // indirect
- github.com/go-ole/go-ole v1.2.6 // indirect
- github.com/go-playground/locales v0.14.1 // indirect
- github.com/go-playground/universal-translator v0.18.1 // indirect
- github.com/go-sql-driver/mysql v1.7.0 // indirect
- github.com/go-webauthn/x v0.1.25 // indirect
- github.com/goccy/go-json v0.10.2 // indirect
- github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
- github.com/google/go-cmp v0.6.0 // indirect
- github.com/google/go-tpm v0.9.5 // indirect
- github.com/gorilla/context v1.1.1 // indirect
- github.com/gorilla/securecookie v1.1.1 // indirect
- github.com/gorilla/sessions v1.2.1 // indirect
- github.com/jackc/pgpassfile v1.0.0 // indirect
- github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
- github.com/jackc/pgx/v5 v5.7.1 // indirect
- github.com/jackc/puddle/v2 v2.2.2 // indirect
- github.com/jinzhu/inflection v1.0.0 // indirect
- github.com/jinzhu/now v1.1.5 // indirect
- github.com/json-iterator/go v1.1.12 // indirect
- github.com/klauspost/cpuid/v2 v2.2.9 // indirect
- github.com/leodido/go-urn v1.4.0 // indirect
- github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/mitchellh/mapstructure v1.5.0 // indirect
- github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
- github.com/modern-go/reflect2 v1.0.2 // indirect
- github.com/pelletier/go-toml/v2 v2.2.1 // indirect
- github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
- github.com/tidwall/match v1.1.1 // indirect
- github.com/tidwall/pretty v1.2.0 // indirect
- github.com/tklauser/go-sysconf v0.3.12 // indirect
- github.com/tklauser/numcpus v0.6.1 // indirect
- github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
- github.com/ugorji/go/codec v1.2.12 // indirect
- github.com/x448/float16 v0.8.4 // indirect
- github.com/yusufpapurcu/wmi v1.2.3 // indirect
- golang.org/x/arch v0.12.0 // indirect
- golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
- golang.org/x/sys v0.36.0 // indirect
- golang.org/x/text v0.29.0 // indirect
- google.golang.org/protobuf v1.34.2 // indirect
- gopkg.in/yaml.v3 v3.0.1 // indirect
- modernc.org/libc v1.22.5 // indirect
- modernc.org/mathutil v1.5.0 // indirect
- modernc.org/memory v1.5.0 // indirect
- modernc.org/sqlite v1.23.1 // indirect
-)
diff --git a/new-api/go.sum b/new-api/go.sum
deleted file mode 100644
index 9780e5c773d0379a5c8327c89209e9274e2f13ff..0000000000000000000000000000000000000000
--- a/new-api/go.sum
+++ /dev/null
@@ -1,324 +0,0 @@
-github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
-github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
-github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
-github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
-github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
-github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
-github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
-github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
-github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
-github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
-github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
-github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
-github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
-github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
-github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
-github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
-github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
-github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
-github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
-github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
-github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
-github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
-github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
-github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
-github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
-github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
-github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
-github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
-github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
-github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
-github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
-github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
-github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
-github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
-github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
-github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
-github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
-github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
-github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
-github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
-github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
-github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
-github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
-github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
-github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
-github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
-github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
-github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
-github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
-github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
-github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
-github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
-github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
-github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U=
-github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
-github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
-github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
-github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
-github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
-github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
-github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
-github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
-github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
-github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
-github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
-github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
-github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
-github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
-github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
-github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
-github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
-github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
-github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
-github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
-github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
-github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
-github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
-github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
-github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
-github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
-github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
-github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
-github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
-github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
-github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
-github.com/go-webauthn/webauthn v0.14.0 h1:ZLNPUgPcDlAeoxe+5umWG/tEeCoQIDr7gE2Zx2QnhL0=
-github.com/go-webauthn/webauthn v0.14.0/go.mod h1:QZzPFH3LJ48u5uEPAu+8/nWJImoLBWM7iAH/kSVSo6k=
-github.com/go-webauthn/x v0.1.25 h1:g/0noooIGcz/yCVqebcFgNnGIgBlJIccS+LYAa+0Z88=
-github.com/go-webauthn/x v0.1.25/go.mod h1:ieblaPY1/BVCV0oQTsA/VAo08/TWayQuJuo5Q+XxmTY=
-github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
-github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
-github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
-github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
-github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
-github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
-github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
-github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
-github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
-github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
-github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
-github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
-github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
-github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
-github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
-github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
-github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
-github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
-github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
-github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
-github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
-github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
-github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
-github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
-github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
-github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
-github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
-github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
-github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
-github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
-github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
-github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
-github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
-github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
-github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
-github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
-github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
-github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
-github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
-github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
-github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
-github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
-github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
-github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
-github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
-github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
-github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
-github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
-github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
-github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
-github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
-github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
-github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
-github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
-github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
-github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
-github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
-github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
-github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
-github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
-github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
-github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
-github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
-github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
-github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
-github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
-github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
-github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
-github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
-github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
-github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
-github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
-github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
-github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
-github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
-github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
-github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
-github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
-github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
-github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
-github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
-github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
-github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
-github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
-github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
-github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
-github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
-github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
-github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
-github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
-github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
-github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
-github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
-github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
-github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
-github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
-github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
-github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
-github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
-github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
-github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
-github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
-github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
-github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
-github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
-github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
-github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
-github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
-github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
-github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
-github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
-github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
-github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
-github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
-github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
-github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
-github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
-github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
-github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
-go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
-go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
-golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
-golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
-golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
-golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
-golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
-golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
-golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
-golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
-golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
-golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
-golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
-golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
-golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
-golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
-golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
-google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
-gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
-gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
-gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
-gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
-gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
-gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
-gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
-gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
-modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
-modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
-modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
-modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
-modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
-modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
-modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
-modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
-nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
-rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
diff --git a/new-api/logger/logger.go b/new-api/logger/logger.go
deleted file mode 100644
index 42860b60514b02870abdd76ce75dbf5a25433381..0000000000000000000000000000000000000000
--- a/new-api/logger/logger.go
+++ /dev/null
@@ -1,118 +0,0 @@
-package logger
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "log"
- "one-api/common"
- "os"
- "path/filepath"
- "sync"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
-)
-
-const (
- loggerINFO = "INFO"
- loggerWarn = "WARN"
- loggerError = "ERR"
- loggerDebug = "DEBUG"
-)
-
-const maxLogCount = 1000000
-
-var logCount int
-var setupLogLock sync.Mutex
-var setupLogWorking bool
-
-func SetupLogger() {
- defer func() {
- setupLogWorking = false
- }()
- if *common.LogDir != "" {
- ok := setupLogLock.TryLock()
- if !ok {
- log.Println("setup log is already working")
- return
- }
- defer func() {
- setupLogLock.Unlock()
- }()
- logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
- fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
- if err != nil {
- log.Fatal("failed to open log file")
- }
- gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
- gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
- }
-}
-
-func LogInfo(ctx context.Context, msg string) {
- logHelper(ctx, loggerINFO, msg)
-}
-
-func LogWarn(ctx context.Context, msg string) {
- logHelper(ctx, loggerWarn, msg)
-}
-
-func LogError(ctx context.Context, msg string) {
- logHelper(ctx, loggerError, msg)
-}
-
-func LogDebug(ctx context.Context, msg string) {
- if common.DebugEnabled {
- logHelper(ctx, loggerDebug, msg)
- }
-}
-
-func logHelper(ctx context.Context, level string, msg string) {
- writer := gin.DefaultErrorWriter
- if level == loggerINFO {
- writer = gin.DefaultWriter
- }
- id := ctx.Value(common.RequestIdKey)
- if id == nil {
- id = "SYSTEM"
- }
- now := time.Now()
- _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
- logCount++ // we don't need accurate count, so no lock here
- if logCount > maxLogCount && !setupLogWorking {
- logCount = 0
- setupLogWorking = true
- gopool.Go(func() {
- SetupLogger()
- })
- }
-}
-
-func LogQuota(quota int) string {
- if common.DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d 点额度", quota)
- }
-}
-
-func FormatQuota(quota int) string {
- if common.DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
- } else {
- return fmt.Sprintf("%d", quota)
- }
-}
-
-// LogJson 仅供测试使用 only for test
-func LogJson(ctx context.Context, msg string, obj any) {
- jsonStr, err := json.Marshal(obj)
- if err != nil {
- LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
- return
- }
- LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
-}
diff --git a/new-api/main.go b/new-api/main.go
deleted file mode 100644
index 8404f1ac9da2213580ad40ae1331217d9a95ec27..0000000000000000000000000000000000000000
--- a/new-api/main.go
+++ /dev/null
@@ -1,232 +0,0 @@
-package main
-
-import (
- "bytes"
- "embed"
- "fmt"
- "log"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/controller"
- "one-api/logger"
- "one-api/middleware"
- "one-api/model"
- "one-api/router"
- "one-api/service"
- "one-api/setting/ratio_setting"
- "os"
- "strconv"
- "strings"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-contrib/sessions"
- "github.com/gin-contrib/sessions/cookie"
- "github.com/gin-gonic/gin"
- "github.com/joho/godotenv"
-
- _ "net/http/pprof"
-)
-
-//go:embed web/dist
-var buildFS embed.FS
-
-//go:embed web/dist/index.html
-var indexPage []byte
-
-func main() {
- startTime := time.Now()
-
- err := InitResources()
- if err != nil {
- common.FatalLog("failed to initialize resources: " + err.Error())
- return
- }
-
- common.SysLog("New API " + common.Version + " started")
- if os.Getenv("GIN_MODE") != "debug" {
- gin.SetMode(gin.ReleaseMode)
- }
- if common.DebugEnabled {
- common.SysLog("running in debug mode")
- }
-
- defer func() {
- err := model.CloseDB()
- if err != nil {
- common.FatalLog("failed to close database: " + err.Error())
- }
- }()
-
- if common.RedisEnabled {
- // for compatibility with old versions
- common.MemoryCacheEnabled = true
- }
- if common.MemoryCacheEnabled {
- common.SysLog("memory cache enabled")
- common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
-
- // Add panic recovery and retry for InitChannelCache
- func() {
- defer func() {
- if r := recover(); r != nil {
- common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
- // Retry once
- _, _, fixErr := model.FixAbility()
- if fixErr != nil {
- common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
- }
- }
- }()
- model.InitChannelCache()
- }()
-
- go model.SyncChannelCache(common.SyncFrequency)
- }
-
- // 热更新配置
- go model.SyncOptions(common.SyncFrequency)
-
- // 数据看板
- go model.UpdateQuotaData()
-
- if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
- frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
- if err != nil {
- common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
- }
- go controller.AutomaticallyUpdateChannels(frequency)
- }
-
- go controller.AutomaticallyTestChannels()
-
- if common.IsMasterNode && constant.UpdateTask {
- gopool.Go(func() {
- controller.UpdateMidjourneyTaskBulk()
- })
- gopool.Go(func() {
- controller.UpdateTaskBulk()
- })
- }
- if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
- common.BatchUpdateEnabled = true
- common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
- model.InitBatchUpdater()
- }
-
- if os.Getenv("ENABLE_PPROF") == "true" {
- gopool.Go(func() {
- log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
- })
- go common.Monitor()
- common.SysLog("pprof enabled")
- }
-
- // Initialize HTTP server
- server := gin.New()
- server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
- common.SysLog(fmt.Sprintf("panic detected: %v", err))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
- "type": "new_api_panic",
- },
- })
- }))
- // This will cause SSE not to work!!!
- //server.Use(gzip.Gzip(gzip.DefaultCompression))
- server.Use(middleware.RequestId())
- middleware.SetUpLogger(server)
- // Initialize session store
- store := cookie.NewStore([]byte(common.SessionSecret))
- store.Options(sessions.Options{
- Path: "/",
- MaxAge: 2592000, // 30 days
- HttpOnly: true,
- Secure: false,
- SameSite: http.SameSiteStrictMode,
- })
- server.Use(sessions.Sessions("session", store))
-
- analyticsInjectBuilder := &strings.Builder{}
- if os.Getenv("UMAMI_WEBSITE_ID") != "" {
- umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
- umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL")
- if umamiScriptURL == "" {
- umamiScriptURL = "https://analytics.umami.is/script.js"
- }
- analyticsInjectBuilder.WriteString("")
- }
- analyticsInject := analyticsInjectBuilder.String()
- indexPage = bytes.ReplaceAll(indexPage, []byte("\n"), []byte(analyticsInject))
-
- router.SetRouter(server, buildFS, indexPage)
- var port = os.Getenv("PORT")
- if port == "" {
- port = strconv.Itoa(*common.Port)
- }
-
- // Log startup success message
- common.LogStartupSuccess(startTime, port)
-
- err = server.Run(":" + port)
- if err != nil {
- common.FatalLog("failed to start HTTP server: " + err.Error())
- }
-}
-
-func InitResources() error {
- // Initialize resources here if needed
- // This is a placeholder function for future resource initialization
- err := godotenv.Load(".env")
- if err != nil {
- if common.DebugEnabled {
- common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
- }
- }
-
- // 加载环境变量
- common.InitEnv()
-
- logger.SetupLogger()
-
- // Initialize model settings
- ratio_setting.InitRatioSettings()
-
- service.InitHttpClient()
-
- service.InitTokenEncoders()
-
- // Initialize SQL Database
- err = model.InitDB()
- if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
- return err
- }
-
- model.CheckSetup()
-
- // Initialize options, should after model.InitDB()
- model.InitOptionMap()
-
- // 初始化模型
- model.GetPricing()
-
- // Initialize SQL Database
- err = model.InitLogDB()
- if err != nil {
- return err
- }
-
- // Initialize Redis
- err = common.InitRedisClient()
- if err != nil {
- return err
- }
- return nil
-}
diff --git a/new-api/makefile b/new-api/makefile
deleted file mode 100644
index 91555ebbc673a8feab44c360145e4f857df029c9..0000000000000000000000000000000000000000
--- a/new-api/makefile
+++ /dev/null
@@ -1,14 +0,0 @@
-FRONTEND_DIR = ./web
-BACKEND_DIR = .
-
-.PHONY: all build-frontend start-backend
-
-all: build-frontend start-backend
-
-build-frontend:
- @echo "Building frontend..."
- @cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
-
-start-backend:
- @echo "Starting backend dev server..."
- @cd $(BACKEND_DIR) && go run main.go &
diff --git a/new-api/middleware/auth.go b/new-api/middleware/auth.go
deleted file mode 100644
index 58404323e28e819bc8578c8e877add1a4fe642ef..0000000000000000000000000000000000000000
--- a/new-api/middleware/auth.go
+++ /dev/null
@@ -1,319 +0,0 @@
-package middleware
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/model"
- "one-api/setting"
- "one-api/setting/ratio_setting"
- "strconv"
- "strings"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-func validUserInfo(username string, role int) bool {
- // check username is empty
- if strings.TrimSpace(username) == "" {
- return false
- }
- if !common.IsValidateRole(role) {
- return false
- }
- return true
-}
-
-func authHelper(c *gin.Context, minRole int) {
- session := sessions.Default(c)
- username := session.Get("username")
- role := session.Get("role")
- id := session.Get("id")
- status := session.Get("status")
- useAccessToken := false
- if username == nil {
- // Check access token
- accessToken := c.Request.Header.Get("Authorization")
- if accessToken == "" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "无权进行此操作,未登录且未提供 access token",
- })
- c.Abort()
- return
- }
- user := model.ValidateAccessToken(accessToken)
- if user != nil && user.Username != "" {
- if !validUserInfo(user.Username, user.Role) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,用户信息无效",
- })
- c.Abort()
- return
- }
- // Token is valid
- username = user.Username
- role = user.Role
- id = user.Id
- status = user.Status
- useAccessToken = true
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,access token 无效",
- })
- c.Abort()
- return
- }
- }
- // get header New-Api-User
- apiUserIdStr := c.Request.Header.Get("New-Api-User")
- if apiUserIdStr == "" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "无权进行此操作,未提供 New-Api-User",
- })
- c.Abort()
- return
- }
- apiUserId, err := strconv.Atoi(apiUserIdStr)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "无权进行此操作,New-Api-User 格式错误",
- })
- c.Abort()
- return
-
- }
- if id != apiUserId {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "无权进行此操作,New-Api-User 与登录用户不匹配",
- })
- c.Abort()
- return
- }
- if status.(int) == common.UserStatusDisabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已被封禁",
- })
- c.Abort()
- return
- }
- if role.(int) < minRole {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,权限不足",
- })
- c.Abort()
- return
- }
- if !validUserInfo(username.(string), role.(int)) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,用户信息无效",
- })
- c.Abort()
- return
- }
- c.Set("username", username)
- c.Set("role", role)
- c.Set("id", id)
- c.Set("group", session.Get("group"))
- c.Set("user_group", session.Get("group"))
- c.Set("use_access_token", useAccessToken)
-
- //userCache, err := model.GetUserCache(id.(int))
- //if err != nil {
- // c.JSON(http.StatusOK, gin.H{
- // "success": false,
- // "message": err.Error(),
- // })
- // c.Abort()
- // return
- //}
- //userCache.WriteContext(c)
-
- c.Next()
-}
-
-func TryUserAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- session := sessions.Default(c)
- id := session.Get("id")
- if id != nil {
- c.Set("id", id)
- }
- c.Next()
- }
-}
-
-func UserAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleCommonUser)
- }
-}
-
-func AdminAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleAdminUser)
- }
-}
-
-func RootAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleRootUser)
- }
-}
-
-func WssAuth(c *gin.Context) {
-
-}
-
-func TokenAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- // 先检测是否为ws
- if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
- // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
- // read sk from Sec-WebSocket-Protocol
- key := c.Request.Header.Get("Sec-WebSocket-Protocol")
- parts := strings.Split(key, ",")
- for _, part := range parts {
- part = strings.TrimSpace(part)
- if strings.HasPrefix(part, "openai-insecure-api-key") {
- key = strings.TrimPrefix(part, "openai-insecure-api-key.")
- break
- }
- }
- c.Request.Header.Set("Authorization", "Bearer "+key)
- }
- // 检查path包含/v1/messages
- if strings.Contains(c.Request.URL.Path, "/v1/messages") {
- anthropicKey := c.Request.Header.Get("x-api-key")
- if anthropicKey != "" {
- c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
- }
- }
- // gemini api 从query中获取key
- if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
- strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
- strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
- skKey := c.Query("key")
- if skKey != "" {
- c.Request.Header.Set("Authorization", "Bearer "+skKey)
- }
- // 从x-goog-api-key header中获取key
- xGoogKey := c.Request.Header.Get("x-goog-api-key")
- if xGoogKey != "" {
- c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
- }
- }
- key := c.Request.Header.Get("Authorization")
- parts := make([]string, 0)
- key = strings.TrimPrefix(key, "Bearer ")
- if key == "" || key == "midjourney-proxy" {
- key = c.Request.Header.Get("mj-api-secret")
- key = strings.TrimPrefix(key, "Bearer ")
- key = strings.TrimPrefix(key, "sk-")
- parts = strings.Split(key, "-")
- key = parts[0]
- } else {
- key = strings.TrimPrefix(key, "sk-")
- parts = strings.Split(key, "-")
- key = parts[0]
- }
- token, err := model.ValidateUserToken(key)
- if token != nil {
- id := c.GetInt("id")
- if id == 0 {
- c.Set("id", token.UserId)
- }
- }
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
- return
- }
-
- allowIpsMap := token.GetIpLimitsMap()
- if len(allowIpsMap) != 0 {
- clientIp := c.ClientIP()
- if _, ok := allowIpsMap[clientIp]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
- return
- }
- }
-
- userCache, err := model.GetUserCache(token.UserId)
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
- return
- }
- userEnabled := userCache.Status == common.UserStatusEnabled
- if !userEnabled {
- abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
- return
- }
-
- userCache.WriteContext(c)
-
- userGroup := userCache.Group
- tokenGroup := token.Group
- if tokenGroup != "" {
- // check common.UserUsableGroups[userGroup]
- if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
- return
- }
- // check group in common.GroupRatio
- if !ratio_setting.ContainsGroupRatio(tokenGroup) {
- if tokenGroup != "auto" {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
- }
- }
- userGroup = tokenGroup
- }
- common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
-
- err = SetupContextForToken(c, token, parts...)
- if err != nil {
- return
- }
- c.Next()
- }
-}
-
-func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
- if token == nil {
- return fmt.Errorf("token is nil")
- }
- c.Set("id", token.UserId)
- c.Set("token_id", token.Id)
- c.Set("token_key", token.Key)
- c.Set("token_name", token.Name)
- c.Set("token_unlimited_quota", token.UnlimitedQuota)
- if !token.UnlimitedQuota {
- c.Set("token_quota", token.RemainQuota)
- }
- if token.ModelLimitsEnabled {
- c.Set("token_model_limit_enabled", true)
- c.Set("token_model_limit", token.GetModelLimitsMap())
- } else {
- c.Set("token_model_limit_enabled", false)
- }
- c.Set("token_group", token.Group)
- if len(parts) > 1 {
- if model.IsAdmin(token.UserId) {
- c.Set("specific_channel_id", parts[1])
- } else {
- abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
- return fmt.Errorf("普通用户不支持指定渠道")
- }
- }
- return nil
-}
diff --git a/new-api/middleware/cache.go b/new-api/middleware/cache.go
deleted file mode 100644
index 8899b12369b2c9210d83e89f74ef37000c4b0afe..0000000000000000000000000000000000000000
--- a/new-api/middleware/cache.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package middleware
-
-import (
- "github.com/gin-gonic/gin"
-)
-
-func Cache() func(c *gin.Context) {
- return func(c *gin.Context) {
- if c.Request.RequestURI == "/" {
- c.Header("Cache-Control", "no-cache")
- } else {
- c.Header("Cache-Control", "max-age=604800") // one week
- }
- c.Next()
- }
-}
diff --git a/new-api/middleware/cors.go b/new-api/middleware/cors.go
deleted file mode 100644
index 81f79a4bdbaa3f3754f541ecd545f096c9b99896..0000000000000000000000000000000000000000
--- a/new-api/middleware/cors.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package middleware
-
-import (
- "github.com/gin-contrib/cors"
- "github.com/gin-gonic/gin"
-)
-
-func CORS() gin.HandlerFunc {
- config := cors.DefaultConfig()
- config.AllowAllOrigins = true
- config.AllowCredentials = true
- config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
- config.AllowHeaders = []string{"*"}
- return cors.New(config)
-}
diff --git a/new-api/middleware/disable-cache.go b/new-api/middleware/disable-cache.go
deleted file mode 100644
index 6e2113f1026606d5e9965376fc277d9b022e2497..0000000000000000000000000000000000000000
--- a/new-api/middleware/disable-cache.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package middleware
-
-import "github.com/gin-gonic/gin"
-
-func DisableCache() gin.HandlerFunc {
- return func(c *gin.Context) {
- c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0")
- c.Header("Pragma", "no-cache")
- c.Header("Expires", "0")
- c.Next()
- }
-}
diff --git a/new-api/middleware/distributor.go b/new-api/middleware/distributor.go
deleted file mode 100644
index f6a88ee28cc9443c778c4598cc40143987fdcfac..0000000000000000000000000000000000000000
--- a/new-api/middleware/distributor.go
+++ /dev/null
@@ -1,327 +0,0 @@
-package middleware
-
-import (
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/ratio_setting"
- "one-api/types"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-type ModelRequest struct {
- Model string `json:"model"`
- Group string `json:"group,omitempty"`
-}
-
-func Distribute() func(c *gin.Context) {
- return func(c *gin.Context) {
- var channel *model.Channel
- channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
- modelRequest, shouldSelectChannel, err := getModelRequest(c)
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
- return
- }
- if ok {
- id, err := strconv.Atoi(channelId.(string))
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
- return
- }
- channel, err = model.GetChannelById(id, true)
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
- return
- }
- if channel.Status != common.ChannelStatusEnabled {
- abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
- return
- }
- } else {
- // Select a channel for the user
- // check token model mapping
- modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
- if modelLimitEnable {
- s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
- if !ok {
- // token model limit is empty, all models are not allowed
- abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
- return
- }
- var tokenModelLimit map[string]bool
- tokenModelLimit, ok = s.(map[string]bool)
- if !ok {
- tokenModelLimit = map[string]bool{}
- }
- matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
- if _, ok := tokenModelLimit[matchName]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
- return
- }
- }
-
- if shouldSelectChannel {
- if modelRequest.Model == "" {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
- return
- }
- var selectGroup string
- userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
- // check path is /pg/chat/completions
- if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
- playgroundRequest := &dto.PlayGroundRequest{}
- err = common.UnmarshalBodyReusable(c, playgroundRequest)
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
- return
- }
- if playgroundRequest.Group != "" {
- if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
- abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
- return
- }
- userGroup = playgroundRequest.Group
- }
- }
- channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
- if err != nil {
- showGroup := userGroup
- if userGroup == "auto" {
- showGroup = fmt.Sprintf("auto(%s)", selectGroup)
- }
- message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
- // 如果错误,但是渠道不为空,说明是数据库一致性问题
- //if channel != nil {
- // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
- // message = "数据库一致性已被破坏,请联系管理员"
- //}
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
- return
- }
- if channel == nil {
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
- return
- }
- }
- }
- common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
- SetupContextForSelectedChannel(c, channel, modelRequest.Model)
- c.Next()
- }
-}
-
-func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
- var modelRequest ModelRequest
- shouldSelectChannel := true
- var err error
- if strings.Contains(c.Request.URL.Path, "/mj/") {
- relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
- if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
- relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
- relayMode == relayconstant.RelayModeMidjourneyNotify ||
- relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
- shouldSelectChannel = false
- } else {
- midjourneyRequest := dto.MidjourneyRequest{}
- err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
- if err != nil {
- return nil, false, err
- }
- midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
- if mjErr != nil {
- return nil, false, fmt.Errorf(mjErr.Description)
- }
- if midjourneyModel == "" {
- if !success {
- return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
- } else {
- // task fetch, task fetch by condition, notify
- shouldSelectChannel = false
- }
- }
- modelRequest.Model = midjourneyModel
- }
- c.Set("relay_mode", relayMode)
- } else if strings.Contains(c.Request.URL.Path, "/suno/") {
- relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
- if relayMode == relayconstant.RelayModeSunoFetch ||
- relayMode == relayconstant.RelayModeSunoFetchByID {
- shouldSelectChannel = false
- } else {
- modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
- modelRequest.Model = modelName
- }
- c.Set("platform", string(constant.TaskPlatformSuno))
- c.Set("relay_mode", relayMode)
- } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
- relayMode := relayconstant.RelayModeUnknown
- if c.Request.Method == http.MethodPost {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
- relayMode = relayconstant.RelayModeVideoSubmit
- } else if c.Request.Method == http.MethodGet {
- relayMode = relayconstant.RelayModeVideoFetchByID
- shouldSelectChannel = false
- }
- if _, ok := c.Get("relay_mode"); !ok {
- c.Set("relay_mode", relayMode)
- }
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
- // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
- relayMode := relayconstant.RelayModeGemini
- modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
- if modelName != "" {
- modelRequest.Model = modelName
- }
- c.Set("relay_mode", relayMode)
- } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
- err = common.UnmarshalBodyReusable(c, &modelRequest)
- }
- if err != nil {
- return nil, false, errors.New("无效的请求, " + err.Error())
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
- //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
- modelRequest.Model = c.Query("model")
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "text-moderation-stable"
- }
- }
- if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- if modelRequest.Model == "" {
- modelRequest.Model = c.Param("model")
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
- //modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
- if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
- modelRequest.Model = c.PostForm("model")
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
- relayMode := relayconstant.RelayModeAudioSpeech
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
- relayMode = relayconstant.RelayModeAudioTranslation
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
- modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
- relayMode = relayconstant.RelayModeAudioTranscription
- }
- c.Set("relay_mode", relayMode)
- }
- if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
- // playground chat completions
- err = common.UnmarshalBodyReusable(c, &modelRequest)
- if err != nil {
- return nil, false, errors.New("无效的请求, " + err.Error())
- }
- common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
- }
- return &modelRequest, shouldSelectChannel, nil
-}
-
-func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
- c.Set("original_model", modelName) // for retry
- if channel == nil {
- return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
- }
- common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
- common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
- common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
- common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
- common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
- common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
- common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
- common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, channel.GetHeaderOverride())
- if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
- common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
- }
- common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
- common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
- common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
-
- key, index, newAPIError := channel.GetNextEnabledKey()
- if newAPIError != nil {
- return newAPIError
- }
- if channel.ChannelInfo.IsMultiKey {
- common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
- common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
- } else {
- // 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
- common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
- }
- // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
- common.SetContextKey(c, constant.ContextKeyChannelKey, key)
- common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
-
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false)
-
- // TODO: api_version统一
- switch channel.Type {
- case constant.ChannelTypeAzure:
- c.Set("api_version", channel.Other)
- case constant.ChannelTypeVertexAi:
- c.Set("region", channel.Other)
- case constant.ChannelTypeXunfei:
- c.Set("api_version", channel.Other)
- case constant.ChannelTypeGemini:
- c.Set("api_version", channel.Other)
- case constant.ChannelTypeAli:
- c.Set("plugin", channel.Other)
- case constant.ChannelCloudflare:
- c.Set("api_version", channel.Other)
- case constant.ChannelTypeMokaAI:
- c.Set("api_version", channel.Other)
- case constant.ChannelTypeCoze:
- c.Set("bot_id", channel.Other)
- }
- return nil
-}
-
-// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
-// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
-// 输出: gemini-2.0-flash
-func extractModelNameFromGeminiPath(path string) string {
- // 查找 "/models/" 的位置
- modelsPrefix := "/models/"
- modelsIndex := strings.Index(path, modelsPrefix)
- if modelsIndex == -1 {
- return ""
- }
-
- // 从 "/models/" 之后开始提取
- startIndex := modelsIndex + len(modelsPrefix)
- if startIndex >= len(path) {
- return ""
- }
-
- // 查找 ":" 的位置,模型名在 ":" 之前
- colonIndex := strings.Index(path[startIndex:], ":")
- if colonIndex == -1 {
- // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
- return path[startIndex:]
- }
-
- // 返回模型名部分
- return path[startIndex : startIndex+colonIndex]
-}
diff --git a/new-api/middleware/email-verification-rate-limit.go b/new-api/middleware/email-verification-rate-limit.go
deleted file mode 100644
index a3ef619813be3ff927ba04e9d689062614b046e2..0000000000000000000000000000000000000000
--- a/new-api/middleware/email-verification-rate-limit.go
+++ /dev/null
@@ -1,80 +0,0 @@
-package middleware
-
-import (
- "context"
- "fmt"
- "net/http"
- "one-api/common"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- EmailVerificationRateLimitMark = "EV"
- EmailVerificationMaxRequests = 2 // 30秒内最多2次
- EmailVerificationDuration = 30 // 30秒时间窗口
-)
-
-func redisEmailVerificationRateLimiter(c *gin.Context) {
- ctx := context.Background()
- rdb := common.RDB
- key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
-
- count, err := rdb.Incr(ctx, key).Result()
- if err != nil {
- // fallback
- memoryEmailVerificationRateLimiter(c)
- return
- }
-
- // 第一次设置键时设置过期时间
- if count == 1 {
- _ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
- }
-
- // 检查是否超出限制
- if count <= int64(EmailVerificationMaxRequests) {
- c.Next()
- return
- }
-
- // 获取剩余等待时间
- ttl, err := rdb.TTL(ctx, key).Result()
- waitSeconds := int64(EmailVerificationDuration)
- if err == nil && ttl > 0 {
- waitSeconds = int64(ttl.Seconds())
- }
-
- c.JSON(http.StatusTooManyRequests, gin.H{
- "success": false,
- "message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
- })
- c.Abort()
-}
-
-func memoryEmailVerificationRateLimiter(c *gin.Context) {
- key := EmailVerificationRateLimitMark + ":" + c.ClientIP()
-
- if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) {
- c.JSON(http.StatusTooManyRequests, gin.H{
- "success": false,
- "message": "发送过于频繁,请稍后再试",
- })
- c.Abort()
- return
- }
-
- c.Next()
-}
-
-func EmailVerificationRateLimit() gin.HandlerFunc {
- return func(c *gin.Context) {
- if common.RedisEnabled {
- redisEmailVerificationRateLimiter(c)
- } else {
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
- memoryEmailVerificationRateLimiter(c)
- }
- }
-}
diff --git a/new-api/middleware/gzip.go b/new-api/middleware/gzip.go
deleted file mode 100644
index bf333dbe266d89548d860c420726b722f4e7cef2..0000000000000000000000000000000000000000
--- a/new-api/middleware/gzip.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package middleware
-
-import (
- "compress/gzip"
- "github.com/andybalholm/brotli"
- "github.com/gin-gonic/gin"
- "io"
- "net/http"
-)
-
-func DecompressRequestMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- if c.Request.Body == nil || c.Request.Method == http.MethodGet {
- c.Next()
- return
- }
- switch c.GetHeader("Content-Encoding") {
- case "gzip":
- gzipReader, err := gzip.NewReader(c.Request.Body)
- if err != nil {
- c.AbortWithStatus(http.StatusBadRequest)
- return
- }
- defer gzipReader.Close()
-
- // Replace the request body with the decompressed data
- c.Request.Body = io.NopCloser(gzipReader)
- c.Request.Header.Del("Content-Encoding")
- case "br":
- reader := brotli.NewReader(c.Request.Body)
- c.Request.Body = io.NopCloser(reader)
- c.Request.Header.Del("Content-Encoding")
- }
-
- // Continue processing the request
- c.Next()
- }
-}
diff --git a/new-api/middleware/jimeng_adapter.go b/new-api/middleware/jimeng_adapter.go
deleted file mode 100644
index aafccc42dc1c793edf69954762680090d0a19052..0000000000000000000000000000000000000000
--- a/new-api/middleware/jimeng_adapter.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "encoding/json"
- "github.com/gin-gonic/gin"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- relayconstant "one-api/relay/constant"
-)
-
-func JimengRequestConvert() func(c *gin.Context) {
- return func(c *gin.Context) {
- action := c.Query("Action")
- if action == "" {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
- return
- }
-
- // Handle Jimeng official API request
- var originalReq map[string]interface{}
- if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
- return
- }
- model, _ := originalReq["req_key"].(string)
- prompt, _ := originalReq["prompt"].(string)
-
- unifiedReq := map[string]interface{}{
- "model": model,
- "prompt": prompt,
- "metadata": originalReq,
- }
-
- jsonData, err := json.Marshal(unifiedReq)
- if err != nil {
- abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
- return
- }
-
- // Update request body
- c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
- c.Set(common.KeyRequestBody, jsonData)
-
- if image, ok := originalReq["image"]; !ok || image == "" {
- c.Set("action", constant.TaskActionTextGenerate)
- }
-
- c.Request.URL.Path = "/v1/video/generations"
-
- if action == "CVSync2AsyncGetResult" {
- taskId, ok := originalReq["task_id"].(string)
- if !ok || taskId == "" {
- abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
- return
- }
- c.Request.URL.Path = "/v1/video/generations/" + taskId
- c.Request.Method = http.MethodGet
- c.Set("task_id", taskId)
- c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
- }
- c.Next()
- }
-}
diff --git a/new-api/middleware/kling_adapter.go b/new-api/middleware/kling_adapter.go
deleted file mode 100644
index 638b15d6ff6d9c7b019a4ef280b038e7ee4a164b..0000000000000000000000000000000000000000
--- a/new-api/middleware/kling_adapter.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "one-api/common"
- "one-api/constant"
-
- "github.com/gin-gonic/gin"
-)
-
-func KlingRequestConvert() func(c *gin.Context) {
- return func(c *gin.Context) {
- var originalReq map[string]interface{}
- if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
- c.Next()
- return
- }
-
- // Support both model_name and model fields
- model, _ := originalReq["model_name"].(string)
- if model == "" {
- model, _ = originalReq["model"].(string)
- }
- prompt, _ := originalReq["prompt"].(string)
-
- unifiedReq := map[string]interface{}{
- "model": model,
- "prompt": prompt,
- "metadata": originalReq,
- }
-
- jsonData, err := json.Marshal(unifiedReq)
- if err != nil {
- c.Next()
- return
- }
-
- // Rewrite request body and path
- c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
- c.Request.URL.Path = "/v1/video/generations"
- if image, ok := originalReq["image"]; !ok || image == "" {
- c.Set("action", constant.TaskActionTextGenerate)
- }
-
- // We have to reset the request body for the next handlers
- c.Set(common.KeyRequestBody, jsonData)
- c.Next()
- }
-}
diff --git a/new-api/middleware/logger.go b/new-api/middleware/logger.go
deleted file mode 100644
index 244286e6040558195f6524984ff466c617588cf8..0000000000000000000000000000000000000000
--- a/new-api/middleware/logger.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package middleware
-
-import (
- "fmt"
- "github.com/gin-gonic/gin"
- "one-api/common"
-)
-
-func SetUpLogger(server *gin.Engine) {
- server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
- var requestID string
- if param.Keys != nil {
- requestID = param.Keys[common.RequestIdKey].(string)
- }
- return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
- param.TimeStamp.Format("2006/01/02 - 15:04:05"),
- requestID,
- param.StatusCode,
- param.Latency,
- param.ClientIP,
- param.Method,
- param.Path,
- )
- }))
-}
diff --git a/new-api/middleware/model-rate-limit.go b/new-api/middleware/model-rate-limit.go
deleted file mode 100644
index 2ee77c560a4c197eaa8b80b9e6137e7638d2fd6f..0000000000000000000000000000000000000000
--- a/new-api/middleware/model-rate-limit.go
+++ /dev/null
@@ -1,199 +0,0 @@
-package middleware
-
-import (
- "context"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/common/limiter"
- "one-api/constant"
- "one-api/setting"
- "strconv"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/go-redis/redis/v8"
-)
-
-const (
- ModelRequestRateLimitCountMark = "MRRL"
- ModelRequestRateLimitSuccessCountMark = "MRRLS"
-)
-
-// 检查Redis中的请求限制
-func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
- // 如果maxCount为0,表示不限制
- if maxCount == 0 {
- return true, nil
- }
-
- // 获取当前计数
- length, err := rdb.LLen(ctx, key).Result()
- if err != nil {
- return false, err
- }
-
- // 如果未达到限制,允许请求
- if length < int64(maxCount) {
- return true, nil
- }
-
- // 检查时间窗口
- oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
- oldTime, err := time.Parse(timeFormat, oldTimeStr)
- if err != nil {
- return false, err
- }
-
- nowTimeStr := time.Now().Format(timeFormat)
- nowTime, err := time.Parse(timeFormat, nowTimeStr)
- if err != nil {
- return false, err
- }
- // 如果在时间窗口内已达到限制,拒绝请求
- subTime := nowTime.Sub(oldTime).Seconds()
- if int64(subTime) < duration {
- rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
- return false, nil
- }
-
- return true, nil
-}
-
-// 记录Redis请求
-func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
- // 如果maxCount为0,不记录请求
- if maxCount == 0 {
- return
- }
-
- now := time.Now().Format(timeFormat)
- rdb.LPush(ctx, key, now)
- rdb.LTrim(ctx, key, 0, int64(maxCount-1))
- rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
-}
-
-// Redis限流处理器
-func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
- return func(c *gin.Context) {
- userId := strconv.Itoa(c.GetInt("id"))
- ctx := context.Background()
- rdb := common.RDB
-
- // 1. 检查成功请求数限制
- successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
- allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
- if err != nil {
- fmt.Println("检查成功请求数限制失败:", err.Error())
- abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
- return
- }
- if !allowed {
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
- return
- }
-
- //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
- if totalMaxCount > 0 {
- totalKey := fmt.Sprintf("rateLimit:%s", userId)
- // 初始化
- tb := limiter.New(ctx, rdb)
- allowed, err = tb.Allow(
- ctx,
- totalKey,
- limiter.WithCapacity(int64(totalMaxCount)*duration),
- limiter.WithRate(int64(totalMaxCount)),
- limiter.WithRequested(duration),
- )
-
- if err != nil {
- fmt.Println("检查总请求数限制失败:", err.Error())
- abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
- return
- }
-
- if !allowed {
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
- }
- }
-
- // 4. 处理请求
- c.Next()
-
- // 5. 如果请求成功,记录成功请求
- if c.Writer.Status() < 400 {
- recordRedisRequest(ctx, rdb, successKey, successMaxCount)
- }
- }
-}
-
-// 内存限流处理器
-func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
- inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
-
- return func(c *gin.Context) {
- userId := strconv.Itoa(c.GetInt("id"))
- totalKey := ModelRequestRateLimitCountMark + userId
- successKey := ModelRequestRateLimitSuccessCountMark + userId
-
- // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
- if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- }
-
- // 2. 检查成功请求数限制
- // 使用一个临时key来检查限制,这样可以避免实际记录
- checkKey := successKey + "_check"
- if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- }
-
- // 3. 处理请求
- c.Next()
-
- // 4. 如果请求成功,记录到实际的成功请求计数中
- if c.Writer.Status() < 400 {
- inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
- }
- }
-}
-
-// ModelRequestRateLimit 模型请求限流中间件
-func ModelRequestRateLimit() func(c *gin.Context) {
- return func(c *gin.Context) {
- // 在每个请求时检查是否启用限流
- if !setting.ModelRequestRateLimitEnabled {
- c.Next()
- return
- }
-
- // 计算限流参数
- duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
- totalMaxCount := setting.ModelRequestRateLimitCount
- successMaxCount := setting.ModelRequestRateLimitSuccessCount
-
- // 获取分组
- group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
- if group == "" {
- group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
- }
-
- //获取分组的限流配置
- groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
- if found {
- totalMaxCount = groupTotalCount
- successMaxCount = groupSuccessCount
- }
-
- // 根据存储类型选择并执行限流处理器
- if common.RedisEnabled {
- redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
- } else {
- memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
- }
- }
-}
diff --git a/new-api/middleware/rate-limit.go b/new-api/middleware/rate-limit.go
deleted file mode 100644
index b00d1080e7b6a49a5c1b25e64e7941ad4427eff7..0000000000000000000000000000000000000000
--- a/new-api/middleware/rate-limit.go
+++ /dev/null
@@ -1,113 +0,0 @@
-package middleware
-
-import (
- "context"
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "time"
-)
-
-var timeFormat = "2006-01-02T15:04:05.000Z"
-
-var inMemoryRateLimiter common.InMemoryRateLimiter
-
-var defNext = func(c *gin.Context) {
- c.Next()
-}
-
-func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
- ctx := context.Background()
- rdb := common.RDB
- key := "rateLimit:" + mark + c.ClientIP()
- listLength, err := rdb.LLen(ctx, key).Result()
- if err != nil {
- fmt.Println(err.Error())
- c.Status(http.StatusInternalServerError)
- c.Abort()
- return
- }
- if listLength < int64(maxRequestNum) {
- rdb.LPush(ctx, key, time.Now().Format(timeFormat))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
- } else {
- oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
- oldTime, err := time.Parse(timeFormat, oldTimeStr)
- if err != nil {
- fmt.Println(err)
- c.Status(http.StatusInternalServerError)
- c.Abort()
- return
- }
- nowTimeStr := time.Now().Format(timeFormat)
- nowTime, err := time.Parse(timeFormat, nowTimeStr)
- if err != nil {
- fmt.Println(err)
- c.Status(http.StatusInternalServerError)
- c.Abort()
- return
- }
- // time.Since will return negative number!
- // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
- if int64(nowTime.Sub(oldTime).Seconds()) < duration {
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- } else {
- rdb.LPush(ctx, key, time.Now().Format(timeFormat))
- rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
- }
- }
-}
-
-func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
- key := mark + c.ClientIP()
- if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- }
-}
-
-func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
- if common.RedisEnabled {
- return func(c *gin.Context) {
- redisRateLimiter(c, maxRequestNum, duration, mark)
- }
- } else {
- // It's safe to call multi times.
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
- return func(c *gin.Context) {
- memoryRateLimiter(c, maxRequestNum, duration, mark)
- }
- }
-}
-
-func GlobalWebRateLimit() func(c *gin.Context) {
- if common.GlobalWebRateLimitEnable {
- return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
- }
- return defNext
-}
-
-func GlobalAPIRateLimit() func(c *gin.Context) {
- if common.GlobalApiRateLimitEnable {
- return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
- }
- return defNext
-}
-
-func CriticalRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
-}
-
-func DownloadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
-}
-
-func UploadRateLimit() func(c *gin.Context) {
- return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
-}
diff --git a/new-api/middleware/recover.go b/new-api/middleware/recover.go
deleted file mode 100644
index 56777babf8537069f2f45a09af29145cad9aa96f..0000000000000000000000000000000000000000
--- a/new-api/middleware/recover.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package middleware
-
-import (
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "runtime/debug"
-)
-
-func RelayPanicRecover() gin.HandlerFunc {
- return func(c *gin.Context) {
- defer func() {
- if err := recover(); err != nil {
- common.SysLog(fmt.Sprintf("panic detected: %v", err))
- common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
- "type": "new_api_panic",
- },
- })
- c.Abort()
- }
- }()
- c.Next()
- }
-}
diff --git a/new-api/middleware/request-id.go b/new-api/middleware/request-id.go
deleted file mode 100644
index cc75bb60508cdd6a78caeea3f3ed9cae37a5e1b7..0000000000000000000000000000000000000000
--- a/new-api/middleware/request-id.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package middleware
-
-import (
- "context"
- "github.com/gin-gonic/gin"
- "one-api/common"
-)
-
-func RequestId() func(c *gin.Context) {
- return func(c *gin.Context) {
- id := common.GetTimeString() + common.GetRandomString(8)
- c.Set(common.RequestIdKey, id)
- ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
- c.Request = c.Request.WithContext(ctx)
- c.Header(common.RequestIdKey, id)
- c.Next()
- }
-}
diff --git a/new-api/middleware/secure_verification.go b/new-api/middleware/secure_verification.go
deleted file mode 100644
index ae9c2b9ceab94437b810b24d1c312a5f7ea9d45b..0000000000000000000000000000000000000000
--- a/new-api/middleware/secure_verification.go
+++ /dev/null
@@ -1,131 +0,0 @@
-package middleware
-
-import (
- "net/http"
- "time"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
-)
-
-const (
- // SecureVerificationSessionKey 安全验证的 session key(与 controller 保持一致)
- SecureVerificationSessionKey = "secure_verified_at"
- // SecureVerificationTimeout 验证有效期(秒)
- SecureVerificationTimeout = 300 // 5分钟
-)
-
-// SecureVerificationRequired 安全验证中间件
-// 检查用户是否在有效时间内通过了安全验证
-// 如果未验证或验证已过期,返回 401 错误
-func SecureVerificationRequired() gin.HandlerFunc {
- return func(c *gin.Context) {
- // 检查用户是否已登录
- userId := c.GetInt("id")
- if userId == 0 {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "未登录",
- })
- c.Abort()
- return
- }
-
- // 检查 session 中的验证时间戳
- session := sessions.Default(c)
- verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
- if verifiedAtRaw == nil {
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "需要安全验证",
- "code": "VERIFICATION_REQUIRED",
- })
- c.Abort()
- return
- }
-
- verifiedAt, ok := verifiedAtRaw.(int64)
- if !ok {
- // session 数据格式错误
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "验证状态异常,请重新验证",
- "code": "VERIFICATION_INVALID",
- })
- c.Abort()
- return
- }
-
- // 检查验证是否过期
- elapsed := time.Now().Unix() - verifiedAt
- if elapsed >= SecureVerificationTimeout {
- // 验证已过期,清除 session
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
- c.JSON(http.StatusForbidden, gin.H{
- "success": false,
- "message": "验证已过期,请重新验证",
- "code": "VERIFICATION_EXPIRED",
- })
- c.Abort()
- return
- }
-
- // 验证有效,继续处理请求
- c.Next()
- }
-}
-
-// OptionalSecureVerification 可选的安全验证中间件
-// 如果用户已验证,则在 context 中设置标记,但不阻止请求继续
-// 用于某些需要区分是否已验证的场景
-func OptionalSecureVerification() gin.HandlerFunc {
- return func(c *gin.Context) {
- userId := c.GetInt("id")
- if userId == 0 {
- c.Set("secure_verified", false)
- c.Next()
- return
- }
-
- session := sessions.Default(c)
- verifiedAtRaw := session.Get(SecureVerificationSessionKey)
-
- if verifiedAtRaw == nil {
- c.Set("secure_verified", false)
- c.Next()
- return
- }
-
- verifiedAt, ok := verifiedAtRaw.(int64)
- if !ok {
- c.Set("secure_verified", false)
- c.Next()
- return
- }
-
- elapsed := time.Now().Unix() - verifiedAt
- if elapsed >= SecureVerificationTimeout {
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
- c.Set("secure_verified", false)
- c.Next()
- return
- }
-
- c.Set("secure_verified", true)
- c.Set("secure_verified_at", verifiedAt)
- c.Next()
- }
-}
-
-// ClearSecureVerification 清除安全验证状态
-// 用于用户登出或需要强制重新验证的场景
-func ClearSecureVerification(c *gin.Context) {
- session := sessions.Default(c)
- session.Delete(SecureVerificationSessionKey)
- _ = session.Save()
-}
diff --git a/new-api/middleware/stats.go b/new-api/middleware/stats.go
deleted file mode 100644
index fa1d61efb1cfcecac3d66b1063d52be1a0794942..0000000000000000000000000000000000000000
--- a/new-api/middleware/stats.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package middleware
-
-import (
- "sync/atomic"
-
- "github.com/gin-gonic/gin"
-)
-
-// HTTPStats 存储HTTP统计信息
-type HTTPStats struct {
- activeConnections int64
-}
-
-var globalStats = &HTTPStats{}
-
-// StatsMiddleware 统计中间件
-func StatsMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- // 增加活跃连接数
- atomic.AddInt64(&globalStats.activeConnections, 1)
-
- // 确保在请求结束时减少连接数
- defer func() {
- atomic.AddInt64(&globalStats.activeConnections, -1)
- }()
-
- c.Next()
- }
-}
-
-// StatsInfo 统计信息结构
-type StatsInfo struct {
- ActiveConnections int64 `json:"active_connections"`
-}
-
-// GetStats 获取统计信息
-func GetStats() StatsInfo {
- return StatsInfo{
- ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
- }
-}
diff --git a/new-api/middleware/turnstile-check.go b/new-api/middleware/turnstile-check.go
deleted file mode 100644
index 1fcd7ad22bd35746fffc4ff9a6e7f79821679a19..0000000000000000000000000000000000000000
--- a/new-api/middleware/turnstile-check.go
+++ /dev/null
@@ -1,80 +0,0 @@
-package middleware
-
-import (
- "encoding/json"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- "net/http"
- "net/url"
- "one-api/common"
-)
-
-type turnstileCheckResponse struct {
- Success bool `json:"success"`
-}
-
-func TurnstileCheck() gin.HandlerFunc {
- return func(c *gin.Context) {
- if common.TurnstileCheckEnabled {
- session := sessions.Default(c)
- turnstileChecked := session.Get("turnstile")
- if turnstileChecked != nil {
- c.Next()
- return
- }
- response := c.Query("turnstile")
- if response == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Turnstile token 为空",
- })
- c.Abort()
- return
- }
- rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
- "secret": {common.TurnstileSecretKey},
- "response": {response},
- "remoteip": {c.ClientIP()},
- })
- if err != nil {
- common.SysLog(err.Error())
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- c.Abort()
- return
- }
- defer rawRes.Body.Close()
- var res turnstileCheckResponse
- err = json.NewDecoder(rawRes.Body).Decode(&res)
- if err != nil {
- common.SysLog(err.Error())
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- c.Abort()
- return
- }
- if !res.Success {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "Turnstile 校验失败,请刷新重试!",
- })
- c.Abort()
- return
- }
- session.Set("turnstile", true)
- err = session.Save()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "message": "无法保存会话信息,请重试",
- "success": false,
- })
- return
- }
- }
- c.Next()
- }
-}
diff --git a/new-api/middleware/utils.go b/new-api/middleware/utils.go
deleted file mode 100644
index 7cf682ce20b99e4e7e6db9c083d1d0cfb4304a71..0000000000000000000000000000000000000000
--- a/new-api/middleware/utils.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package middleware
-
-import (
- "fmt"
- "github.com/gin-gonic/gin"
- "one-api/common"
- "one-api/logger"
-)
-
-func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
- codeStr := ""
- if len(code) > 0 {
- codeStr = code[0]
- }
- userId := c.GetInt("id")
- c.JSON(statusCode, gin.H{
- "error": gin.H{
- "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
- "type": "new_api_error",
- "code": codeStr,
- },
- })
- c.Abort()
- logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
-}
-
-func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
- c.JSON(statusCode, gin.H{
- "description": description,
- "type": "new_api_error",
- "code": code,
- })
- c.Abort()
- logger.LogError(c.Request.Context(), description)
-}
diff --git a/new-api/model/ability.go b/new-api/model/ability.go
deleted file mode 100644
index ec2f7ee4f6fa94f6e733979fc982527034e08827..0000000000000000000000000000000000000000
--- a/new-api/model/ability.go
+++ /dev/null
@@ -1,340 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "strings"
- "sync"
-
- "github.com/samber/lo"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
-)
-
-type Ability struct {
- Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
- ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
- Enabled bool `json:"enabled"`
- Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
- Weight uint `json:"weight" gorm:"default:0;index"`
- Tag *string `json:"tag" gorm:"index"`
-}
-
-type AbilityWithChannel struct {
- Ability
- ChannelType int `json:"channel_type"`
-}
-
-func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
- var abilities []AbilityWithChannel
- err := DB.Table("abilities").
- Select("abilities.*, channels.type as channel_type").
- Joins("left join channels on abilities.channel_id = channels.id").
- Where("abilities.enabled = ?", true).
- Scan(&abilities).Error
- return abilities, err
-}
-
-func GetGroupEnabledModels(group string) []string {
- var models []string
- // Find distinct models
- DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
- return models
-}
-
-func GetEnabledModels() []string {
- var models []string
- // Find distinct models
- DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
- return models
-}
-
-func GetAllEnableAbilities() []Ability {
- var abilities []Ability
- DB.Find(&abilities, "enabled = ?", true)
- return abilities
-}
-
-func getPriority(group string, model string, retry int) (int, error) {
-
- var priorities []int
- err := DB.Model(&Ability{}).
- Select("DISTINCT(priority)").
- Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
- Order("priority DESC"). // 按优先级降序排序
- Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
-
- if err != nil {
- // 处理错误
- return 0, err
- }
-
- if len(priorities) == 0 {
- // 如果没有查询到优先级,则返回错误
- return 0, errors.New("数据库一致性被破坏")
- }
-
- // 确定要使用的优先级
- var priorityToUse int
- if retry >= len(priorities) {
- // 如果重试次数大于优先级数,则使用最小的优先级
- priorityToUse = priorities[len(priorities)-1]
- } else {
- priorityToUse = priorities[retry]
- }
- return priorityToUse, nil
-}
-
-func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
- channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
- if retry != 0 {
- priority, err := getPriority(group, model, retry)
- if err != nil {
- return nil, err
- } else {
- channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
- }
- }
-
- return channelQuery, nil
-}
-
-func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
- var abilities []Ability
-
- var err error = nil
- channelQuery, err := getChannelQuery(group, model, retry)
- if err != nil {
- return nil, err
- }
- if common.UsingSQLite || common.UsingPostgreSQL {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- } else {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- }
- if err != nil {
- return nil, err
- }
- channel := Channel{}
- if len(abilities) > 0 {
- // Randomly choose one
- weightSum := uint(0)
- for _, ability_ := range abilities {
- weightSum += ability_.Weight + 10
- }
- // Randomly choose one
- weight := common.GetRandomInt(int(weightSum))
- for _, ability_ := range abilities {
- weight -= int(ability_.Weight) + 10
- //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
- if weight <= 0 {
- channel.Id = ability_.ChannelId
- break
- }
- }
- } else {
- return nil, nil
- }
- err = DB.First(&channel, "id = ?", channel.Id).Error
- return &channel, err
-}
-
-func (channel *Channel) AddAbilities(tx *gorm.DB) error {
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
- abilitySet := make(map[string]struct{})
- abilities := make([]Ability, 0, len(models_))
- for _, model := range models_ {
- for _, group := range groups_ {
- key := group + "|" + model
- if _, exists := abilitySet[key]; exists {
- continue
- }
- abilitySet[key] = struct{}{}
- ability := Ability{
- Group: group,
- Model: model,
- ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
- Priority: channel.Priority,
- Weight: uint(channel.GetWeight()),
- Tag: channel.Tag,
- }
- abilities = append(abilities, ability)
- }
- }
- if len(abilities) == 0 {
- return nil
- }
- // choose DB or provided tx
- useDB := DB
- if tx != nil {
- useDB = tx
- }
- for _, chunk := range lo.Chunk(abilities, 50) {
- err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (channel *Channel) DeleteAbilities() error {
- return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
-}
-
-// UpdateAbilities updates abilities of this channel.
-// Make sure the channel is completed before calling this function.
-func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
- isNewTx := false
- // 如果没有传入事务,创建新的事务
- if tx == nil {
- tx = DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- isNewTx = true
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
- }
-
- // First delete all abilities of this channel
- err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
- if err != nil {
- if isNewTx {
- tx.Rollback()
- }
- return err
- }
-
- // Then add new abilities
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
- abilitySet := make(map[string]struct{})
- abilities := make([]Ability, 0, len(models_))
- for _, model := range models_ {
- for _, group := range groups_ {
- key := group + "|" + model
- if _, exists := abilitySet[key]; exists {
- continue
- }
- abilitySet[key] = struct{}{}
- ability := Ability{
- Group: group,
- Model: model,
- ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
- Priority: channel.Priority,
- Weight: uint(channel.GetWeight()),
- Tag: channel.Tag,
- }
- abilities = append(abilities, ability)
- }
- }
-
- if len(abilities) > 0 {
- for _, chunk := range lo.Chunk(abilities, 50) {
- err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
- if err != nil {
- if isNewTx {
- tx.Rollback()
- }
- return err
- }
- }
- }
-
- // 如果是新创建的事务,需要提交
- if isNewTx {
- return tx.Commit().Error
- }
-
- return nil
-}
-
-func UpdateAbilityStatus(channelId int, status bool) error {
- return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
-}
-
-func UpdateAbilityStatusByTag(tag string, status bool) error {
- return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
-}
-
-func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
- ability := Ability{}
- if newTag != nil {
- ability.Tag = newTag
- }
- if priority != nil {
- ability.Priority = priority
- }
- if weight != nil {
- ability.Weight = *weight
- }
- return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
-}
-
-var fixLock = sync.Mutex{}
-
-func FixAbility() (int, int, error) {
- lock := fixLock.TryLock()
- if !lock {
- return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
- }
- defer fixLock.Unlock()
-
- // truncate abilities table
- if common.UsingSQLite {
- err := DB.Exec("DELETE FROM abilities").Error
- if err != nil {
- common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
- return 0, 0, err
- }
- } else {
- err := DB.Exec("TRUNCATE TABLE abilities").Error
- if err != nil {
- common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
- return 0, 0, err
- }
- }
- var channels []*Channel
- // Find all channels
- err := DB.Model(&Channel{}).Find(&channels).Error
- if err != nil {
- return 0, 0, err
- }
- if len(channels) == 0 {
- return 0, 0, nil
- }
- successCount := 0
- failCount := 0
- for _, chunk := range lo.Chunk(channels, 50) {
- ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
- // Delete all abilities of this channel
- err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
- if err != nil {
- common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
- failCount += len(chunk)
- continue
- }
- // Then add new abilities
- for _, channel := range chunk {
- err = channel.AddAbilities(nil)
- if err != nil {
- common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
- failCount++
- } else {
- successCount++
- }
- }
- }
- InitChannelCache()
- return successCount, failCount, nil
-}
diff --git a/new-api/model/channel.go b/new-api/model/channel.go
deleted file mode 100644
index f82542ffe47749e1bdbb5816fecdd275348a4d93..0000000000000000000000000000000000000000
--- a/new-api/model/channel.go
+++ /dev/null
@@ -1,992 +0,0 @@
-package model
-
-import (
- "database/sql/driver"
- "encoding/json"
- "errors"
- "fmt"
- "math/rand"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/types"
- "strings"
- "sync"
-
- "github.com/samber/lo"
- "gorm.io/gorm"
-)
-
-type Channel struct {
- Id int `json:"id"`
- Type int `json:"type" gorm:"default:0"`
- Key string `json:"key" gorm:"not null"`
- OpenAIOrganization *string `json:"openai_organization"`
- TestModel *string `json:"test_model"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index"`
- Weight *uint `json:"weight" gorm:"default:0"`
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- TestTime int64 `json:"test_time" gorm:"bigint"`
- ResponseTime int `json:"response_time"` // in milliseconds
- BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
- Other string `json:"other"`
- Balance float64 `json:"balance"` // in USD
- BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
- Models string `json:"models"`
- Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
- UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
- ModelMapping *string `json:"model_mapping" gorm:"type:text"`
- //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
- StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
- Priority *int64 `json:"priority" gorm:"bigint;default:0"`
- AutoBan *int `json:"auto_ban" gorm:"default:1"`
- OtherInfo string `json:"other_info"`
- Tag *string `json:"tag" gorm:"index"`
- Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
- ParamOverride *string `json:"param_override" gorm:"type:text"`
- HeaderOverride *string `json:"header_override" gorm:"type:text"`
- Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
- // add after v0.8.5
- ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
-
- OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
-
- // cache info
- Keys []string `json:"-" gorm:"-"`
-}
-
-type ChannelInfo struct {
- IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
- MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
- MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
- MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
- MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
- MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
- MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
-}
-
-// Value implements driver.Valuer interface
-func (c ChannelInfo) Value() (driver.Value, error) {
- return common.Marshal(&c)
-}
-
-// Scan implements sql.Scanner interface
-func (c *ChannelInfo) Scan(value interface{}) error {
- bytesValue, _ := value.([]byte)
- return common.Unmarshal(bytesValue, c)
-}
-
-func (channel *Channel) GetKeys() []string {
- if channel.Key == "" {
- return []string{}
- }
- if len(channel.Keys) > 0 {
- return channel.Keys
- }
- trimmed := strings.TrimSpace(channel.Key)
- // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
- if strings.HasPrefix(trimmed, "[") {
- var arr []json.RawMessage
- if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
- res := make([]string, len(arr))
- for i, v := range arr {
- res[i] = string(v)
- }
- return res
- }
- }
- // Otherwise, fall back to splitting by newline
- keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
- return keys
-}
-
-func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
- // If not in multi-key mode, return the original key string directly.
- if !channel.ChannelInfo.IsMultiKey {
- return channel.Key, 0, nil
- }
-
- // Obtain all keys (split by \n)
- keys := channel.GetKeys()
- if len(keys) == 0 {
- // No keys available, return error, should disable the channel
- return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
- }
-
- lock := GetChannelPollingLock(channel.Id)
- lock.Lock()
- defer lock.Unlock()
-
- statusList := channel.ChannelInfo.MultiKeyStatusList
- // helper to get key status, default to enabled when missing
- getStatus := func(idx int) int {
- if statusList == nil {
- return common.ChannelStatusEnabled
- }
- if status, ok := statusList[idx]; ok {
- return status
- }
- return common.ChannelStatusEnabled
- }
-
- // Collect indexes of enabled keys
- enabledIdx := make([]int, 0, len(keys))
- for i := range keys {
- if getStatus(i) == common.ChannelStatusEnabled {
- enabledIdx = append(enabledIdx, i)
- }
- }
- // If no specific status list or none enabled, fall back to first key
- if len(enabledIdx) == 0 {
- return keys[0], 0, nil
- }
-
- switch channel.ChannelInfo.MultiKeyMode {
- case constant.MultiKeyModeRandom:
- // Randomly pick one enabled key
- selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
- return keys[selectedIdx], selectedIdx, nil
- case constant.MultiKeyModePolling:
- // Use channel-specific lock to ensure thread-safe polling
-
- channelInfo, err := CacheGetChannelInfo(channel.Id)
- if err != nil {
- return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
- }
- //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
- defer func() {
- if common.DebugEnabled {
- println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
- }
- if !common.MemoryCacheEnabled {
- _ = channel.SaveChannelInfo()
- } else {
- // CacheUpdateChannel(channel)
- }
- }()
- // Start from the saved polling index and look for the next enabled key
- start := channelInfo.MultiKeyPollingIndex
- if start < 0 || start >= len(keys) {
- start = 0
- }
- for i := 0; i < len(keys); i++ {
- idx := (start + i) % len(keys)
- if getStatus(idx) == common.ChannelStatusEnabled {
- // update polling index for next call (point to the next position)
- channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
- return keys[idx], idx, nil
- }
- }
- // Fallback – should not happen, but return first enabled key
- return keys[enabledIdx[0]], enabledIdx[0], nil
- default:
- // Unknown mode, default to first enabled key (or original key string)
- return keys[enabledIdx[0]], enabledIdx[0], nil
- }
-}
-
-func (channel *Channel) SaveChannelInfo() error {
- return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
-}
-
-func (channel *Channel) GetModels() []string {
- if channel.Models == "" {
- return []string{}
- }
- return strings.Split(strings.Trim(channel.Models, ","), ",")
-}
-
-func (channel *Channel) GetGroups() []string {
- if channel.Group == "" {
- return []string{}
- }
- groups := strings.Split(strings.Trim(channel.Group, ","), ",")
- for i, group := range groups {
- groups[i] = strings.TrimSpace(group)
- }
- return groups
-}
-
-func (channel *Channel) GetOtherInfo() map[string]interface{} {
- otherInfo := make(map[string]interface{})
- if channel.OtherInfo != "" {
- err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
- }
- }
- return otherInfo
-}
-
-func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
- otherInfoBytes, err := json.Marshal(otherInfo)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
- return
- }
- channel.OtherInfo = string(otherInfoBytes)
-}
-
-func (channel *Channel) GetTag() string {
- if channel.Tag == nil {
- return ""
- }
- return *channel.Tag
-}
-
-func (channel *Channel) SetTag(tag string) {
- channel.Tag = &tag
-}
-
-func (channel *Channel) GetAutoBan() bool {
- if channel.AutoBan == nil {
- return false
- }
- return *channel.AutoBan == 1
-}
-
-func (channel *Channel) Save() error {
- return DB.Save(channel).Error
-}
-
-func (channel *Channel) SaveWithoutKey() error {
- return DB.Omit("key").Save(channel).Error
-}
-
-func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
- var channels []*Channel
- var err error
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
- if selectAll {
- err = DB.Order(order).Find(&channels).Error
- } else {
- err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
- }
- return channels, err
-}
-
-func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
- var channels []*Channel
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
- err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error
- return channels, err
-}
-
-func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
- var channels []*Channel
- modelsCol := "`models`"
-
- // 如果是 PostgreSQL,使用双引号
- if common.UsingPostgreSQL {
- modelsCol = `"models"`
- }
-
- baseURLCol := "`base_url`"
- // 如果是 PostgreSQL,使用双引号
- if common.UsingPostgreSQL {
- baseURLCol = `"base_url"`
- }
-
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
-
- // 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit("key")
-
- // 构造WHERE子句
- var whereClause string
- var args []interface{}
- if group != "" && group != "null" {
- var groupCondition string
- if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
- } else {
- // sqlite, PostgreSQL
- groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
- }
- whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
- args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
- } else {
- whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
- args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
- }
-
- // 执行查询
- err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
- if err != nil {
- return nil, err
- }
- return channels, nil
-}
-
-func GetChannelById(id int, selectAll bool) (*Channel, error) {
- channel := &Channel{Id: id}
- var err error = nil
- if selectAll {
- err = DB.First(channel, "id = ?", id).Error
- } else {
- err = DB.Omit("key").First(channel, "id = ?", id).Error
- }
- if err != nil {
- return nil, err
- }
- if channel == nil {
- return nil, errors.New("channel not found")
- }
- return channel, nil
-}
-
-func BatchInsertChannels(channels []Channel) error {
- if len(channels) == 0 {
- return nil
- }
- tx := DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
-
- for _, chunk := range lo.Chunk(channels, 50) {
- if err := tx.Create(&chunk).Error; err != nil {
- tx.Rollback()
- return err
- }
- for _, channel_ := range chunk {
- if err := channel_.AddAbilities(tx); err != nil {
- tx.Rollback()
- return err
- }
- }
- }
- return tx.Commit().Error
-}
-
-func BatchDeleteChannels(ids []int) error {
- if len(ids) == 0 {
- return nil
- }
- // 使用事务 分批删除channel表和abilities表
- tx := DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- for _, chunk := range lo.Chunk(ids, 200) {
- if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
- tx.Rollback()
- return err
- }
- if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
- tx.Rollback()
- return err
- }
- }
- return tx.Commit().Error
-}
-
-func (channel *Channel) GetPriority() int64 {
- if channel.Priority == nil {
- return 0
- }
- return *channel.Priority
-}
-
-func (channel *Channel) GetWeight() int {
- if channel.Weight == nil {
- return 0
- }
- return int(*channel.Weight)
-}
-
-func (channel *Channel) GetBaseURL() string {
- if channel.BaseURL == nil {
- return ""
- }
- url := *channel.BaseURL
- if url == "" {
- url = constant.ChannelBaseURLs[channel.Type]
- }
- return url
-}
-
-func (channel *Channel) GetModelMapping() string {
- if channel.ModelMapping == nil {
- return ""
- }
- return *channel.ModelMapping
-}
-
-func (channel *Channel) GetStatusCodeMapping() string {
- if channel.StatusCodeMapping == nil {
- return ""
- }
- return *channel.StatusCodeMapping
-}
-
-func (channel *Channel) Insert() error {
- var err error
- err = DB.Create(channel).Error
- if err != nil {
- return err
- }
- err = channel.AddAbilities(nil)
- return err
-}
-
-func (channel *Channel) Update() error {
- // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
- if channel.ChannelInfo.IsMultiKey {
- var keyStr string
- if channel.Key != "" {
- keyStr = channel.Key
- } else {
- // If key is not provided, read the existing key from the database
- if existing, err := GetChannelById(channel.Id, true); err == nil {
- keyStr = existing.Key
- }
- }
- // Parse the key list (supports newline separation or JSON array)
- keys := []string{}
- if keyStr != "" {
- trimmed := strings.TrimSpace(keyStr)
- if strings.HasPrefix(trimmed, "[") {
- var arr []json.RawMessage
- if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
- keys = make([]string, len(arr))
- for i, v := range arr {
- keys[i] = string(v)
- }
- }
- }
- if len(keys) == 0 { // fallback to newline split
- keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
- }
- }
- channel.ChannelInfo.MultiKeySize = len(keys)
- // Clean up status data that exceeds the new key count to prevent index out of range
- if channel.ChannelInfo.MultiKeyStatusList != nil {
- for idx := range channel.ChannelInfo.MultiKeyStatusList {
- if idx >= channel.ChannelInfo.MultiKeySize {
- delete(channel.ChannelInfo.MultiKeyStatusList, idx)
- }
- }
- }
- }
- var err error
- err = DB.Model(channel).Updates(channel).Error
- if err != nil {
- return err
- }
- DB.Model(channel).First(channel, "id = ?", channel.Id)
- err = channel.UpdateAbilities(nil)
- return err
-}
-
-func (channel *Channel) UpdateResponseTime(responseTime int64) {
- err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
- TestTime: common.GetTimestamp(),
- ResponseTime: int(responseTime),
- }).Error
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
- }
-}
-
-func (channel *Channel) UpdateBalance(balance float64) {
- err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
- BalanceUpdatedTime: common.GetTimestamp(),
- Balance: balance,
- }).Error
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
- }
-}
-
-func (channel *Channel) Delete() error {
- var err error
- err = DB.Delete(channel).Error
- if err != nil {
- return err
- }
- err = channel.DeleteAbilities()
- return err
-}
-
-var channelStatusLock sync.Mutex
-
-// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
-var channelPollingLocks sync.Map
-
-// GetChannelPollingLock returns or creates a mutex for the given channel ID
-func GetChannelPollingLock(channelId int) *sync.Mutex {
- if lock, exists := channelPollingLocks.Load(channelId); exists {
- return lock.(*sync.Mutex)
- }
- // Create new lock for this channel
- newLock := &sync.Mutex{}
- actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
- return actual.(*sync.Mutex)
-}
-
-// CleanupChannelPollingLocks removes locks for channels that no longer exist
-// This is optional and can be called periodically to prevent memory leaks
-func CleanupChannelPollingLocks() {
- var activeChannelIds []int
- DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
-
- activeChannelSet := make(map[int]bool)
- for _, id := range activeChannelIds {
- activeChannelSet[id] = true
- }
-
- channelPollingLocks.Range(func(key, value interface{}) bool {
- channelId := key.(int)
- if !activeChannelSet[channelId] {
- channelPollingLocks.Delete(channelId)
- }
- return true
- })
-}
-
-func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
- keys := channel.GetKeys()
- if len(keys) == 0 {
- channel.Status = status
- } else {
- var keyIndex int
- for i, key := range keys {
- if key == usingKey {
- keyIndex = i
- break
- }
- }
- if channel.ChannelInfo.MultiKeyStatusList == nil {
- channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
- }
- if status == common.ChannelStatusEnabled {
- delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
- } else {
- channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
- if channel.ChannelInfo.MultiKeyDisabledReason == nil {
- channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
- }
- if channel.ChannelInfo.MultiKeyDisabledTime == nil {
- channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
- }
- channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
- channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
- }
- if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
- channel.Status = common.ChannelStatusAutoDisabled
- info := channel.GetOtherInfo()
- info["status_reason"] = "All keys are disabled"
- info["status_time"] = common.GetTimestamp()
- channel.SetOtherInfo(info)
- }
- }
-}
-
-func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
- if common.MemoryCacheEnabled {
- channelStatusLock.Lock()
- defer channelStatusLock.Unlock()
-
- channelCache, _ := CacheGetChannel(channelId)
- if channelCache == nil {
- return false
- }
- if channelCache.ChannelInfo.IsMultiKey {
- // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
- pollingLock := GetChannelPollingLock(channelId)
- pollingLock.Lock()
- // 如果是多Key模式,更新缓存中的状态
- handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
- pollingLock.Unlock()
- //CacheUpdateChannel(channelCache)
- //return true
- } else {
- // 如果缓存渠道存在,且状态已是目标状态,直接返回
- if channelCache.Status == status {
- return false
- }
- CacheUpdateChannelStatus(channelId, status)
- }
- }
-
- shouldUpdateAbilities := false
- defer func() {
- if shouldUpdateAbilities {
- err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
- }
- }
- }()
- channel, err := GetChannelById(channelId, true)
- if err != nil {
- return false
- } else {
- if channel.Status == status {
- return false
- }
-
- if channel.ChannelInfo.IsMultiKey {
- beforeStatus := channel.Status
- // Protect map writes with the same per-channel lock used by readers
- pollingLock := GetChannelPollingLock(channelId)
- pollingLock.Lock()
- handlerMultiKeyUpdate(channel, usingKey, status, reason)
- pollingLock.Unlock()
- if beforeStatus != channel.Status {
- shouldUpdateAbilities = true
- }
- } else {
- info := channel.GetOtherInfo()
- info["status_reason"] = reason
- info["status_time"] = common.GetTimestamp()
- channel.SetOtherInfo(info)
- channel.Status = status
- shouldUpdateAbilities = true
- }
- err = channel.SaveWithoutKey()
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
- return false
- }
- }
- return true
-}
-
-func EnableChannelByTag(tag string) error {
- err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
- if err != nil {
- return err
- }
- err = UpdateAbilityStatusByTag(tag, true)
- return err
-}
-
-func DisableChannelByTag(tag string) error {
- err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
- if err != nil {
- return err
- }
- err = UpdateAbilityStatusByTag(tag, false)
- return err
-}
-
-func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
- updateData := Channel{}
- shouldReCreateAbilities := false
- updatedTag := tag
- // 如果 newTag 不为空且不等于 tag,则更新 tag
- if newTag != nil && *newTag != tag {
- updateData.Tag = newTag
- updatedTag = *newTag
- }
- if modelMapping != nil && *modelMapping != "" {
- updateData.ModelMapping = modelMapping
- }
- if models != nil && *models != "" {
- shouldReCreateAbilities = true
- updateData.Models = *models
- }
- if group != nil && *group != "" {
- shouldReCreateAbilities = true
- updateData.Group = *group
- }
- if priority != nil {
- updateData.Priority = priority
- }
- if weight != nil {
- updateData.Weight = weight
- }
-
- err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
- if err != nil {
- return err
- }
- if shouldReCreateAbilities {
- channels, err := GetChannelsByTag(updatedTag, false)
- if err == nil {
- for _, channel := range channels {
- err = channel.UpdateAbilities(nil)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
- }
- }
- }
- } else {
- err := UpdateAbilityByTag(tag, newTag, priority, weight)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func UpdateChannelUsedQuota(id int, quota int) {
- if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
- return
- }
- updateChannelUsedQuota(id, quota)
-}
-
-func updateChannelUsedQuota(id int, quota int) {
- err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
- }
-}
-
-func DeleteChannelByStatus(status int64) (int64, error) {
- result := DB.Where("status = ?", status).Delete(&Channel{})
- return result.RowsAffected, result.Error
-}
-
-func DeleteDisabledChannel() (int64, error) {
- result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
- return result.RowsAffected, result.Error
-}
-
-func GetPaginatedTags(offset int, limit int) ([]*string, error) {
- var tags []*string
- err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
- return tags, err
-}
-
-func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
- var tags []*string
- modelsCol := "`models`"
-
- // 如果是 PostgreSQL,使用双引号
- if common.UsingPostgreSQL {
- modelsCol = `"models"`
- }
-
- baseURLCol := "`base_url`"
- // 如果是 PostgreSQL,使用双引号
- if common.UsingPostgreSQL {
- baseURLCol = `"base_url"`
- }
-
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
-
- // 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit("key")
-
- // 构造WHERE子句
- var whereClause string
- var args []interface{}
- if group != "" && group != "null" {
- var groupCondition string
- if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
- } else {
- // sqlite, PostgreSQL
- groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
- }
- whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
- args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
- } else {
- whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
- args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
- }
-
- subQuery := baseQuery.Where(whereClause, args...).
- Select("tag").
- Where("tag != ''").
- Order(order)
-
- err := DB.Table("(?) as sub", subQuery).
- Select("DISTINCT tag").
- Find(&tags).Error
-
- if err != nil {
- return nil, err
- }
-
- return tags, nil
-}
-
-func (channel *Channel) ValidateSettings() error {
- channelParams := &dto.ChannelSettings{}
- if channel.Setting != nil && *channel.Setting != "" {
- err := common.Unmarshal([]byte(*channel.Setting), channelParams)
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (channel *Channel) GetSetting() dto.ChannelSettings {
- setting := dto.ChannelSettings{}
- if channel.Setting != nil && *channel.Setting != "" {
- err := common.Unmarshal([]byte(*channel.Setting), &setting)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
- channel.Setting = nil // 清空设置以避免后续错误
- _ = channel.Save() // 保存修改
- }
- }
- return setting
-}
-
-func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
- settingBytes, err := common.Marshal(setting)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
- return
- }
- channel.Setting = common.GetPointer[string](string(settingBytes))
-}
-
-func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
- setting := dto.ChannelOtherSettings{}
- if channel.OtherSettings != "" {
- err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
- channel.OtherSettings = "{}" // 清空设置以避免后续错误
- _ = channel.Save() // 保存修改
- }
- }
- return setting
-}
-
-func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
- settingBytes, err := common.Marshal(setting)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
- return
- }
- channel.OtherSettings = string(settingBytes)
-}
-
-func (channel *Channel) GetParamOverride() map[string]interface{} {
- paramOverride := make(map[string]interface{})
- if channel.ParamOverride != nil && *channel.ParamOverride != "" {
- err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
- }
- }
- return paramOverride
-}
-
-func (channel *Channel) GetHeaderOverride() map[string]interface{} {
- headerOverride := make(map[string]interface{})
- if channel.HeaderOverride != nil && *channel.HeaderOverride != "" {
- err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err))
- }
- }
- return headerOverride
-}
-
-func GetChannelsByIds(ids []int) ([]*Channel, error) {
- var channels []*Channel
- err := DB.Where("id in (?)", ids).Find(&channels).Error
- return channels, err
-}
-
-func BatchSetChannelTag(ids []int, tag *string) error {
- // 开启事务
- tx := DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
-
- // 更新标签
- err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
- if err != nil {
- tx.Rollback()
- return err
- }
-
- // update ability status
- channels, err := GetChannelsByIds(ids)
- if err != nil {
- tx.Rollback()
- return err
- }
-
- for _, channel := range channels {
- err = channel.UpdateAbilities(tx)
- if err != nil {
- tx.Rollback()
- return err
- }
- }
-
- // 提交事务
- return tx.Commit().Error
-}
-
-// CountAllChannels returns total channels in DB
-func CountAllChannels() (int64, error) {
- var total int64
- err := DB.Model(&Channel{}).Count(&total).Error
- return total, err
-}
-
-// CountAllTags returns number of non-empty distinct tags
-func CountAllTags() (int64, error) {
- var total int64
- err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
- return total, err
-}
-
-// Get channels of specified type with pagination
-func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
- var channels []*Channel
- order := "priority desc"
- if idSort {
- order = "id desc"
- }
- err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
- return channels, err
-}
-
-// Count channels of specific type
-func CountChannelsByType(channelType int) (int64, error) {
- var count int64
- err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
- return count, err
-}
-
-// Return map[type]count for all channels
-func CountChannelsGroupByType() (map[int64]int64, error) {
- type result struct {
- Type int64 `gorm:"column:type"`
- Count int64 `gorm:"column:count"`
- }
- var results []result
- err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
- if err != nil {
- return nil, err
- }
- counts := make(map[int64]int64)
- for _, r := range results {
- counts[r.Type] = r.Count
- }
- return counts, nil
-}
diff --git a/new-api/model/channel_cache.go b/new-api/model/channel_cache.go
deleted file mode 100644
index f5bab66514151576b1642b8eba64fdbeb0ff1bab..0000000000000000000000000000000000000000
--- a/new-api/model/channel_cache.go
+++ /dev/null
@@ -1,284 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "math/rand"
- "one-api/common"
- "one-api/constant"
- "one-api/setting"
- "one-api/setting/ratio_setting"
- "sort"
- "strings"
- "sync"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-var group2model2channels map[string]map[string][]int // enabled channel
-var channelsIDM map[int]*Channel // all channels include disabled
-var channelSyncLock sync.RWMutex
-
-func InitChannelCache() {
- if !common.MemoryCacheEnabled {
- return
- }
- newChannelId2channel := make(map[int]*Channel)
- var channels []*Channel
- DB.Find(&channels)
- for _, channel := range channels {
- newChannelId2channel[channel.Id] = channel
- }
- var abilities []*Ability
- DB.Find(&abilities)
- groups := make(map[string]bool)
- for _, ability := range abilities {
- groups[ability.Group] = true
- }
- newGroup2model2channels := make(map[string]map[string][]int)
- for group := range groups {
- newGroup2model2channels[group] = make(map[string][]int)
- }
- for _, channel := range channels {
- if channel.Status != common.ChannelStatusEnabled {
- continue // skip disabled channels
- }
- groups := strings.Split(channel.Group, ",")
- for _, group := range groups {
- models := strings.Split(channel.Models, ",")
- for _, model := range models {
- if _, ok := newGroup2model2channels[group][model]; !ok {
- newGroup2model2channels[group][model] = make([]int, 0)
- }
- newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
- }
- }
- }
-
- // sort by priority
- for group, model2channels := range newGroup2model2channels {
- for model, channels := range model2channels {
- sort.Slice(channels, func(i, j int) bool {
- return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
- })
- newGroup2model2channels[group][model] = channels
- }
- }
-
- channelSyncLock.Lock()
- group2model2channels = newGroup2model2channels
- //channelsIDM = newChannelId2channel
- for i, channel := range newChannelId2channel {
- if channel.ChannelInfo.IsMultiKey {
- channel.Keys = channel.GetKeys()
- if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
- if oldChannel, ok := channelsIDM[i]; ok {
- // 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
- if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
- channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
- }
- }
- }
- }
- }
- channelsIDM = newChannelId2channel
- channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
-}
-
-func SyncChannelCache(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
- InitChannelCache()
- }
-}
-
-func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
- var channel *Channel
- var err error
- selectGroup := group
- if group == "auto" {
- if len(setting.AutoGroups) == 0 {
- return nil, selectGroup, errors.New("auto groups is not enabled")
- }
- for _, autoGroup := range setting.AutoGroups {
- if common.DebugEnabled {
- println("autoGroup:", autoGroup)
- }
- channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
- if channel == nil {
- continue
- } else {
- c.Set("auto_group", autoGroup)
- selectGroup = autoGroup
- if common.DebugEnabled {
- println("selectGroup:", selectGroup)
- }
- break
- }
- }
- } else {
- channel, err = getRandomSatisfiedChannel(group, model, retry)
- if err != nil {
- return nil, group, err
- }
- }
- return channel, selectGroup, nil
-}
-
-func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
- // if memory cache is disabled, get channel directly from database
- if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model, retry)
- }
-
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
-
- // First, try to find channels with the exact model name.
- channels := group2model2channels[group][model]
-
- // If no channels found, try to find channels with the normalized model name.
- if len(channels) == 0 {
- normalizedModel := ratio_setting.FormatMatchingModelName(model)
- channels = group2model2channels[group][normalizedModel]
- }
-
- if len(channels) == 0 {
- return nil, nil
- }
-
- if len(channels) == 1 {
- if channel, ok := channelsIDM[channels[0]]; ok {
- return channel, nil
- }
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
- }
-
- uniquePriorities := make(map[int]bool)
- for _, channelId := range channels {
- if channel, ok := channelsIDM[channelId]; ok {
- uniquePriorities[int(channel.GetPriority())] = true
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
- var sortedUniquePriorities []int
- for priority := range uniquePriorities {
- sortedUniquePriorities = append(sortedUniquePriorities, priority)
- }
- sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
-
- if retry >= len(uniquePriorities) {
- retry = len(uniquePriorities) - 1
- }
- targetPriority := int64(sortedUniquePriorities[retry])
-
- // get the priority for the given retry number
- var targetChannels []*Channel
- for _, channelId := range channels {
- if channel, ok := channelsIDM[channelId]; ok {
- if channel.GetPriority() == targetPriority {
- targetChannels = append(targetChannels, channel)
- }
- } else {
- return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
- }
- }
-
- // 平滑系数
- smoothingFactor := 10
- // Calculate the total weight of all channels up to endIdx
- totalWeight := 0
- for _, channel := range targetChannels {
- totalWeight += channel.GetWeight() + smoothingFactor
- }
- // Generate a random value in the range [0, totalWeight)
- randomWeight := rand.Intn(totalWeight)
-
- // Find a channel based on its weight
- for _, channel := range targetChannels {
- randomWeight -= channel.GetWeight() + smoothingFactor
- if randomWeight < 0 {
- return channel, nil
- }
- }
- // return null if no channel is not found
- return nil, errors.New("channel not found")
-}
-
-func CacheGetChannel(id int) (*Channel, error) {
- if !common.MemoryCacheEnabled {
- return GetChannelById(id, true)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
-
- c, ok := channelsIDM[id]
- if !ok {
- return nil, fmt.Errorf("渠道# %d,已不存在", id)
- }
- return c, nil
-}
-
-func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
- if !common.MemoryCacheEnabled {
- channel, err := GetChannelById(id, true)
- if err != nil {
- return nil, err
- }
- return &channel.ChannelInfo, nil
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
-
- c, ok := channelsIDM[id]
- if !ok {
- return nil, fmt.Errorf("渠道# %d,已不存在", id)
- }
- return &c.ChannelInfo, nil
-}
-
-func CacheUpdateChannelStatus(id int, status int) {
- if !common.MemoryCacheEnabled {
- return
- }
- channelSyncLock.Lock()
- defer channelSyncLock.Unlock()
- if channel, ok := channelsIDM[id]; ok {
- channel.Status = status
- }
- if status != common.ChannelStatusEnabled {
- // delete the channel from group2model2channels
- for group, model2channels := range group2model2channels {
- for model, channels := range model2channels {
- for i, channelId := range channels {
- if channelId == id {
- // remove the channel from the slice
- group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
- break
- }
- }
- }
- }
- }
-}
-
-func CacheUpdateChannel(channel *Channel) {
- if !common.MemoryCacheEnabled {
- return
- }
- channelSyncLock.Lock()
- defer channelSyncLock.Unlock()
- if channel == nil {
- return
- }
-
- println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
-
- println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
- channelsIDM[channel.Id] = channel
- println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
-}
diff --git a/new-api/model/log.go b/new-api/model/log.go
deleted file mode 100644
index f13bd853ab8fd725ae2988fbddc82f51fdcaab49..0000000000000000000000000000000000000000
--- a/new-api/model/log.go
+++ /dev/null
@@ -1,408 +0,0 @@
-package model
-
-import (
- "context"
- "fmt"
- "one-api/common"
- "one-api/logger"
- "one-api/types"
- "os"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
-)
-
-type Log struct {
- Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
- UserId int `json:"user_id" gorm:"index"`
- CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
- Type int `json:"type" gorm:"index:idx_created_at_type"`
- Content string `json:"content"`
- Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
- TokenName string `json:"token_name" gorm:"index;default:''"`
- ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
- Quota int `json:"quota" gorm:"default:0"`
- PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
- CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
- UseTime int `json:"use_time" gorm:"default:0"`
- IsStream bool `json:"is_stream"`
- ChannelId int `json:"channel" gorm:"index"`
- ChannelName string `json:"channel_name" gorm:"->"`
- TokenId int `json:"token_id" gorm:"default:0;index"`
- Group string `json:"group" gorm:"index"`
- Ip string `json:"ip" gorm:"index;default:''"`
- Other string `json:"other"`
-}
-
-const (
- LogTypeUnknown = iota
- LogTypeTopup
- LogTypeConsume
- LogTypeManage
- LogTypeSystem
- LogTypeError
-)
-
-func formatUserLogs(logs []*Log) {
- for i := range logs {
- logs[i].ChannelName = ""
- var otherMap map[string]interface{}
- otherMap, _ = common.StrToMap(logs[i].Other)
- if otherMap != nil {
- // delete admin
- delete(otherMap, "admin_info")
- }
- logs[i].Other = common.MapToJsonStr(otherMap)
- logs[i].Id = logs[i].Id % 1024
- }
-}
-
-func GetLogByKey(key string) (logs []*Log, err error) {
- if os.Getenv("LOG_SQL_DSN") != "" {
- var tk Token
- if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
- return nil, err
- }
- err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
- } else {
- err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
- }
- formatUserLogs(logs)
- return logs, err
-}
-
-func RecordLog(userId int, logType int, content string) {
- if logType == LogTypeConsume && !common.LogConsumeEnabled {
- return
- }
- username, _ := GetUsernameById(userId, false)
- log := &Log{
- UserId: userId,
- Username: username,
- CreatedAt: common.GetTimestamp(),
- Type: logType,
- Content: content,
- }
- err := LOG_DB.Create(log).Error
- if err != nil {
- common.SysLog("failed to record log: " + err.Error())
- }
-}
-
-func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
- isStream bool, group string, other map[string]interface{}) {
- logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
- username := c.GetString("username")
- otherStr := common.MapToJsonStr(other)
- // 判断是否需要记录 IP
- needRecordIp := false
- if settingMap, err := GetUserSetting(userId, false); err == nil {
- if settingMap.RecordIpLog {
- needRecordIp = true
- }
- }
- log := &Log{
- UserId: userId,
- Username: username,
- CreatedAt: common.GetTimestamp(),
- Type: LogTypeError,
- Content: content,
- PromptTokens: 0,
- CompletionTokens: 0,
- TokenName: tokenName,
- ModelName: modelName,
- Quota: 0,
- ChannelId: channelId,
- TokenId: tokenId,
- UseTime: useTimeSeconds,
- IsStream: isStream,
- Group: group,
- Ip: func() string {
- if needRecordIp {
- return c.ClientIP()
- }
- return ""
- }(),
- Other: otherStr,
- }
- err := LOG_DB.Create(log).Error
- if err != nil {
- logger.LogError(c, "failed to record log: "+err.Error())
- }
-}
-
-type RecordConsumeLogParams struct {
- ChannelId int `json:"channel_id"`
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- ModelName string `json:"model_name"`
- TokenName string `json:"token_name"`
- Quota int `json:"quota"`
- Content string `json:"content"`
- TokenId int `json:"token_id"`
- UseTimeSeconds int `json:"use_time_seconds"`
- IsStream bool `json:"is_stream"`
- Group string `json:"group"`
- Other map[string]interface{} `json:"other"`
-}
-
-func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
- if !common.LogConsumeEnabled {
- return
- }
- logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
- username := c.GetString("username")
- otherStr := common.MapToJsonStr(params.Other)
- // 判断是否需要记录 IP
- needRecordIp := false
- if settingMap, err := GetUserSetting(userId, false); err == nil {
- if settingMap.RecordIpLog {
- needRecordIp = true
- }
- }
- log := &Log{
- UserId: userId,
- Username: username,
- CreatedAt: common.GetTimestamp(),
- Type: LogTypeConsume,
- Content: params.Content,
- PromptTokens: params.PromptTokens,
- CompletionTokens: params.CompletionTokens,
- TokenName: params.TokenName,
- ModelName: params.ModelName,
- Quota: params.Quota,
- ChannelId: params.ChannelId,
- TokenId: params.TokenId,
- UseTime: params.UseTimeSeconds,
- IsStream: params.IsStream,
- Group: params.Group,
- Ip: func() string {
- if needRecordIp {
- return c.ClientIP()
- }
- return ""
- }(),
- Other: otherStr,
- }
- err := LOG_DB.Create(log).Error
- if err != nil {
- logger.LogError(c, "failed to record log: "+err.Error())
- }
- if common.DataExportEnabled {
- gopool.Go(func() {
- LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
- })
- }
-}
-
-func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
- var tx *gorm.DB
- if logType == LogTypeUnknown {
- tx = LOG_DB
- } else {
- tx = LOG_DB.Where("logs.type = ?", logType)
- }
-
- if modelName != "" {
- tx = tx.Where("logs.model_name like ?", modelName)
- }
- if username != "" {
- tx = tx.Where("logs.username = ?", username)
- }
- if tokenName != "" {
- tx = tx.Where("logs.token_name = ?", tokenName)
- }
- if startTimestamp != 0 {
- tx = tx.Where("logs.created_at >= ?", startTimestamp)
- }
- if endTimestamp != 0 {
- tx = tx.Where("logs.created_at <= ?", endTimestamp)
- }
- if channel != 0 {
- tx = tx.Where("logs.channel_id = ?", channel)
- }
- if group != "" {
- tx = tx.Where("logs."+logGroupCol+" = ?", group)
- }
- err = tx.Model(&Log{}).Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
- if err != nil {
- return nil, 0, err
- }
-
- channelIds := types.NewSet[int]()
- for _, log := range logs {
- if log.ChannelId != 0 {
- channelIds.Add(log.ChannelId)
- }
- }
-
- if channelIds.Len() > 0 {
- var channels []struct {
- Id int `gorm:"column:id"`
- Name string `gorm:"column:name"`
- }
- if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
- return logs, total, err
- }
- channelMap := make(map[int]string, len(channels))
- for _, channel := range channels {
- channelMap[channel.Id] = channel.Name
- }
- for i := range logs {
- logs[i].ChannelName = channelMap[logs[i].ChannelId]
- }
- }
-
- return logs, total, err
-}
-
-func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
- var tx *gorm.DB
- if logType == LogTypeUnknown {
- tx = LOG_DB.Where("logs.user_id = ?", userId)
- } else {
- tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
- }
-
- if modelName != "" {
- tx = tx.Where("logs.model_name like ?", modelName)
- }
- if tokenName != "" {
- tx = tx.Where("logs.token_name = ?", tokenName)
- }
- if startTimestamp != 0 {
- tx = tx.Where("logs.created_at >= ?", startTimestamp)
- }
- if endTimestamp != 0 {
- tx = tx.Where("logs.created_at <= ?", endTimestamp)
- }
- if group != "" {
- tx = tx.Where("logs."+logGroupCol+" = ?", group)
- }
- err = tx.Model(&Log{}).Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
- if err != nil {
- return nil, 0, err
- }
-
- formatUserLogs(logs)
- return logs, total, err
-}
-
-func SearchAllLogs(keyword string) (logs []*Log, err error) {
- err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
- return logs, err
-}
-
-func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
- err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
- formatUserLogs(logs)
- return logs, err
-}
-
-type Stat struct {
- Quota int `json:"quota"`
- Rpm int `json:"rpm"`
- Tpm int `json:"tpm"`
-}
-
-func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
- tx := LOG_DB.Table("logs").Select("sum(quota) quota")
-
- // 为rpm和tpm创建单独的查询
- rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
-
- if username != "" {
- tx = tx.Where("username = ?", username)
- rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
- }
- if tokenName != "" {
- tx = tx.Where("token_name = ?", tokenName)
- rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
- }
- if startTimestamp != 0 {
- tx = tx.Where("created_at >= ?", startTimestamp)
- }
- if endTimestamp != 0 {
- tx = tx.Where("created_at <= ?", endTimestamp)
- }
- if modelName != "" {
- tx = tx.Where("model_name like ?", modelName)
- rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
- }
- if channel != 0 {
- tx = tx.Where("channel_id = ?", channel)
- rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
- }
- if group != "" {
- tx = tx.Where(logGroupCol+" = ?", group)
- rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
- }
-
- tx = tx.Where("type = ?", LogTypeConsume)
- rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
-
- // 只统计最近60秒的rpm和tpm
- rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
-
- // 执行查询
- tx.Scan(&stat)
- rpmTpmQuery.Scan(&stat)
-
- return stat
-}
-
-func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
- tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
- if username != "" {
- tx = tx.Where("username = ?", username)
- }
- if tokenName != "" {
- tx = tx.Where("token_name = ?", tokenName)
- }
- if startTimestamp != 0 {
- tx = tx.Where("created_at >= ?", startTimestamp)
- }
- if endTimestamp != 0 {
- tx = tx.Where("created_at <= ?", endTimestamp)
- }
- if modelName != "" {
- tx = tx.Where("model_name = ?", modelName)
- }
- tx.Where("type = ?", LogTypeConsume).Scan(&token)
- return token
-}
-
-func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
- var total int64 = 0
-
- for {
- if nil != ctx.Err() {
- return total, ctx.Err()
- }
-
- result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
- if nil != result.Error {
- return total, result.Error
- }
-
- total += result.RowsAffected
-
- if result.RowsAffected < int64(limit) {
- break
- }
- }
-
- return total, nil
-}
diff --git a/new-api/model/main.go b/new-api/model/main.go
deleted file mode 100644
index a6f46a7fdf622f60452cec28b725485d52316e8c..0000000000000000000000000000000000000000
--- a/new-api/model/main.go
+++ /dev/null
@@ -1,477 +0,0 @@
-package model
-
-import (
- "fmt"
- "log"
- "one-api/common"
- "one-api/constant"
- "os"
- "strings"
- "sync"
- "time"
-
- "github.com/glebarez/sqlite"
- "gorm.io/driver/mysql"
- "gorm.io/driver/postgres"
- "gorm.io/gorm"
-)
-
-var commonGroupCol string
-var commonKeyCol string
-var commonTrueVal string
-var commonFalseVal string
-
-var logKeyCol string
-var logGroupCol string
-
-func initCol() {
- // init common column names
- if common.UsingPostgreSQL {
- commonGroupCol = `"group"`
- commonKeyCol = `"key"`
- commonTrueVal = "true"
- commonFalseVal = "false"
- } else {
- commonGroupCol = "`group`"
- commonKeyCol = "`key`"
- commonTrueVal = "1"
- commonFalseVal = "0"
- }
- if os.Getenv("LOG_SQL_DSN") != "" {
- switch common.LogSqlType {
- case common.DatabaseTypePostgreSQL:
- logGroupCol = `"group"`
- logKeyCol = `"key"`
- default:
- logGroupCol = commonGroupCol
- logKeyCol = commonKeyCol
- }
- } else {
- // LOG_SQL_DSN 为空时,日志数据库与主数据库相同
- if common.UsingPostgreSQL {
- logGroupCol = `"group"`
- logKeyCol = `"key"`
- } else {
- logGroupCol = commonGroupCol
- logKeyCol = commonKeyCol
- }
- }
- // log sql type and database type
- //common.SysLog("Using Log SQL Type: " + common.LogSqlType)
-}
-
-var DB *gorm.DB
-
-var LOG_DB *gorm.DB
-
-func createRootAccountIfNeed() error {
- var user User
- //if user.Status != common.UserStatusEnabled {
- if err := DB.First(&user).Error; err != nil {
- common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
- hashedPassword, err := common.Password2Hash("123456")
- if err != nil {
- return err
- }
- rootUser := User{
- Username: "root",
- Password: hashedPassword,
- Role: common.RoleRootUser,
- Status: common.UserStatusEnabled,
- DisplayName: "Root User",
- AccessToken: nil,
- Quota: 100000000,
- }
- DB.Create(&rootUser)
- }
- return nil
-}
-
-func CheckSetup() {
- setup := GetSetup()
- if setup == nil {
- // No setup record exists, check if we have a root user
- if RootUserExists() {
- common.SysLog("system is not initialized, but root user exists")
- // Create setup record
- newSetup := Setup{
- Version: common.Version,
- InitializedAt: time.Now().Unix(),
- }
- err := DB.Create(&newSetup).Error
- if err != nil {
- common.SysLog("failed to create setup record: " + err.Error())
- }
- constant.Setup = true
- } else {
- common.SysLog("system is not initialized and no root user exists")
- constant.Setup = false
- }
- } else {
- // Setup record exists, system is initialized
- common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
- constant.Setup = true
- }
-}
-
-func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
- defer func() {
- initCol()
- }()
- dsn := os.Getenv(envName)
- if dsn != "" {
- if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
- // Use PostgreSQL
- common.SysLog("using PostgreSQL as database")
- if !isLog {
- common.UsingPostgreSQL = true
- } else {
- common.LogSqlType = common.DatabaseTypePostgreSQL
- }
- return gorm.Open(postgres.New(postgres.Config{
- DSN: dsn,
- PreferSimpleProtocol: true, // disables implicit prepared statement usage
- }), &gorm.Config{
- PrepareStmt: true, // precompile SQL
- })
- }
- if strings.HasPrefix(dsn, "local") {
- common.SysLog("SQL_DSN not set, using SQLite as database")
- if !isLog {
- common.UsingSQLite = true
- } else {
- common.LogSqlType = common.DatabaseTypeSQLite
- }
- return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
- PrepareStmt: true, // precompile SQL
- })
- }
- // Use MySQL
- common.SysLog("using MySQL as database")
- // check parseTime
- if !strings.Contains(dsn, "parseTime") {
- if strings.Contains(dsn, "?") {
- dsn += "&parseTime=true"
- } else {
- dsn += "?parseTime=true"
- }
- }
- if !isLog {
- common.UsingMySQL = true
- } else {
- common.LogSqlType = common.DatabaseTypeMySQL
- }
- return gorm.Open(mysql.Open(dsn), &gorm.Config{
- PrepareStmt: true, // precompile SQL
- })
- }
- // Use SQLite
- common.SysLog("SQL_DSN not set, using SQLite as database")
- common.UsingSQLite = true
- return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
- PrepareStmt: true, // precompile SQL
- })
-}
-
-func InitDB() (err error) {
- db, err := chooseDB("SQL_DSN", false)
- if err == nil {
- if common.DebugEnabled {
- db = db.Debug()
- }
- DB = db
- // MySQL charset/collation startup check: ensure Chinese-capable charset
- if common.UsingMySQL {
- if err := checkMySQLChineseSupport(DB); err != nil {
- panic(err)
- }
- }
- sqlDB, err := DB.DB()
- if err != nil {
- return err
- }
- sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
- sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
- sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
-
- if !common.IsMasterNode {
- return nil
- }
- if common.UsingMySQL {
- //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
- }
- common.SysLog("database migration started")
- err = migrateDB()
- return err
- } else {
- common.FatalLog(err)
- }
- return err
-}
-
-func InitLogDB() (err error) {
- if os.Getenv("LOG_SQL_DSN") == "" {
- LOG_DB = DB
- return
- }
- db, err := chooseDB("LOG_SQL_DSN", true)
- if err == nil {
- if common.DebugEnabled {
- db = db.Debug()
- }
- LOG_DB = db
- // If log DB is MySQL, also ensure Chinese-capable charset
- if common.LogSqlType == common.DatabaseTypeMySQL {
- if err := checkMySQLChineseSupport(LOG_DB); err != nil {
- panic(err)
- }
- }
- sqlDB, err := LOG_DB.DB()
- if err != nil {
- return err
- }
- sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
- sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
- sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
-
- if !common.IsMasterNode {
- return nil
- }
- common.SysLog("database migration started")
- err = migrateLOGDB()
- return err
- } else {
- common.FatalLog(err)
- }
- return err
-}
-
-func migrateDB() error {
- err := DB.AutoMigrate(
- &Channel{},
- &Token{},
- &User{},
- &PasskeyCredential{},
- &Option{},
- &Redemption{},
- &Ability{},
- &Log{},
- &Midjourney{},
- &TopUp{},
- &QuotaData{},
- &Task{},
- &Model{},
- &Vendor{},
- &PrefillGroup{},
- &Setup{},
- &TwoFA{},
- &TwoFABackupCode{},
- )
- if err != nil {
- return err
- }
- return nil
-}
-
-func migrateDBFast() error {
-
- var wg sync.WaitGroup
-
- migrations := []struct {
- model interface{}
- name string
- }{
- {&Channel{}, "Channel"},
- {&Token{}, "Token"},
- {&User{}, "User"},
- {&PasskeyCredential{}, "PasskeyCredential"},
- {&Option{}, "Option"},
- {&Redemption{}, "Redemption"},
- {&Ability{}, "Ability"},
- {&Log{}, "Log"},
- {&Midjourney{}, "Midjourney"},
- {&TopUp{}, "TopUp"},
- {&QuotaData{}, "QuotaData"},
- {&Task{}, "Task"},
- {&Model{}, "Model"},
- {&Vendor{}, "Vendor"},
- {&PrefillGroup{}, "PrefillGroup"},
- {&Setup{}, "Setup"},
- {&TwoFA{}, "TwoFA"},
- {&TwoFABackupCode{}, "TwoFABackupCode"},
- }
- // 动态计算migration数量,确保errChan缓冲区足够大
- errChan := make(chan error, len(migrations))
-
- for _, m := range migrations {
- wg.Add(1)
- go func(model interface{}, name string) {
- defer wg.Done()
- if err := DB.AutoMigrate(model); err != nil {
- errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
- }
- }(m.model, m.name)
- }
-
- // Wait for all migrations to complete
- wg.Wait()
- close(errChan)
-
- // Check for any errors
- for err := range errChan {
- if err != nil {
- return err
- }
- }
- common.SysLog("database migrated")
- return nil
-}
-
-func migrateLOGDB() error {
- var err error
- if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
- return err
- }
- return nil
-}
-
-func closeDB(db *gorm.DB) error {
- sqlDB, err := db.DB()
- if err != nil {
- return err
- }
- err = sqlDB.Close()
- return err
-}
-
-func CloseDB() error {
- if LOG_DB != DB {
- err := closeDB(LOG_DB)
- if err != nil {
- return err
- }
- }
- return closeDB(DB)
-}
-
-// checkMySQLChineseSupport ensures the MySQL connection and current schema
-// default charset/collation can store Chinese characters. It allows common
-// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
-func checkMySQLChineseSupport(db *gorm.DB) error {
- // 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
-
- // Read current schema defaults
- var schemaCharset, schemaCollation string
- err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
- if err != nil {
- return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
- }
-
- toLower := func(s string) string { return strings.ToLower(s) }
- // Allowed charsets that can store Chinese text
- allowedCharsets := map[string]string{
- "utf8mb4": "utf8mb4_",
- "utf8": "utf8_",
- "gbk": "gbk_",
- "big5": "big5_",
- "gb18030": "gb18030_",
- }
- isChineseCapable := func(cs, cl string) bool {
- csLower := toLower(cs)
- clLower := toLower(cl)
- if prefix, ok := allowedCharsets[csLower]; ok {
- if clLower == "" {
- return true
- }
- return strings.HasPrefix(clLower, prefix)
- }
- // 如果仅提供了排序规则,尝试按排序规则前缀判断
- for _, prefix := range allowedCharsets {
- if strings.HasPrefix(clLower, prefix) {
- return true
- }
- }
- return false
- }
-
- // 1) 当前库默认值必须支持中文
- if !isChineseCapable(schemaCharset, schemaCollation) {
- return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
- schemaCharset, schemaCollation, schemaCharset, schemaCollation)
- }
-
- // 2) 所有物理表的排序规则(隐含字符集)必须支持中文
- type tableInfo struct {
- Name string
- Collation *string
- }
- var tables []tableInfo
- if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
- return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
- }
-
- var badTables []string
- for _, t := range tables {
- // NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
- if t.Collation == nil || *t.Collation == "" {
- continue
- }
- cl := *t.Collation
- // 仅凭排序规则判断是否中文可用
- ok := false
- lower := strings.ToLower(cl)
- for _, prefix := range allowedCharsets {
- if strings.HasPrefix(lower, prefix) {
- ok = true
- break
- }
- }
- if !ok {
- badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
- }
- }
-
- if len(badTables) > 0 {
- // 限制输出数量以避免日志过长
- maxShow := 20
- shown := badTables
- if len(shown) > maxShow {
- shown = shown[:maxShow]
- }
- return fmt.Errorf(
- "存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
- maxShow, shown, maxShow, shown,
- )
- }
- return nil
-}
-
-var (
- lastPingTime time.Time
- pingMutex sync.Mutex
-)
-
-func PingDB() error {
- pingMutex.Lock()
- defer pingMutex.Unlock()
-
- if time.Since(lastPingTime) < time.Second*10 {
- return nil
- }
-
- sqlDB, err := DB.DB()
- if err != nil {
- log.Printf("Error getting sql.DB from GORM: %v", err)
- return err
- }
-
- err = sqlDB.Ping()
- if err != nil {
- log.Printf("Error pinging DB: %v", err)
- return err
- }
-
- lastPingTime = time.Now()
- common.SysLog("Database pinged successfully")
- return nil
-}
diff --git a/new-api/model/midjourney.go b/new-api/model/midjourney.go
deleted file mode 100644
index 363a9162cee6ba7d489850f3a8bcfca3646ff5ef..0000000000000000000000000000000000000000
--- a/new-api/model/midjourney.go
+++ /dev/null
@@ -1,207 +0,0 @@
-package model
-
-type Midjourney struct {
- Id int `json:"id"`
- Code int `json:"code"`
- UserId int `json:"user_id" gorm:"index"`
- Action string `json:"action" gorm:"type:varchar(40);index"`
- MjId string `json:"mj_id" gorm:"index"`
- Prompt string `json:"prompt"`
- PromptEn string `json:"prompt_en"`
- Description string `json:"description"`
- State string `json:"state"`
- SubmitTime int64 `json:"submit_time" gorm:"index"`
- StartTime int64 `json:"start_time" gorm:"index"`
- FinishTime int64 `json:"finish_time" gorm:"index"`
- ImageUrl string `json:"image_url"`
- VideoUrl string `json:"video_url"`
- VideoUrls string `json:"video_urls"`
- Status string `json:"status" gorm:"type:varchar(20);index"`
- Progress string `json:"progress" gorm:"type:varchar(30);index"`
- FailReason string `json:"fail_reason"`
- ChannelId int `json:"channel_id"`
- Quota int `json:"quota"`
- Buttons string `json:"buttons"`
- Properties string `json:"properties"`
-}
-
-// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
-type TaskQueryParams struct {
- ChannelID string
- MjID string
- StartTimestamp string
- EndTimestamp string
-}
-
-func GetAllUserTask(userId int, startIdx int, num int, queryParams TaskQueryParams) []*Midjourney {
- var tasks []*Midjourney
- var err error
-
- // 初始化查询构建器
- query := DB.Where("user_id = ?", userId)
-
- if queryParams.MjID != "" {
- query = query.Where("mj_id = ?", queryParams.MjID)
- }
- if queryParams.StartTimestamp != "" {
- // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != "" {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
-
- // 获取数据
- err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
- if err != nil {
- return nil
- }
-
- return tasks
-}
-
-func GetAllTasks(startIdx int, num int, queryParams TaskQueryParams) []*Midjourney {
- var tasks []*Midjourney
- var err error
-
- // 初始化查询构建器
- query := DB
-
- // 添加过滤条件
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.MjID != "" {
- query = query.Where("mj_id = ?", queryParams.MjID)
- }
- if queryParams.StartTimestamp != "" {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != "" {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
-
- // 获取数据
- err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
- if err != nil {
- return nil
- }
-
- return tasks
-}
-
-func GetAllUnFinishTasks() []*Midjourney {
- var tasks []*Midjourney
- var err error
- // get all tasks progress is not 100%
- err = DB.Where("progress != ?", "100%").Find(&tasks).Error
- if err != nil {
- return nil
- }
- return tasks
-}
-
-func GetByOnlyMJId(mjId string) *Midjourney {
- var mj *Midjourney
- var err error
- err = DB.Where("mj_id = ?", mjId).First(&mj).Error
- if err != nil {
- return nil
- }
- return mj
-}
-
-func GetByMJId(userId int, mjId string) *Midjourney {
- var mj *Midjourney
- var err error
- err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
- if err != nil {
- return nil
- }
- return mj
-}
-
-func GetByMJIds(userId int, mjIds []string) []*Midjourney {
- var mj []*Midjourney
- var err error
- err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
- if err != nil {
- return nil
- }
- return mj
-}
-
-func GetMjByuId(id int) *Midjourney {
- var mj *Midjourney
- var err error
- err = DB.Where("id = ?", id).First(&mj).Error
- if err != nil {
- return nil
- }
- return mj
-}
-
-func UpdateProgress(id int, progress string) error {
- return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
-func (midjourney *Midjourney) Insert() error {
- var err error
- err = DB.Create(midjourney).Error
- return err
-}
-
-func (midjourney *Midjourney) Update() error {
- var err error
- err = DB.Save(midjourney).Error
- return err
-}
-
-func MjBulkUpdate(mjIds []string, params map[string]any) error {
- return DB.Model(&Midjourney{}).
- Where("mj_id in (?)", mjIds).
- Updates(params).Error
-}
-
-func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
- return DB.Model(&Midjourney{}).
- Where("id in (?)", taskIDs).
- Updates(params).Error
-}
-
-// CountAllTasks returns total midjourney tasks for admin query
-func CountAllTasks(queryParams TaskQueryParams) int64 {
- var total int64
- query := DB.Model(&Midjourney{})
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.MjID != "" {
- query = query.Where("mj_id = ?", queryParams.MjID)
- }
- if queryParams.StartTimestamp != "" {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != "" {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- _ = query.Count(&total).Error
- return total
-}
-
-// CountAllUserTask returns total midjourney tasks for user
-func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
- var total int64
- query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
- if queryParams.MjID != "" {
- query = query.Where("mj_id = ?", queryParams.MjID)
- }
- if queryParams.StartTimestamp != "" {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != "" {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- _ = query.Count(&total).Error
- return total
-}
diff --git a/new-api/model/missing_models.go b/new-api/model/missing_models.go
deleted file mode 100644
index 2b8ac4cdc792afe09ce0759bde79da9997f76c23..0000000000000000000000000000000000000000
--- a/new-api/model/missing_models.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package model
-
-// GetMissingModels returns model names that are referenced in the system
-func GetMissingModels() ([]string, error) {
- // 1. 获取所有已启用模型(去重)
- models := GetEnabledModels()
- if len(models) == 0 {
- return []string{}, nil
- }
-
- // 2. 查询已有的元数据模型名
- var existing []string
- if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
- return nil, err
- }
-
- existingSet := make(map[string]struct{}, len(existing))
- for _, e := range existing {
- existingSet[e] = struct{}{}
- }
-
- // 3. 收集缺失模型
- var missing []string
- for _, name := range models {
- if _, ok := existingSet[name]; !ok {
- missing = append(missing, name)
- }
- }
- return missing, nil
-}
diff --git a/new-api/model/model_extra.go b/new-api/model/model_extra.go
deleted file mode 100644
index ce55e94ebabcc737f865733e96531e1a4fa2f688..0000000000000000000000000000000000000000
--- a/new-api/model/model_extra.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package model
-
-func GetModelEnableGroups(modelName string) []string {
- // 确保缓存最新
- GetPricing()
-
- if modelName == "" {
- return make([]string, 0)
- }
-
- modelEnableGroupsLock.RLock()
- groups, ok := modelEnableGroups[modelName]
- modelEnableGroupsLock.RUnlock()
- if !ok {
- return make([]string, 0)
- }
- return groups
-}
-
-// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存)
-func GetModelQuotaTypes(modelName string) []int {
- GetPricing()
-
- modelEnableGroupsLock.RLock()
- quota, ok := modelQuotaTypeMap[modelName]
- modelEnableGroupsLock.RUnlock()
- if !ok {
- return []int{}
- }
- return []int{quota}
-}
diff --git a/new-api/model/model_meta.go b/new-api/model/model_meta.go
deleted file mode 100644
index c80cfd40cea614ba2a85674450c47bb8dc01325a..0000000000000000000000000000000000000000
--- a/new-api/model/model_meta.go
+++ /dev/null
@@ -1,147 +0,0 @@
-package model
-
-import (
- "one-api/common"
- "strconv"
-
- "gorm.io/gorm"
-)
-
-const (
- NameRuleExact = iota
- NameRulePrefix
- NameRuleContains
- NameRuleSuffix
-)
-
-type BoundChannel struct {
- Name string `json:"name"`
- Type int `json:"type"`
-}
-
-type Model struct {
- Id int `json:"id"`
- ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
- Description string `json:"description,omitempty" gorm:"type:text"`
- Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
- Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
- VendorID int `json:"vendor_id,omitempty" gorm:"index"`
- Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
- Status int `json:"status" gorm:"default:1"`
- SyncOfficial int `json:"sync_official" gorm:"default:1"`
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
-
- BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
- EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
- QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
- NameRule int `json:"name_rule" gorm:"default:0"`
-
- MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
- MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
-}
-
-func (mi *Model) Insert() error {
- now := common.GetTimestamp()
- mi.CreatedTime = now
- mi.UpdatedTime = now
- return DB.Create(mi).Error
-}
-
-func IsModelNameDuplicated(id int, name string) (bool, error) {
- if name == "" {
- return false, nil
- }
- var cnt int64
- err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
- return cnt > 0, err
-}
-
-func (mi *Model) Update() error {
- mi.UpdatedTime = common.GetTimestamp()
- return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
- Model(&Model{}).
- Where("id = ?", mi.Id).
- Omit("created_time").
- Select("*").
- Updates(mi).Error
-}
-
-func (mi *Model) Delete() error {
- return DB.Delete(mi).Error
-}
-
-func GetVendorModelCounts() (map[int64]int64, error) {
- var stats []struct {
- VendorID int64
- Count int64
- }
- if err := DB.Model(&Model{}).
- Select("vendor_id as vendor_id, count(*) as count").
- Group("vendor_id").
- Scan(&stats).Error; err != nil {
- return nil, err
- }
- m := make(map[int64]int64, len(stats))
- for _, s := range stats {
- m[s.VendorID] = s.Count
- }
- return m, nil
-}
-
-func GetAllModels(offset int, limit int) ([]*Model, error) {
- var models []*Model
- err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
- return models, err
-}
-
-func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
- result := make(map[string][]BoundChannel)
- if len(modelNames) == 0 {
- return result, nil
- }
- type row struct {
- Model string
- Name string
- Type int
- }
- var rows []row
- err := DB.Table("channels").
- Select("abilities.model as model, channels.name as name, channels.type as type").
- Joins("JOIN abilities ON abilities.channel_id = channels.id").
- Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
- Distinct().
- Scan(&rows).Error
- if err != nil {
- return nil, err
- }
- for _, r := range rows {
- result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
- }
- return result, nil
-}
-
-func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
- var models []*Model
- db := DB.Model(&Model{})
- if keyword != "" {
- like := "%" + keyword + "%"
- db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
- }
- if vendor != "" {
- if vid, err := strconv.Atoi(vendor); err == nil {
- db = db.Where("models.vendor_id = ?", vid)
- } else {
- db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
- }
- }
- var total int64
- if err := db.Count(&total).Error; err != nil {
- return nil, 0, err
- }
- if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
- return nil, 0, err
- }
- return models, total, nil
-}
diff --git a/new-api/model/option.go b/new-api/model/option.go
deleted file mode 100644
index 6fb59627a196a3f9b806dd5598b74e51607c9811..0000000000000000000000000000000000000000
--- a/new-api/model/option.go
+++ /dev/null
@@ -1,457 +0,0 @@
-package model
-
-import (
- "one-api/common"
- "one-api/setting"
- "one-api/setting/config"
- "one-api/setting/operation_setting"
- "one-api/setting/ratio_setting"
- "one-api/setting/system_setting"
- "strconv"
- "strings"
- "time"
-)
-
-type Option struct {
- Key string `json:"key" gorm:"primaryKey"`
- Value string `json:"value"`
-}
-
-func AllOption() ([]*Option, error) {
- var options []*Option
- var err error
- err = DB.Find(&options).Error
- return options, err
-}
-
-func InitOptionMap() {
- common.OptionMapRWMutex.Lock()
- common.OptionMap = make(map[string]string)
-
- // 添加原有的系统配置
- common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
- common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
- common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
- common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
- common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
- common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
- common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
- common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
- common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled)
- common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
- common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
- common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
- common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
- common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
- common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
- common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
- common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
- common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
- common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
- common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
- common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
- common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
- common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
- common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
- common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
- common.OptionMap["SMTPServer"] = ""
- common.OptionMap["SMTPFrom"] = ""
- common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
- common.OptionMap["SMTPAccount"] = ""
- common.OptionMap["SMTPToken"] = ""
- common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
- common.OptionMap["Notice"] = ""
- common.OptionMap["About"] = ""
- common.OptionMap["HomePageContent"] = ""
- common.OptionMap["Footer"] = common.Footer
- common.OptionMap["SystemName"] = common.SystemName
- common.OptionMap["Logo"] = common.Logo
- common.OptionMap["ServerAddress"] = ""
- common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
- common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
- common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
- common.OptionMap["PayAddress"] = ""
- common.OptionMap["CustomCallbackAddress"] = ""
- common.OptionMap["EpayId"] = ""
- common.OptionMap["EpayKey"] = ""
- common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
- common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
- common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
- common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
- common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
- common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
- common.OptionMap["StripePriceId"] = setting.StripePriceId
- common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
- common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
- common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
- common.OptionMap["Chats"] = setting.Chats2JsonString()
- common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
- common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
- common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
- common.OptionMap["GitHubClientId"] = ""
- common.OptionMap["GitHubClientSecret"] = ""
- common.OptionMap["TelegramBotToken"] = ""
- common.OptionMap["TelegramBotName"] = ""
- common.OptionMap["WeChatServerAddress"] = ""
- common.OptionMap["WeChatServerToken"] = ""
- common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
- common.OptionMap["TurnstileSiteKey"] = ""
- common.OptionMap["TurnstileSecretKey"] = ""
- common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
- common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
- common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
- common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
- common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
- common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
- common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
- common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
- common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
- common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
- common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
- common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
- common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
- common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
- common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
- common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
- common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
- common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
- common.OptionMap["TopUpLink"] = common.TopUpLink
- //common.OptionMap["ChatLink"] = common.ChatLink
- //common.OptionMap["ChatLink2"] = common.ChatLink2
- common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
- common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
- common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
- common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
- common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
- common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
- common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
- common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
- common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
- common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
- common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
- common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
- common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
- common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
- common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
- common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
- common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
- common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
- common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
- common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
-
- // 自动添加所有注册的模型配置
- modelConfigs := config.GlobalConfig.ExportAllConfigs()
- for k, v := range modelConfigs {
- common.OptionMap[k] = v
- }
-
- common.OptionMapRWMutex.Unlock()
- loadOptionsFromDatabase()
-}
-
-func loadOptionsFromDatabase() {
- options, _ := AllOption()
- for _, option := range options {
- err := updateOptionMap(option.Key, option.Value)
- if err != nil {
- common.SysLog("failed to update option map: " + err.Error())
- }
- }
-}
-
-func SyncOptions(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing options from database")
- loadOptionsFromDatabase()
- }
-}
-
-func UpdateOption(key string, value string) error {
- // Save to database first
- option := Option{
- Key: key,
- }
- // https://gorm.io/docs/update.html#Save-All-Fields
- DB.FirstOrCreate(&option, Option{Key: key})
- option.Value = value
- // Save is a combination function.
- // If save value does not contain primary key, it will execute Create,
- // otherwise it will execute Update (with all fields).
- DB.Save(&option)
- // Update OptionMap
- return updateOptionMap(key, value)
-}
-
-func updateOptionMap(key string, value string) (err error) {
- common.OptionMapRWMutex.Lock()
- defer common.OptionMapRWMutex.Unlock()
- common.OptionMap[key] = value
-
- // 检查是否是模型配置 - 使用更规范的方式处理
- if handleConfigUpdate(key, value) {
- return nil // 已由配置系统处理
- }
-
- // 处理传统配置项...
- if strings.HasSuffix(key, "Permission") {
- intValue, _ := strconv.Atoi(value)
- switch key {
- case "FileUploadPermission":
- common.FileUploadPermission = intValue
- case "FileDownloadPermission":
- common.FileDownloadPermission = intValue
- case "ImageUploadPermission":
- common.ImageUploadPermission = intValue
- case "ImageDownloadPermission":
- common.ImageDownloadPermission = intValue
- }
- }
- if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
- boolValue := value == "true"
- switch key {
- case "PasswordRegisterEnabled":
- common.PasswordRegisterEnabled = boolValue
- case "PasswordLoginEnabled":
- common.PasswordLoginEnabled = boolValue
- case "EmailVerificationEnabled":
- common.EmailVerificationEnabled = boolValue
- case "GitHubOAuthEnabled":
- common.GitHubOAuthEnabled = boolValue
- case "LinuxDOOAuthEnabled":
- common.LinuxDOOAuthEnabled = boolValue
- case "WeChatAuthEnabled":
- common.WeChatAuthEnabled = boolValue
- case "TelegramOAuthEnabled":
- common.TelegramOAuthEnabled = boolValue
- case "TurnstileCheckEnabled":
- common.TurnstileCheckEnabled = boolValue
- case "RegisterEnabled":
- common.RegisterEnabled = boolValue
- case "EmailDomainRestrictionEnabled":
- common.EmailDomainRestrictionEnabled = boolValue
- case "EmailAliasRestrictionEnabled":
- common.EmailAliasRestrictionEnabled = boolValue
- case "AutomaticDisableChannelEnabled":
- common.AutomaticDisableChannelEnabled = boolValue
- case "AutomaticEnableChannelEnabled":
- common.AutomaticEnableChannelEnabled = boolValue
- case "LogConsumeEnabled":
- common.LogConsumeEnabled = boolValue
- case "DisplayInCurrencyEnabled":
- common.DisplayInCurrencyEnabled = boolValue
- case "DisplayTokenStatEnabled":
- common.DisplayTokenStatEnabled = boolValue
- case "DrawingEnabled":
- common.DrawingEnabled = boolValue
- case "TaskEnabled":
- common.TaskEnabled = boolValue
- case "DataExportEnabled":
- common.DataExportEnabled = boolValue
- case "DefaultCollapseSidebar":
- common.DefaultCollapseSidebar = boolValue
- case "MjNotifyEnabled":
- setting.MjNotifyEnabled = boolValue
- case "MjAccountFilterEnabled":
- setting.MjAccountFilterEnabled = boolValue
- case "MjModeClearEnabled":
- setting.MjModeClearEnabled = boolValue
- case "MjForwardUrlEnabled":
- setting.MjForwardUrlEnabled = boolValue
- case "MjActionCheckSuccessEnabled":
- setting.MjActionCheckSuccessEnabled = boolValue
- case "CheckSensitiveEnabled":
- setting.CheckSensitiveEnabled = boolValue
- case "DemoSiteEnabled":
- operation_setting.DemoSiteEnabled = boolValue
- case "SelfUseModeEnabled":
- operation_setting.SelfUseModeEnabled = boolValue
- case "CheckSensitiveOnPromptEnabled":
- setting.CheckSensitiveOnPromptEnabled = boolValue
- case "ModelRequestRateLimitEnabled":
- setting.ModelRequestRateLimitEnabled = boolValue
- case "StopOnSensitiveEnabled":
- setting.StopOnSensitiveEnabled = boolValue
- case "SMTPSSLEnabled":
- common.SMTPSSLEnabled = boolValue
- case "WorkerAllowHttpImageRequestEnabled":
- system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
- case "DefaultUseAutoGroup":
- setting.DefaultUseAutoGroup = boolValue
- case "ExposeRatioEnabled":
- ratio_setting.SetExposeRatioEnabled(boolValue)
- }
- }
- switch key {
- case "EmailDomainWhitelist":
- common.EmailDomainWhitelist = strings.Split(value, ",")
- case "SMTPServer":
- common.SMTPServer = value
- case "SMTPPort":
- intValue, _ := strconv.Atoi(value)
- common.SMTPPort = intValue
- case "SMTPAccount":
- common.SMTPAccount = value
- case "SMTPFrom":
- common.SMTPFrom = value
- case "SMTPToken":
- common.SMTPToken = value
- case "ServerAddress":
- system_setting.ServerAddress = value
- case "WorkerUrl":
- system_setting.WorkerUrl = value
- case "WorkerValidKey":
- system_setting.WorkerValidKey = value
- case "PayAddress":
- operation_setting.PayAddress = value
- case "Chats":
- err = setting.UpdateChatsByJsonString(value)
- case "AutoGroups":
- err = setting.UpdateAutoGroupsByJsonString(value)
- case "CustomCallbackAddress":
- operation_setting.CustomCallbackAddress = value
- case "EpayId":
- operation_setting.EpayId = value
- case "EpayKey":
- operation_setting.EpayKey = value
- case "Price":
- operation_setting.Price, _ = strconv.ParseFloat(value, 64)
- case "USDExchangeRate":
- operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
- case "MinTopUp":
- operation_setting.MinTopUp, _ = strconv.Atoi(value)
- case "StripeApiSecret":
- setting.StripeApiSecret = value
- case "StripeWebhookSecret":
- setting.StripeWebhookSecret = value
- case "StripePriceId":
- setting.StripePriceId = value
- case "StripeUnitPrice":
- setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
- case "StripeMinTopUp":
- setting.StripeMinTopUp, _ = strconv.Atoi(value)
- case "StripePromotionCodesEnabled":
- setting.StripePromotionCodesEnabled = value == "true"
- case "TopupGroupRatio":
- err = common.UpdateTopupGroupRatioByJSONString(value)
- case "GitHubClientId":
- common.GitHubClientId = value
- case "GitHubClientSecret":
- common.GitHubClientSecret = value
- case "LinuxDOClientId":
- common.LinuxDOClientId = value
- case "LinuxDOClientSecret":
- common.LinuxDOClientSecret = value
- case "LinuxDOMinimumTrustLevel":
- common.LinuxDOMinimumTrustLevel, _ = strconv.Atoi(value)
- case "Footer":
- common.Footer = value
- case "SystemName":
- common.SystemName = value
- case "Logo":
- common.Logo = value
- case "WeChatServerAddress":
- common.WeChatServerAddress = value
- case "WeChatServerToken":
- common.WeChatServerToken = value
- case "WeChatAccountQRCodeImageURL":
- common.WeChatAccountQRCodeImageURL = value
- case "TelegramBotToken":
- common.TelegramBotToken = value
- case "TelegramBotName":
- common.TelegramBotName = value
- case "TurnstileSiteKey":
- common.TurnstileSiteKey = value
- case "TurnstileSecretKey":
- common.TurnstileSecretKey = value
- case "QuotaForNewUser":
- common.QuotaForNewUser, _ = strconv.Atoi(value)
- case "QuotaForInviter":
- common.QuotaForInviter, _ = strconv.Atoi(value)
- case "QuotaForInvitee":
- common.QuotaForInvitee, _ = strconv.Atoi(value)
- case "QuotaRemindThreshold":
- common.QuotaRemindThreshold, _ = strconv.Atoi(value)
- case "PreConsumedQuota":
- common.PreConsumedQuota, _ = strconv.Atoi(value)
- case "ModelRequestRateLimitCount":
- setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
- case "ModelRequestRateLimitDurationMinutes":
- setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
- case "ModelRequestRateLimitSuccessCount":
- setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
- case "ModelRequestRateLimitGroup":
- err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
- case "RetryTimes":
- common.RetryTimes, _ = strconv.Atoi(value)
- case "DataExportInterval":
- common.DataExportInterval, _ = strconv.Atoi(value)
- case "DataExportDefaultTime":
- common.DataExportDefaultTime = value
- case "ModelRatio":
- err = ratio_setting.UpdateModelRatioByJSONString(value)
- case "GroupRatio":
- err = ratio_setting.UpdateGroupRatioByJSONString(value)
- case "GroupGroupRatio":
- err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
- case "UserUsableGroups":
- err = setting.UpdateUserUsableGroupsByJSONString(value)
- case "CompletionRatio":
- err = ratio_setting.UpdateCompletionRatioByJSONString(value)
- case "ModelPrice":
- err = ratio_setting.UpdateModelPriceByJSONString(value)
- case "CacheRatio":
- err = ratio_setting.UpdateCacheRatioByJSONString(value)
- case "ImageRatio":
- err = ratio_setting.UpdateImageRatioByJSONString(value)
- case "AudioRatio":
- err = ratio_setting.UpdateAudioRatioByJSONString(value)
- case "AudioCompletionRatio":
- err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
- case "TopUpLink":
- common.TopUpLink = value
- //case "ChatLink":
- // common.ChatLink = value
- //case "ChatLink2":
- // common.ChatLink2 = value
- case "ChannelDisableThreshold":
- common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
- case "QuotaPerUnit":
- common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
- case "SensitiveWords":
- setting.SensitiveWordsFromString(value)
- case "AutomaticDisableKeywords":
- operation_setting.AutomaticDisableKeywordsFromString(value)
- case "StreamCacheQueueLength":
- setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
- case "PayMethods":
- err = operation_setting.UpdatePayMethodsByJsonString(value)
- }
- return err
-}
-
-// handleConfigUpdate 处理分层配置更新,返回是否已处理
-func handleConfigUpdate(key, value string) bool {
- parts := strings.SplitN(key, ".", 2)
- if len(parts) != 2 {
- return false // 不是分层配置
- }
-
- configName := parts[0]
- configKey := parts[1]
-
- // 获取配置对象
- cfg := config.GlobalConfig.Get(configName)
- if cfg == nil {
- return false // 未注册的配置
- }
-
- // 更新配置
- configMap := map[string]string{
- configKey: value,
- }
- config.UpdateConfigFromMap(cfg, configMap)
-
- return true // 已处理
-}
diff --git a/new-api/model/passkey.go b/new-api/model/passkey.go
deleted file mode 100644
index c2556c450ecb383130d4d268d8cc5ac1320c27d7..0000000000000000000000000000000000000000
--- a/new-api/model/passkey.go
+++ /dev/null
@@ -1,209 +0,0 @@
-package model
-
-import (
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "one-api/common"
- "strings"
- "time"
-
- "github.com/go-webauthn/webauthn/protocol"
- "github.com/go-webauthn/webauthn/webauthn"
- "gorm.io/gorm"
-)
-
-var (
- ErrPasskeyNotFound = errors.New("passkey credential not found")
- ErrFriendlyPasskeyNotFound = errors.New("Passkey 验证失败,请重试或联系管理员")
-)
-
-type PasskeyCredential struct {
- ID int `json:"id" gorm:"primaryKey"`
- UserID int `json:"user_id" gorm:"uniqueIndex;not null"`
- CredentialID string `json:"credential_id" gorm:"type:varchar(512);uniqueIndex;not null"` // base64 encoded
- PublicKey string `json:"public_key" gorm:"type:text;not null"` // base64 encoded
- AttestationType string `json:"attestation_type" gorm:"type:varchar(255)"`
- AAGUID string `json:"aaguid" gorm:"type:varchar(512)"` // base64 encoded
- SignCount uint32 `json:"sign_count" gorm:"default:0"`
- CloneWarning bool `json:"clone_warning"`
- UserPresent bool `json:"user_present"`
- UserVerified bool `json:"user_verified"`
- BackupEligible bool `json:"backup_eligible"`
- BackupState bool `json:"backup_state"`
- Transports string `json:"transports" gorm:"type:text"`
- Attachment string `json:"attachment" gorm:"type:varchar(32)"`
- LastUsedAt *time.Time `json:"last_used_at"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
-}
-
-func (p *PasskeyCredential) TransportList() []protocol.AuthenticatorTransport {
- if p == nil || strings.TrimSpace(p.Transports) == "" {
- return nil
- }
- var transports []string
- if err := json.Unmarshal([]byte(p.Transports), &transports); err != nil {
- return nil
- }
- result := make([]protocol.AuthenticatorTransport, 0, len(transports))
- for _, transport := range transports {
- result = append(result, protocol.AuthenticatorTransport(transport))
- }
- return result
-}
-
-func (p *PasskeyCredential) SetTransports(list []protocol.AuthenticatorTransport) {
- if len(list) == 0 {
- p.Transports = ""
- return
- }
- stringList := make([]string, len(list))
- for i, transport := range list {
- stringList[i] = string(transport)
- }
- encoded, err := json.Marshal(stringList)
- if err != nil {
- return
- }
- p.Transports = string(encoded)
-}
-
-func (p *PasskeyCredential) ToWebAuthnCredential() webauthn.Credential {
- flags := webauthn.CredentialFlags{
- UserPresent: p.UserPresent,
- UserVerified: p.UserVerified,
- BackupEligible: p.BackupEligible,
- BackupState: p.BackupState,
- }
-
- credID, _ := base64.StdEncoding.DecodeString(p.CredentialID)
- pubKey, _ := base64.StdEncoding.DecodeString(p.PublicKey)
- aaguid, _ := base64.StdEncoding.DecodeString(p.AAGUID)
-
- return webauthn.Credential{
- ID: credID,
- PublicKey: pubKey,
- AttestationType: p.AttestationType,
- Transport: p.TransportList(),
- Flags: flags,
- Authenticator: webauthn.Authenticator{
- AAGUID: aaguid,
- SignCount: p.SignCount,
- CloneWarning: p.CloneWarning,
- Attachment: protocol.AuthenticatorAttachment(p.Attachment),
- },
- }
-}
-
-func NewPasskeyCredentialFromWebAuthn(userID int, credential *webauthn.Credential) *PasskeyCredential {
- if credential == nil {
- return nil
- }
- passkey := &PasskeyCredential{
- UserID: userID,
- CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
- PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
- AttestationType: credential.AttestationType,
- AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
- SignCount: credential.Authenticator.SignCount,
- CloneWarning: credential.Authenticator.CloneWarning,
- UserPresent: credential.Flags.UserPresent,
- UserVerified: credential.Flags.UserVerified,
- BackupEligible: credential.Flags.BackupEligible,
- BackupState: credential.Flags.BackupState,
- Attachment: string(credential.Authenticator.Attachment),
- }
- passkey.SetTransports(credential.Transport)
- return passkey
-}
-
-func (p *PasskeyCredential) ApplyValidatedCredential(credential *webauthn.Credential) {
- if credential == nil || p == nil {
- return
- }
- p.CredentialID = base64.StdEncoding.EncodeToString(credential.ID)
- p.PublicKey = base64.StdEncoding.EncodeToString(credential.PublicKey)
- p.AttestationType = credential.AttestationType
- p.AAGUID = base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID)
- p.SignCount = credential.Authenticator.SignCount
- p.CloneWarning = credential.Authenticator.CloneWarning
- p.UserPresent = credential.Flags.UserPresent
- p.UserVerified = credential.Flags.UserVerified
- p.BackupEligible = credential.Flags.BackupEligible
- p.BackupState = credential.Flags.BackupState
- p.Attachment = string(credential.Authenticator.Attachment)
- p.SetTransports(credential.Transport)
-}
-
-func GetPasskeyByUserID(userID int) (*PasskeyCredential, error) {
- if userID == 0 {
- common.SysLog("GetPasskeyByUserID: empty user ID")
- return nil, ErrFriendlyPasskeyNotFound
- }
- var credential PasskeyCredential
- if err := DB.Where("user_id = ?", userID).First(&credential).Error; err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- // 未找到记录是正常情况(用户未绑定),返回 ErrPasskeyNotFound 而不记录日志
- return nil, ErrPasskeyNotFound
- }
- // 只有真正的数据库错误才记录日志
- common.SysLog(fmt.Sprintf("GetPasskeyByUserID: database error for user %d: %v", userID, err))
- return nil, ErrFriendlyPasskeyNotFound
- }
- return &credential, nil
-}
-
-func GetPasskeyByCredentialID(credentialID []byte) (*PasskeyCredential, error) {
- if len(credentialID) == 0 {
- common.SysLog("GetPasskeyByCredentialID: empty credential ID")
- return nil, ErrFriendlyPasskeyNotFound
- }
-
- credIDStr := base64.StdEncoding.EncodeToString(credentialID)
- var credential PasskeyCredential
- if err := DB.Where("credential_id = ?", credIDStr).First(&credential).Error; err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: passkey not found for credential ID length %d", len(credentialID)))
- return nil, ErrFriendlyPasskeyNotFound
- }
- common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: database error for credential ID: %v", err))
- return nil, ErrFriendlyPasskeyNotFound
- }
-
- return &credential, nil
-}
-
-func UpsertPasskeyCredential(credential *PasskeyCredential) error {
- if credential == nil {
- common.SysLog("UpsertPasskeyCredential: nil credential provided")
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- return DB.Transaction(func(tx *gorm.DB) error {
- // 使用Unscoped()进行硬删除,避免唯一索引冲突
- if err := tx.Unscoped().Where("user_id = ?", credential.UserID).Delete(&PasskeyCredential{}).Error; err != nil {
- common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to delete existing credential for user %d: %v", credential.UserID, err))
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- if err := tx.Create(credential).Error; err != nil {
- common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to create credential for user %d: %v", credential.UserID, err))
- return fmt.Errorf("Passkey 保存失败,请重试")
- }
- return nil
- })
-}
-
-func DeletePasskeyByUserID(userID int) error {
- if userID == 0 {
- common.SysLog("DeletePasskeyByUserID: empty user ID")
- return fmt.Errorf("删除失败,请重试")
- }
- // 使用Unscoped()进行硬删除,避免唯一索引冲突
- if err := DB.Unscoped().Where("user_id = ?", userID).Delete(&PasskeyCredential{}).Error; err != nil {
- common.SysLog(fmt.Sprintf("DeletePasskeyByUserID: failed to delete passkey for user %d: %v", userID, err))
- return fmt.Errorf("删除失败,请重试")
- }
- return nil
-}
diff --git a/new-api/model/prefill_group.go b/new-api/model/prefill_group.go
deleted file mode 100644
index e88b6e985db083860697b547e7059ded062a343f..0000000000000000000000000000000000000000
--- a/new-api/model/prefill_group.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package model
-
-import (
- "database/sql/driver"
- "encoding/json"
- "one-api/common"
-
- "gorm.io/gorm"
-)
-
-// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
-// Name 字段保持唯一,用于在前端下拉框中展示。
-// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。
-// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例:
-// ["gpt-4o", "gpt-3.5-turbo"]
-// 设计遵循 3NF,避免冗余,提供灵活扩展能力。
-
-// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取
-type JSONValue json.RawMessage
-
-// Value 实现 driver.Valuer 接口,用于数据库写入
-func (j JSONValue) Value() (driver.Value, error) {
- if j == nil {
- return nil, nil
- }
- return []byte(j), nil
-}
-
-// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
-func (j *JSONValue) Scan(value interface{}) error {
- switch v := value.(type) {
- case nil:
- *j = nil
- return nil
- case []byte:
- // 拷贝底层字节,避免保留底层缓冲区
- b := make([]byte, len(v))
- copy(b, v)
- *j = JSONValue(b)
- return nil
- case string:
- *j = JSONValue([]byte(v))
- return nil
- default:
- // 其他类型尝试序列化为 JSON
- b, err := json.Marshal(v)
- if err != nil {
- return err
- }
- *j = JSONValue(b)
- return nil
- }
-}
-
-// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
-func (j JSONValue) MarshalJSON() ([]byte, error) {
- if j == nil {
- return []byte("null"), nil
- }
- return j, nil
-}
-
-// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
-func (j *JSONValue) UnmarshalJSON(data []byte) error {
- if data == nil {
- *j = nil
- return nil
- }
- b := make([]byte, len(data))
- copy(b, data)
- *j = JSONValue(b)
- return nil
-}
-
-type PrefillGroup struct {
- Id int `json:"id"`
- Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
- Type string `json:"type" gorm:"size:32;index;not null"`
- Items JSONValue `json:"items" gorm:"type:json"`
- Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
-}
-
-// Insert 新建组
-func (g *PrefillGroup) Insert() error {
- now := common.GetTimestamp()
- g.CreatedTime = now
- g.UpdatedTime = now
- return DB.Create(g).Error
-}
-
-// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
-func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
- if name == "" {
- return false, nil
- }
- var cnt int64
- err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
- return cnt > 0, err
-}
-
-// Update 更新组
-func (g *PrefillGroup) Update() error {
- g.UpdatedTime = common.GetTimestamp()
- return DB.Save(g).Error
-}
-
-// DeleteByID 根据 ID 删除组
-func DeletePrefillGroupByID(id int) error {
- return DB.Delete(&PrefillGroup{}, id).Error
-}
-
-// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
-func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
- var groups []*PrefillGroup
- query := DB.Model(&PrefillGroup{})
- if groupType != "" {
- query = query.Where("type = ?", groupType)
- }
- if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
- return nil, err
- }
- return groups, nil
-}
diff --git a/new-api/model/pricing.go b/new-api/model/pricing.go
deleted file mode 100644
index 4cb82ba9e2bde406ac93d1be24611e40edddef5a..0000000000000000000000000000000000000000
--- a/new-api/model/pricing.go
+++ /dev/null
@@ -1,312 +0,0 @@
-package model
-
-import (
- "encoding/json"
- "fmt"
- "strings"
-
- "one-api/common"
- "one-api/constant"
- "one-api/setting/ratio_setting"
- "one-api/types"
- "sync"
- "time"
-)
-
-type Pricing struct {
- ModelName string `json:"model_name"`
- Description string `json:"description,omitempty"`
- Icon string `json:"icon,omitempty"`
- Tags string `json:"tags,omitempty"`
- VendorID int `json:"vendor_id,omitempty"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- OwnerBy string `json:"owner_by"`
- CompletionRatio float64 `json:"completion_ratio"`
- EnableGroup []string `json:"enable_groups"`
- SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
-}
-
-type PricingVendor struct {
- ID int `json:"id"`
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- Icon string `json:"icon,omitempty"`
-}
-
-var (
- pricingMap []Pricing
- vendorsList []PricingVendor
- supportedEndpointMap map[string]common.EndpointInfo
- lastGetPricingTime time.Time
- updatePricingLock sync.Mutex
-
- // 缓存映射:模型名 -> 启用分组 / 计费类型
- modelEnableGroups = make(map[string][]string)
- modelQuotaTypeMap = make(map[string]int)
- modelEnableGroupsLock = sync.RWMutex{}
-)
-
-var (
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- modelSupportEndpointsLock = sync.RWMutex{}
-)
-
-func GetPricing() []Pricing {
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- updatePricingLock.Lock()
- defer updatePricingLock.Unlock()
- // Double check after acquiring the lock
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- modelSupportEndpointsLock.Lock()
- defer modelSupportEndpointsLock.Unlock()
- updatePricing()
- }
- }
- return pricingMap
-}
-
-// GetVendors 返回当前定价接口使用到的供应商信息
-func GetVendors() []PricingVendor {
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- // 保证先刷新一次
- GetPricing()
- }
- return vendorsList
-}
-
-func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
- if model == "" {
- return make([]constant.EndpointType, 0)
- }
- modelSupportEndpointsLock.RLock()
- defer modelSupportEndpointsLock.RUnlock()
- if endpoints, ok := modelSupportEndpointTypes[model]; ok {
- return endpoints
- }
- return make([]constant.EndpointType, 0)
-}
-
-func updatePricing() {
- //modelRatios := common.GetModelRatios()
- enableAbilities, err := GetAllEnableAbilityWithChannels()
- if err != nil {
- common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
- return
- }
- // 预加载模型元数据与供应商一次,避免循环查询
- var allMeta []Model
- _ = DB.Find(&allMeta).Error
- metaMap := make(map[string]*Model)
- prefixList := make([]*Model, 0)
- suffixList := make([]*Model, 0)
- containsList := make([]*Model, 0)
- for i := range allMeta {
- m := &allMeta[i]
- if m.NameRule == NameRuleExact {
- metaMap[m.ModelName] = m
- } else {
- switch m.NameRule {
- case NameRulePrefix:
- prefixList = append(prefixList, m)
- case NameRuleSuffix:
- suffixList = append(suffixList, m)
- case NameRuleContains:
- containsList = append(containsList, m)
- }
- }
- }
-
- // 将非精确规则模型匹配到 metaMap
- for _, m := range prefixList {
- for _, pricingModel := range enableAbilities {
- if strings.HasPrefix(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
- for _, m := range suffixList {
- for _, pricingModel := range enableAbilities {
- if strings.HasSuffix(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
- for _, m := range containsList {
- for _, pricingModel := range enableAbilities {
- if strings.Contains(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
-
- // 预加载供应商
- var vendors []Vendor
- _ = DB.Find(&vendors).Error
- vendorMap := make(map[int]*Vendor)
- for i := range vendors {
- vendorMap[vendors[i].Id] = &vendors[i]
- }
-
- // 初始化默认供应商映射
- initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
-
- // 构建对前端友好的供应商列表
- vendorsList = make([]PricingVendor, 0, len(vendorMap))
- for _, v := range vendorMap {
- vendorsList = append(vendorsList, PricingVendor{
- ID: v.Id,
- Name: v.Name,
- Description: v.Description,
- Icon: v.Icon,
- })
- }
-
- modelGroupsMap := make(map[string]*types.Set[string])
-
- for _, ability := range enableAbilities {
- groups, ok := modelGroupsMap[ability.Model]
- if !ok {
- groups = types.NewSet[string]()
- modelGroupsMap[ability.Model] = groups
- }
- groups.Add(ability.Group)
- }
-
- //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
- modelSupportEndpointsStr := make(map[string][]string)
-
- // 先根据已有能力填充原生端点
- for _, ability := range enableAbilities {
- endpoints := modelSupportEndpointsStr[ability.Model]
- channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
- for _, channelType := range channelTypes {
- if !common.StringsContains(endpoints, string(channelType)) {
- endpoints = append(endpoints, string(channelType))
- }
- }
- modelSupportEndpointsStr[ability.Model] = endpoints
- }
-
- // 再补充模型自定义端点
- for modelName, meta := range metaMap {
- if strings.TrimSpace(meta.Endpoints) == "" {
- continue
- }
- var raw map[string]interface{}
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
- endpoints := modelSupportEndpointsStr[modelName]
- for k := range raw {
- if !common.StringsContains(endpoints, k) {
- endpoints = append(endpoints, k)
- }
- }
- modelSupportEndpointsStr[modelName] = endpoints
- }
- }
-
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- for model, endpoints := range modelSupportEndpointsStr {
- supportedEndpoints := make([]constant.EndpointType, 0)
- for _, endpointStr := range endpoints {
- endpointType := constant.EndpointType(endpointStr)
- supportedEndpoints = append(supportedEndpoints, endpointType)
- }
- modelSupportEndpointTypes[model] = supportedEndpoints
- }
-
- // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
- supportedEndpointMap = make(map[string]common.EndpointInfo)
- // 1. 默认端点
- for _, endpoints := range modelSupportEndpointTypes {
- for _, et := range endpoints {
- if info, ok := common.GetDefaultEndpointInfo(et); ok {
- if _, exists := supportedEndpointMap[string(et)]; !exists {
- supportedEndpointMap[string(et)] = info
- }
- }
- }
- }
- // 2. 自定义端点(models 表)覆盖默认
- for _, meta := range metaMap {
- if strings.TrimSpace(meta.Endpoints) == "" {
- continue
- }
- var raw map[string]interface{}
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
- for k, v := range raw {
- switch val := v.(type) {
- case string:
- supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
- case map[string]interface{}:
- ep := common.EndpointInfo{Method: "POST"}
- if p, ok := val["path"].(string); ok {
- ep.Path = p
- }
- if m, ok := val["method"].(string); ok {
- ep.Method = strings.ToUpper(m)
- }
- supportedEndpointMap[k] = ep
- default:
- // ignore unsupported types
- }
- }
- }
- }
-
- pricingMap = make([]Pricing, 0)
- for model, groups := range modelGroupsMap {
- pricing := Pricing{
- ModelName: model,
- EnableGroup: groups.Items(),
- SupportedEndpointTypes: modelSupportEndpointTypes[model],
- }
-
- // 补充模型元数据(描述、标签、供应商、状态)
- if meta, ok := metaMap[model]; ok {
- // 若模型被禁用(status!=1),则直接跳过,不返回给前端
- if meta.Status != 1 {
- continue
- }
- pricing.Description = meta.Description
- pricing.Icon = meta.Icon
- pricing.Tags = meta.Tags
- pricing.VendorID = meta.VendorID
- }
- modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
- if findPrice {
- pricing.ModelPrice = modelPrice
- pricing.QuotaType = 1
- } else {
- modelRatio, _, _ := ratio_setting.GetModelRatio(model)
- pricing.ModelRatio = modelRatio
- pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
- pricing.QuotaType = 0
- }
- pricingMap = append(pricingMap, pricing)
- }
-
- // 刷新缓存映射,供高并发快速查询
- modelEnableGroupsLock.Lock()
- modelEnableGroups = make(map[string][]string)
- modelQuotaTypeMap = make(map[string]int)
- for _, p := range pricingMap {
- modelEnableGroups[p.ModelName] = p.EnableGroup
- modelQuotaTypeMap[p.ModelName] = p.QuotaType
- }
- modelEnableGroupsLock.Unlock()
-
- lastGetPricingTime = time.Now()
-}
-
-// GetSupportedEndpointMap 返回全局端点到路径的映射
-func GetSupportedEndpointMap() map[string]common.EndpointInfo {
- return supportedEndpointMap
-}
diff --git a/new-api/model/pricing_default.go b/new-api/model/pricing_default.go
deleted file mode 100644
index 976aefb1e49a04c16a6fce96e6c5678dc91f5824..0000000000000000000000000000000000000000
--- a/new-api/model/pricing_default.go
+++ /dev/null
@@ -1,128 +0,0 @@
-package model
-
-import (
- "strings"
-)
-
-// 简化的供应商映射规则
-var defaultVendorRules = map[string]string{
- "gpt": "OpenAI",
- "dall-e": "OpenAI",
- "whisper": "OpenAI",
- "o1": "OpenAI",
- "o3": "OpenAI",
- "claude": "Anthropic",
- "gemini": "Google",
- "moonshot": "Moonshot",
- "kimi": "Moonshot",
- "chatglm": "智谱",
- "glm-": "智谱",
- "qwen": "阿里巴巴",
- "deepseek": "DeepSeek",
- "abab": "MiniMax",
- "ernie": "百度",
- "spark": "讯飞",
- "hunyuan": "腾讯",
- "command": "Cohere",
- "@cf/": "Cloudflare",
- "360": "360",
- "yi": "零一万物",
- "jina": "Jina",
- "mistral": "Mistral",
- "grok": "xAI",
- "llama": "Meta",
- "doubao": "字节跳动",
- "kling": "快手",
- "jimeng": "即梦",
- "vidu": "Vidu",
-}
-
-// 供应商默认图标映射
-var defaultVendorIcons = map[string]string{
- "OpenAI": "OpenAI",
- "Anthropic": "Claude.Color",
- "Google": "Gemini.Color",
- "Moonshot": "Moonshot",
- "智谱": "Zhipu.Color",
- "阿里巴巴": "Qwen.Color",
- "DeepSeek": "DeepSeek.Color",
- "MiniMax": "Minimax.Color",
- "百度": "Wenxin.Color",
- "讯飞": "Spark.Color",
- "腾讯": "Hunyuan.Color",
- "Cohere": "Cohere.Color",
- "Cloudflare": "Cloudflare.Color",
- "360": "Ai360.Color",
- "零一万物": "Yi.Color",
- "Jina": "Jina",
- "Mistral": "Mistral.Color",
- "xAI": "XAI",
- "Meta": "Ollama",
- "字节跳动": "Doubao.Color",
- "快手": "Kling.Color",
- "即梦": "Jimeng.Color",
- "Vidu": "Vidu",
- "微软": "AzureAI",
- "Microsoft": "AzureAI",
- "Azure": "AzureAI",
-}
-
-// initDefaultVendorMapping 简化的默认供应商映射
-func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) {
- for _, ability := range enableAbilities {
- modelName := ability.Model
- if _, exists := metaMap[modelName]; exists {
- continue
- }
-
- // 匹配供应商
- vendorID := 0
- modelLower := strings.ToLower(modelName)
- for pattern, vendorName := range defaultVendorRules {
- if strings.Contains(modelLower, pattern) {
- vendorID = getOrCreateVendor(vendorName, vendorMap)
- break
- }
- }
-
- // 创建模型元数据
- metaMap[modelName] = &Model{
- ModelName: modelName,
- VendorID: vendorID,
- Status: 1,
- NameRule: NameRuleExact,
- }
- }
-}
-
-// 查找或创建供应商
-func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int {
- // 查找现有供应商
- for id, vendor := range vendorMap {
- if vendor.Name == vendorName {
- return id
- }
- }
-
- // 创建新供应商
- newVendor := &Vendor{
- Name: vendorName,
- Status: 1,
- Icon: getDefaultVendorIcon(vendorName),
- }
-
- if err := newVendor.Insert(); err != nil {
- return 0
- }
-
- vendorMap[newVendor.Id] = newVendor
- return newVendor.Id
-}
-
-// 获取供应商默认图标
-func getDefaultVendorIcon(vendorName string) string {
- if icon, exists := defaultVendorIcons[vendorName]; exists {
- return icon
- }
- return ""
-}
diff --git a/new-api/model/pricing_refresh.go b/new-api/model/pricing_refresh.go
deleted file mode 100644
index 362b5b0a298a2ef07c1c3e9793a1be7ba69709fe..0000000000000000000000000000000000000000
--- a/new-api/model/pricing_refresh.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package model
-
-// RefreshPricing 强制立即重新计算与定价相关的缓存。
-// 该方法用于需要最新数据的内部管理 API,
-// 因此会绕过默认的 1 分钟延迟刷新。
-func RefreshPricing() {
- updatePricingLock.Lock()
- defer updatePricingLock.Unlock()
-
- modelSupportEndpointsLock.Lock()
- defer modelSupportEndpointsLock.Unlock()
-
- updatePricing()
-}
diff --git a/new-api/model/redemption.go b/new-api/model/redemption.go
deleted file mode 100644
index 6e48bc61ff8d5fab6db7bde8a63fec1efaf6ba93..0000000000000000000000000000000000000000
--- a/new-api/model/redemption.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "one-api/logger"
- "strconv"
-
- "gorm.io/gorm"
-)
-
-type Redemption struct {
- Id int `json:"id"`
- UserId int `json:"user_id"`
- Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index"`
- Quota int `json:"quota" gorm:"default:100"`
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
- Count int `json:"count" gorm:"-:all"` // only for api request
- UsedUserId int `json:"used_user_id"`
- DeletedAt gorm.DeletedAt `gorm:"index"`
- ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
-}
-
-func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
- // 开始事务
- tx := DB.Begin()
- if tx.Error != nil {
- return nil, 0, tx.Error
- }
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
-
- // 获取总数
- err = tx.Model(&Redemption{}).Count(&total).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // 获取分页数据
- err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // 提交事务
- if err = tx.Commit().Error; err != nil {
- return nil, 0, err
- }
-
- return redemptions, total, nil
-}
-
-func SearchRedemptions(keyword string, startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
- tx := DB.Begin()
- if tx.Error != nil {
- return nil, 0, tx.Error
- }
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
-
- // Build query based on keyword type
- query := tx.Model(&Redemption{})
-
- // Only try to convert to ID if the string represents a valid integer
- if id, err := strconv.Atoi(keyword); err == nil {
- query = query.Where("id = ? OR name LIKE ?", id, keyword+"%")
- } else {
- query = query.Where("name LIKE ?", keyword+"%")
- }
-
- // Get total count
- err = query.Count(&total).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // Get paginated data
- err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- if err = tx.Commit().Error; err != nil {
- return nil, 0, err
- }
-
- return redemptions, total, nil
-}
-
-func GetRedemptionById(id int) (*Redemption, error) {
- if id == 0 {
- return nil, errors.New("id 为空!")
- }
- redemption := Redemption{Id: id}
- var err error = nil
- err = DB.First(&redemption, "id = ?", id).Error
- return &redemption, err
-}
-
-func Redeem(key string, userId int) (quota int, err error) {
- if key == "" {
- return 0, errors.New("未提供兑换码")
- }
- if userId == 0 {
- return 0, errors.New("无效的 user id")
- }
- redemption := &Redemption{}
-
- keyCol := "`key`"
- if common.UsingPostgreSQL {
- keyCol = `"key"`
- }
- common.RandomSleep()
- err = DB.Transaction(func(tx *gorm.DB) error {
- err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
- if err != nil {
- return errors.New("无效的兑换码")
- }
- if redemption.Status != common.RedemptionCodeStatusEnabled {
- return errors.New("该兑换码已被使用")
- }
- if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
- return errors.New("该兑换码已过期")
- }
- err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
- if err != nil {
- return err
- }
- redemption.RedeemedTime = common.GetTimestamp()
- redemption.Status = common.RedemptionCodeStatusUsed
- redemption.UsedUserId = userId
- err = tx.Save(redemption).Error
- return err
- })
- if err != nil {
- return 0, errors.New("兑换失败," + err.Error())
- }
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
- return redemption.Quota, nil
-}
-
-func (redemption *Redemption) Insert() error {
- var err error
- err = DB.Create(redemption).Error
- return err
-}
-
-func (redemption *Redemption) SelectUpdate() error {
- // This can update zero values
- return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error
-}
-
-// Update Make sure your token's fields is completed, because this will update non-zero values
-func (redemption *Redemption) Update() error {
- var err error
- err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
- return err
-}
-
-func (redemption *Redemption) Delete() error {
- var err error
- err = DB.Delete(redemption).Error
- return err
-}
-
-func DeleteRedemptionById(id int) (err error) {
- if id == 0 {
- return errors.New("id 为空!")
- }
- redemption := Redemption{Id: id}
- err = DB.Where(redemption).First(&redemption).Error
- if err != nil {
- return err
- }
- return redemption.Delete()
-}
-
-func DeleteInvalidRedemptions() (int64, error) {
- now := common.GetTimestamp()
- result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
- return result.RowsAffected, result.Error
-}
diff --git a/new-api/model/setup.go b/new-api/model/setup.go
deleted file mode 100644
index daf4d32c7456e32e9a375210e014a77d02750098..0000000000000000000000000000000000000000
--- a/new-api/model/setup.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package model
-
-type Setup struct {
- ID uint `json:"id" gorm:"primaryKey"`
- Version string `json:"version" gorm:"type:varchar(50);not null"`
- InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
-}
-
-func GetSetup() *Setup {
- var setup Setup
- err := DB.First(&setup).Error
- if err != nil {
- return nil
- }
- return &setup
-}
diff --git a/new-api/model/task.go b/new-api/model/task.go
deleted file mode 100644
index e490112be1dc989810d360f1b75b8c02f7ef35fd..0000000000000000000000000000000000000000
--- a/new-api/model/task.go
+++ /dev/null
@@ -1,365 +0,0 @@
-package model
-
-import (
- "database/sql/driver"
- "encoding/json"
- "one-api/constant"
- commonRelay "one-api/relay/common"
- "time"
-)
-
-type TaskStatus string
-
-const (
- TaskStatusNotStart TaskStatus = "NOT_START"
- TaskStatusSubmitted = "SUBMITTED"
- TaskStatusQueued = "QUEUED"
- TaskStatusInProgress = "IN_PROGRESS"
- TaskStatusFailure = "FAILURE"
- TaskStatusSuccess = "SUCCESS"
- TaskStatusUnknown = "UNKNOWN"
-)
-
-type Task struct {
- ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
- CreatedAt int64 `json:"created_at" gorm:"index"`
- UpdatedAt int64 `json:"updated_at"`
- TaskID string `json:"task_id" gorm:"type:varchar(191);index"` // 第三方id,不一定有/ song id\ Task id
- Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
- UserId int `json:"user_id" gorm:"index"`
- ChannelId int `json:"channel_id" gorm:"index"`
- Quota int `json:"quota"`
- Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
- Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
- FailReason string `json:"fail_reason"`
- SubmitTime int64 `json:"submit_time" gorm:"index"`
- StartTime int64 `json:"start_time" gorm:"index"`
- FinishTime int64 `json:"finish_time" gorm:"index"`
- Progress string `json:"progress" gorm:"type:varchar(20);index"`
- Properties Properties `json:"properties" gorm:"type:json"`
-
- Data json.RawMessage `json:"data" gorm:"type:json"`
-}
-
-func (t *Task) SetData(data any) {
- b, _ := json.Marshal(data)
- t.Data = json.RawMessage(b)
-}
-
-func (t *Task) GetData(v any) error {
- err := json.Unmarshal(t.Data, &v)
- return err
-}
-
-type Properties struct {
- Input string `json:"input"`
-}
-
-func (m *Properties) Scan(val interface{}) error {
- bytesValue, _ := val.([]byte)
- return json.Unmarshal(bytesValue, m)
-}
-
-func (m Properties) Value() (driver.Value, error) {
- return json.Marshal(m)
-}
-
-// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
-type SyncTaskQueryParams struct {
- Platform constant.TaskPlatform
- ChannelID string
- TaskID string
- UserID string
- Action string
- Status string
- StartTimestamp int64
- EndTimestamp int64
- UserIDs []int
-}
-
-func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
- t := &Task{
- UserId: relayInfo.UserId,
- SubmitTime: time.Now().Unix(),
- Status: TaskStatusNotStart,
- Progress: "0%",
- ChannelId: relayInfo.ChannelId,
- Platform: platform,
- }
- return t
-}
-
-func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
- var tasks []*Task
- var err error
-
- // 初始化查询构建器
- query := DB.Where("user_id = ?", userId)
-
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.Platform != "" {
- query = query.Where("platform = ?", queryParams.Platform)
- }
- if queryParams.StartTimestamp != 0 {
- // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
-
- // 获取数据
- err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
- if err != nil {
- return nil
- }
-
- return tasks
-}
-
-func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
- var tasks []*Task
- var err error
-
- // 初始化查询构建器
- query := DB
-
- // 添加过滤条件
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.Platform != "" {
- query = query.Where("platform = ?", queryParams.Platform)
- }
- if queryParams.UserID != "" {
- query = query.Where("user_id = ?", queryParams.UserID)
- }
- if len(queryParams.UserIDs) != 0 {
- query = query.Where("user_id in (?)", queryParams.UserIDs)
- }
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.StartTimestamp != 0 {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
-
- // 获取数据
- err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
- if err != nil {
- return nil
- }
-
- return tasks
-}
-
-func GetAllUnFinishSyncTasks(limit int) []*Task {
- var tasks []*Task
- var err error
- // get all tasks progress is not 100%
- err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
- if err != nil {
- return nil
- }
- return tasks
-}
-
-func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
- if taskId == "" {
- return nil, false, nil
- }
- var task *Task
- var err error
- err = DB.Where("task_id = ?", taskId).First(&task).Error
- exist, err := RecordExist(err)
- if err != nil {
- return nil, false, err
- }
- return task, exist, err
-}
-
-func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
- if taskId == "" {
- return nil, false, nil
- }
- var task *Task
- var err error
- err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
- First(&task).Error
- exist, err := RecordExist(err)
- if err != nil {
- return nil, false, err
- }
- return task, exist, err
-}
-
-func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
- if len(taskIds) == 0 {
- return nil, nil
- }
- var task []*Task
- var err error
- err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
- Find(&task).Error
- if err != nil {
- return nil, err
- }
- return task, nil
-}
-
-func TaskUpdateProgress(id int64, progress string) error {
- return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
-func (Task *Task) Insert() error {
- var err error
- err = DB.Create(Task).Error
- return err
-}
-
-func (Task *Task) Update() error {
- var err error
- err = DB.Save(Task).Error
- return err
-}
-
-func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
- if len(TaskIds) == 0 {
- return nil
- }
- return DB.Model(&Task{}).
- Where("task_id in (?)", TaskIds).
- Updates(params).Error
-}
-
-func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
- if len(taskIDs) == 0 {
- return nil
- }
- return DB.Model(&Task{}).
- Where("id in (?)", taskIDs).
- Updates(params).Error
-}
-
-func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
- if len(ids) == 0 {
- return nil
- }
- return DB.Model(&Task{}).
- Where("id in (?)", ids).
- Updates(params).Error
-}
-
-type TaskQuotaUsage struct {
- Mode string `json:"mode"`
- Count float64 `json:"count"`
-}
-
-func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
- query := DB.Model(Task{})
- // 添加过滤条件
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.UserID != "" {
- query = query.Where("user_id = ?", queryParams.UserID)
- }
- if len(queryParams.UserIDs) != 0 {
- query = query.Where("user_id in (?)", queryParams.UserIDs)
- }
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.StartTimestamp != 0 {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
- return stat, err
-}
-
-// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
-func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
- var total int64
- query := DB.Model(&Task{})
- if queryParams.ChannelID != "" {
- query = query.Where("channel_id = ?", queryParams.ChannelID)
- }
- if queryParams.Platform != "" {
- query = query.Where("platform = ?", queryParams.Platform)
- }
- if queryParams.UserID != "" {
- query = query.Where("user_id = ?", queryParams.UserID)
- }
- if len(queryParams.UserIDs) != 0 {
- query = query.Where("user_id in (?)", queryParams.UserIDs)
- }
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.StartTimestamp != 0 {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- _ = query.Count(&total).Error
- return total
-}
-
-// TaskCountAllUserTask returns total tasks for given user
-func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
- var total int64
- query := DB.Model(&Task{}).Where("user_id = ?", userId)
- if queryParams.TaskID != "" {
- query = query.Where("task_id = ?", queryParams.TaskID)
- }
- if queryParams.Action != "" {
- query = query.Where("action = ?", queryParams.Action)
- }
- if queryParams.Status != "" {
- query = query.Where("status = ?", queryParams.Status)
- }
- if queryParams.Platform != "" {
- query = query.Where("platform = ?", queryParams.Platform)
- }
- if queryParams.StartTimestamp != 0 {
- query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
- }
- if queryParams.EndTimestamp != 0 {
- query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
- }
- _ = query.Count(&total).Error
- return total
-}
diff --git a/new-api/model/token.go b/new-api/model/token.go
deleted file mode 100644
index a42d832d51d3d2076ed76c95854d6e966f93dadc..0000000000000000000000000000000000000000
--- a/new-api/model/token.go
+++ /dev/null
@@ -1,363 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "strings"
-
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
-)
-
-type Token struct {
- Id int `json:"id"`
- UserId int `json:"user_id" gorm:"index"`
- Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index" `
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
- ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
- RemainQuota int `json:"remain_quota" gorm:"default:0"`
- UnlimitedQuota bool `json:"unlimited_quota"`
- ModelLimitsEnabled bool `json:"model_limits_enabled"`
- ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
- AllowIps *string `json:"allow_ips" gorm:"default:''"`
- UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
- Group string `json:"group" gorm:"default:''"`
- DeletedAt gorm.DeletedAt `gorm:"index"`
-}
-
-func (token *Token) Clean() {
- token.Key = ""
-}
-
-func (token *Token) GetIpLimitsMap() map[string]any {
- // delete empty spaces
- //split with \n
- ipLimitsMap := make(map[string]any)
- if token.AllowIps == nil {
- return ipLimitsMap
- }
- cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
- if cleanIps == "" {
- return ipLimitsMap
- }
- ips := strings.Split(cleanIps, "\n")
- for _, ip := range ips {
- ip = strings.TrimSpace(ip)
- ip = strings.ReplaceAll(ip, ",", "")
- if common.IsIP(ip) {
- ipLimitsMap[ip] = true
- }
- }
- return ipLimitsMap
-}
-
-func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
- var tokens []*Token
- var err error
- err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
- return tokens, err
-}
-
-func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
- if token != "" {
- token = strings.Trim(token, "sk-")
- }
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
- return tokens, err
-}
-
-func ValidateUserToken(key string) (token *Token, err error) {
- if key == "" {
- return nil, errors.New("未提供令牌")
- }
- token, err = GetTokenByKey(key, false)
- if err == nil {
- if token.Status == common.TokenStatusExhausted {
- keyPrefix := key[:3]
- keySuffix := key[len(key)-3:]
- return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
- } else if token.Status == common.TokenStatusExpired {
- return token, errors.New("该令牌已过期")
- }
- if token.Status != common.TokenStatusEnabled {
- return token, errors.New("该令牌状态不可用")
- }
- if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
- if !common.RedisEnabled {
- token.Status = common.TokenStatusExpired
- err := token.SelectUpdate()
- if err != nil {
- common.SysLog("failed to update token status" + err.Error())
- }
- }
- return token, errors.New("该令牌已过期")
- }
- if !token.UnlimitedQuota && token.RemainQuota <= 0 {
- if !common.RedisEnabled {
- // in this case, we can make sure the token is exhausted
- token.Status = common.TokenStatusExhausted
- err := token.SelectUpdate()
- if err != nil {
- common.SysLog("failed to update token status" + err.Error())
- }
- }
- keyPrefix := key[:3]
- keySuffix := key[len(key)-3:]
- return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
- }
- return token, nil
- }
- return nil, errors.New("无效的令牌")
-}
-
-func GetTokenByIds(id int, userId int) (*Token, error) {
- if id == 0 || userId == 0 {
- return nil, errors.New("id 或 userId 为空!")
- }
- token := Token{Id: id, UserId: userId}
- var err error = nil
- err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
- return &token, err
-}
-
-func GetTokenById(id int) (*Token, error) {
- if id == 0 {
- return nil, errors.New("id 为空!")
- }
- token := Token{Id: id}
- var err error = nil
- err = DB.First(&token, "id = ?", id).Error
- if shouldUpdateRedis(true, err) {
- gopool.Go(func() {
- if err := cacheSetToken(token); err != nil {
- common.SysLog("failed to update user status cache: " + err.Error())
- }
- })
- }
- return &token, err
-}
-
-func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) && token != nil {
- gopool.Go(func() {
- if err := cacheSetToken(*token); err != nil {
- common.SysLog("failed to update user status cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- // Try Redis first
- token, err := cacheGetTokenByKey(key)
- if err == nil {
- return token, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
- return token, err
-}
-
-func (token *Token) Insert() error {
- var err error
- err = DB.Create(token).Error
- return err
-}
-
-// Update Make sure your token's fields is completed, because this will update non-zero values
-func (token *Token) Update() (err error) {
- defer func() {
- if shouldUpdateRedis(true, err) {
- gopool.Go(func() {
- err := cacheSetToken(*token)
- if err != nil {
- common.SysLog("failed to update token cache: " + err.Error())
- }
- })
- }
- }()
- err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
- "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
- return err
-}
-
-func (token *Token) SelectUpdate() (err error) {
- defer func() {
- if shouldUpdateRedis(true, err) {
- gopool.Go(func() {
- err := cacheSetToken(*token)
- if err != nil {
- common.SysLog("failed to update token cache: " + err.Error())
- }
- })
- }
- }()
- // This can update zero values
- return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
-}
-
-func (token *Token) Delete() (err error) {
- defer func() {
- if shouldUpdateRedis(true, err) {
- gopool.Go(func() {
- err := cacheDeleteToken(token.Key)
- if err != nil {
- common.SysLog("failed to delete token cache: " + err.Error())
- }
- })
- }
- }()
- err = DB.Delete(token).Error
- return err
-}
-
-func (token *Token) IsModelLimitsEnabled() bool {
- return token.ModelLimitsEnabled
-}
-
-func (token *Token) GetModelLimits() []string {
- if token.ModelLimits == "" {
- return []string{}
- }
- return strings.Split(token.ModelLimits, ",")
-}
-
-func (token *Token) GetModelLimitsMap() map[string]bool {
- limits := token.GetModelLimits()
- limitsMap := make(map[string]bool)
- for _, limit := range limits {
- limitsMap[limit] = true
- }
- return limitsMap
-}
-
-func DisableModelLimits(tokenId int) error {
- token, err := GetTokenById(tokenId)
- if err != nil {
- return err
- }
- token.ModelLimitsEnabled = false
- token.ModelLimits = ""
- return token.Update()
-}
-
-func DeleteTokenById(id int, userId int) (err error) {
- // Why we need userId here? In case user want to delete other's token.
- if id == 0 || userId == 0 {
- return errors.New("id 或 userId 为空!")
- }
- token := Token{Id: id, UserId: userId}
- err = DB.Where(token).First(&token).Error
- if err != nil {
- return err
- }
- return token.Delete()
-}
-
-func IncreaseTokenQuota(id int, key string, quota int) (err error) {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- if common.RedisEnabled {
- gopool.Go(func() {
- err := cacheIncrTokenQuota(key, int64(quota))
- if err != nil {
- common.SysLog("failed to increase token quota: " + err.Error())
- }
- })
- }
- if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
- return nil
- }
- return increaseTokenQuota(id, quota)
-}
-
-func increaseTokenQuota(id int, quota int) (err error) {
- err = DB.Model(&Token{}).Where("id = ?", id).Updates(
- map[string]interface{}{
- "remain_quota": gorm.Expr("remain_quota + ?", quota),
- "used_quota": gorm.Expr("used_quota - ?", quota),
- "accessed_time": common.GetTimestamp(),
- },
- ).Error
- return err
-}
-
-func DecreaseTokenQuota(id int, key string, quota int) (err error) {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- if common.RedisEnabled {
- gopool.Go(func() {
- err := cacheDecrTokenQuota(key, int64(quota))
- if err != nil {
- common.SysLog("failed to decrease token quota: " + err.Error())
- }
- })
- }
- if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
- return nil
- }
- return decreaseTokenQuota(id, quota)
-}
-
-func decreaseTokenQuota(id int, quota int) (err error) {
- err = DB.Model(&Token{}).Where("id = ?", id).Updates(
- map[string]interface{}{
- "remain_quota": gorm.Expr("remain_quota - ?", quota),
- "used_quota": gorm.Expr("used_quota + ?", quota),
- "accessed_time": common.GetTimestamp(),
- },
- ).Error
- return err
-}
-
-// CountUserTokens returns total number of tokens for the given user, used for pagination
-func CountUserTokens(userId int) (int64, error) {
- var total int64
- err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
- return total, err
-}
-
-// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
-func BatchDeleteTokens(ids []int, userId int) (int, error) {
- if len(ids) == 0 {
- return 0, errors.New("ids 不能为空!")
- }
-
- tx := DB.Begin()
-
- var tokens []Token
- if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
- tx.Rollback()
- return 0, err
- }
-
- if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
- tx.Rollback()
- return 0, err
- }
-
- if err := tx.Commit().Error; err != nil {
- return 0, err
- }
-
- if common.RedisEnabled {
- gopool.Go(func() {
- for _, t := range tokens {
- _ = cacheDeleteToken(t.Key)
- }
- })
- }
-
- return len(tokens), nil
-}
diff --git a/new-api/model/token_cache.go b/new-api/model/token_cache.go
deleted file mode 100644
index 42367357516390e54555b8f5e0fa86f8032e5e23..0000000000000000000000000000000000000000
--- a/new-api/model/token_cache.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package model
-
-import (
- "fmt"
- "one-api/common"
- "one-api/constant"
- "time"
-)
-
-func cacheSetToken(token Token) error {
- key := common.GenerateHMAC(token.Key)
- token.Clean()
- err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
- if err != nil {
- return err
- }
- return nil
-}
-
-func cacheDeleteToken(key string) error {
- key = common.GenerateHMAC(key)
- err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
- if err != nil {
- return err
- }
- return nil
-}
-
-func cacheIncrTokenQuota(key string, increment int64) error {
- key = common.GenerateHMAC(key)
- err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
- if err != nil {
- return err
- }
- return nil
-}
-
-func cacheDecrTokenQuota(key string, decrement int64) error {
- return cacheIncrTokenQuota(key, -decrement)
-}
-
-func cacheSetTokenField(key string, field string, value string) error {
- key = common.GenerateHMAC(key)
- err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
- if err != nil {
- return err
- }
- return nil
-}
-
-// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
-func cacheGetTokenByKey(key string) (*Token, error) {
- hmacKey := common.GenerateHMAC(key)
- if !common.RedisEnabled {
- return nil, fmt.Errorf("redis is not enabled")
- }
- var token Token
- err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
- if err != nil {
- return nil, err
- }
- token.Key = key
- return &token, nil
-}
diff --git a/new-api/model/topup.go b/new-api/model/topup.go
deleted file mode 100644
index 6c60bab07f1ff1c1b8391dc5f6ab48d5e0a7843b..0000000000000000000000000000000000000000
--- a/new-api/model/topup.go
+++ /dev/null
@@ -1,101 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "one-api/logger"
-
- "gorm.io/gorm"
-)
-
-type TopUp struct {
- Id int `json:"id"`
- UserId int `json:"user_id" gorm:"index"`
- Amount int64 `json:"amount"`
- Money float64 `json:"money"`
- TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
- CreateTime int64 `json:"create_time"`
- CompleteTime int64 `json:"complete_time"`
- Status string `json:"status"`
-}
-
-func (topUp *TopUp) Insert() error {
- var err error
- err = DB.Create(topUp).Error
- return err
-}
-
-func (topUp *TopUp) Update() error {
- var err error
- err = DB.Save(topUp).Error
- return err
-}
-
-func GetTopUpById(id int) *TopUp {
- var topUp *TopUp
- var err error
- err = DB.Where("id = ?", id).First(&topUp).Error
- if err != nil {
- return nil
- }
- return topUp
-}
-
-func GetTopUpByTradeNo(tradeNo string) *TopUp {
- var topUp *TopUp
- var err error
- err = DB.Where("trade_no = ?", tradeNo).First(&topUp).Error
- if err != nil {
- return nil
- }
- return topUp
-}
-
-func Recharge(referenceId string, customerId string) (err error) {
- if referenceId == "" {
- return errors.New("未提供支付单号")
- }
-
- var quota float64
- topUp := &TopUp{}
-
- refCol := "`trade_no`"
- if common.UsingPostgreSQL {
- refCol = `"trade_no"`
- }
-
- err = DB.Transaction(func(tx *gorm.DB) error {
- err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
- if err != nil {
- return errors.New("充值订单不存在")
- }
-
- if topUp.Status != common.TopUpStatusPending {
- return errors.New("充值订单状态错误")
- }
-
- topUp.CompleteTime = common.GetTimestamp()
- topUp.Status = common.TopUpStatusSuccess
- err = tx.Save(topUp).Error
- if err != nil {
- return err
- }
-
- quota = topUp.Money * common.QuotaPerUnit
- err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
- if err != nil {
- return err
- }
-
- return nil
- })
-
- if err != nil {
- return errors.New("充值失败," + err.Error())
- }
-
- RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
-
- return nil
-}
diff --git a/new-api/model/twofa.go b/new-api/model/twofa.go
deleted file mode 100644
index 53a663bf7d6cbac3cbcf19fa8d2eea500a36be2d..0000000000000000000000000000000000000000
--- a/new-api/model/twofa.go
+++ /dev/null
@@ -1,322 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "time"
-
- "gorm.io/gorm"
-)
-
-var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
-
-// TwoFA 用户2FA设置表
-type TwoFA struct {
- Id int `json:"id" gorm:"primaryKey"`
- UserId int `json:"user_id" gorm:"unique;not null;index"`
- Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
- IsEnabled bool `json:"is_enabled"`
- FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
- LockedUntil *time.Time `json:"locked_until,omitempty"`
- LastUsedAt *time.Time `json:"last_used_at,omitempty"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
-}
-
-// TwoFABackupCode 备用码使用记录表
-type TwoFABackupCode struct {
- Id int `json:"id" gorm:"primaryKey"`
- UserId int `json:"user_id" gorm:"not null;index"`
- CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
- IsUsed bool `json:"is_used"`
- UsedAt *time.Time `json:"used_at,omitempty"`
- CreatedAt time.Time `json:"created_at"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
-}
-
-// GetTwoFAByUserId 根据用户ID获取2FA设置
-func GetTwoFAByUserId(userId int) (*TwoFA, error) {
- if userId == 0 {
- return nil, errors.New("用户ID不能为空")
- }
-
- var twoFA TwoFA
- err := DB.Where("user_id = ?", userId).First(&twoFA).Error
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return nil, nil // 返回nil表示未设置2FA
- }
- return nil, err
- }
-
- return &twoFA, nil
-}
-
-// IsTwoFAEnabled 检查用户是否启用了2FA
-func IsTwoFAEnabled(userId int) bool {
- twoFA, err := GetTwoFAByUserId(userId)
- if err != nil || twoFA == nil {
- return false
- }
- return twoFA.IsEnabled
-}
-
-// CreateTwoFA 创建2FA设置
-func (t *TwoFA) Create() error {
- // 检查用户是否已存在2FA设置
- existing, err := GetTwoFAByUserId(t.UserId)
- if err != nil {
- return err
- }
- if existing != nil {
- return errors.New("用户已存在2FA设置")
- }
-
- // 验证用户存在
- var user User
- if err := DB.First(&user, t.UserId).Error; err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return errors.New("用户不存在")
- }
- return err
- }
-
- return DB.Create(t).Error
-}
-
-// Update 更新2FA设置
-func (t *TwoFA) Update() error {
- if t.Id == 0 {
- return errors.New("2FA记录ID不能为空")
- }
- return DB.Save(t).Error
-}
-
-// Delete 删除2FA设置
-func (t *TwoFA) Delete() error {
- if t.Id == 0 {
- return errors.New("2FA记录ID不能为空")
- }
-
- // 使用事务确保原子性
- return DB.Transaction(func(tx *gorm.DB) error {
- // 同时删除相关的备用码记录(硬删除)
- if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
- return err
- }
-
- // 硬删除2FA记录
- return tx.Unscoped().Delete(t).Error
- })
-}
-
-// ResetFailedAttempts 重置失败尝试次数
-func (t *TwoFA) ResetFailedAttempts() error {
- t.FailedAttempts = 0
- t.LockedUntil = nil
- return t.Update()
-}
-
-// IncrementFailedAttempts 增加失败尝试次数
-func (t *TwoFA) IncrementFailedAttempts() error {
- t.FailedAttempts++
-
- // 检查是否需要锁定
- if t.FailedAttempts >= common.MaxFailAttempts {
- lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
- t.LockedUntil = &lockUntil
- }
-
- return t.Update()
-}
-
-// IsLocked 检查账户是否被锁定
-func (t *TwoFA) IsLocked() bool {
- if t.LockedUntil == nil {
- return false
- }
- return time.Now().Before(*t.LockedUntil)
-}
-
-// CreateBackupCodes 创建备用码
-func CreateBackupCodes(userId int, codes []string) error {
- return DB.Transaction(func(tx *gorm.DB) error {
- // 先删除现有的备用码
- if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
- return err
- }
-
- // 创建新的备用码记录
- for _, code := range codes {
- hashedCode, err := common.HashBackupCode(code)
- if err != nil {
- return err
- }
-
- backupCode := TwoFABackupCode{
- UserId: userId,
- CodeHash: hashedCode,
- IsUsed: false,
- }
-
- if err := tx.Create(&backupCode).Error; err != nil {
- return err
- }
- }
-
- return nil
- })
-}
-
-// ValidateBackupCode 验证并使用备用码
-func ValidateBackupCode(userId int, code string) (bool, error) {
- if !common.ValidateBackupCode(code) {
- return false, errors.New("验证码或备用码不正确")
- }
-
- normalizedCode := common.NormalizeBackupCode(code)
-
- // 查找未使用的备用码
- var backupCodes []TwoFABackupCode
- if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
- return false, err
- }
-
- // 验证备用码
- for _, bc := range backupCodes {
- if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
- // 标记为已使用
- now := time.Now()
- bc.IsUsed = true
- bc.UsedAt = &now
-
- if err := DB.Save(&bc).Error; err != nil {
- return false, err
- }
-
- return true, nil
- }
- }
-
- return false, nil
-}
-
-// GetUnusedBackupCodeCount 获取未使用的备用码数量
-func GetUnusedBackupCodeCount(userId int) (int, error) {
- var count int64
- err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
- return int(count), err
-}
-
-// DisableTwoFA 禁用用户的2FA
-func DisableTwoFA(userId int) error {
- twoFA, err := GetTwoFAByUserId(userId)
- if err != nil {
- return err
- }
- if twoFA == nil {
- return ErrTwoFANotEnabled
- }
-
- // 删除2FA设置和备用码
- return twoFA.Delete()
-}
-
-// EnableTwoFA 启用2FA
-func (t *TwoFA) Enable() error {
- t.IsEnabled = true
- t.FailedAttempts = 0
- t.LockedUntil = nil
- return t.Update()
-}
-
-// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
-func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
- // 检查是否被锁定
- if t.IsLocked() {
- return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
- }
-
- // 验证TOTP码
- if !common.ValidateTOTPCode(t.Secret, code) {
- // 增加失败次数
- if err := t.IncrementFailedAttempts(); err != nil {
- common.SysLog("更新2FA失败次数失败: " + err.Error())
- }
- return false, nil
- }
-
- // 验证成功,重置失败次数并更新最后使用时间
- now := time.Now()
- t.FailedAttempts = 0
- t.LockedUntil = nil
- t.LastUsedAt = &now
-
- if err := t.Update(); err != nil {
- common.SysLog("更新2FA使用记录失败: " + err.Error())
- }
-
- return true, nil
-}
-
-// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
-func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
- // 检查是否被锁定
- if t.IsLocked() {
- return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
- }
-
- // 验证备用码
- valid, err := ValidateBackupCode(t.UserId, code)
- if err != nil {
- return false, err
- }
-
- if !valid {
- // 增加失败次数
- if err := t.IncrementFailedAttempts(); err != nil {
- common.SysLog("更新2FA失败次数失败: " + err.Error())
- }
- return false, nil
- }
-
- // 验证成功,重置失败次数并更新最后使用时间
- now := time.Now()
- t.FailedAttempts = 0
- t.LockedUntil = nil
- t.LastUsedAt = &now
-
- if err := t.Update(); err != nil {
- common.SysLog("更新2FA使用记录失败: " + err.Error())
- }
-
- return true, nil
-}
-
-// GetTwoFAStats 获取2FA统计信息(管理员使用)
-func GetTwoFAStats() (map[string]interface{}, error) {
- var totalUsers, enabledUsers int64
-
- // 总用户数
- if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
- return nil, err
- }
-
- // 启用2FA的用户数
- if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
- return nil, err
- }
-
- enabledRate := float64(0)
- if totalUsers > 0 {
- enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
- }
-
- return map[string]interface{}{
- "total_users": totalUsers,
- "enabled_users": enabledUsers,
- "enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
- }, nil
-}
diff --git a/new-api/model/usedata.go b/new-api/model/usedata.go
deleted file mode 100644
index 7404cdd6df6e78a2e5518fb6359e1a9ee1abce21..0000000000000000000000000000000000000000
--- a/new-api/model/usedata.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package model
-
-import (
- "fmt"
- "gorm.io/gorm"
- "one-api/common"
- "sync"
- "time"
-)
-
-// QuotaData 柱状图数据
-type QuotaData struct {
- Id int `json:"id"`
- UserID int `json:"user_id" gorm:"index"`
- Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
- ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
- CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
- TokenUsed int `json:"token_used" gorm:"default:0"`
- Count int `json:"count" gorm:"default:0"`
- Quota int `json:"quota" gorm:"default:0"`
-}
-
-func UpdateQuotaData() {
- for {
- if common.DataExportEnabled {
- common.SysLog("正在更新数据看板数据...")
- SaveQuotaDataCache()
- }
- time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
- }
-}
-
-var CacheQuotaData = make(map[string]*QuotaData)
-var CacheQuotaDataLock = sync.Mutex{}
-
-func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
- key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
- quotaData, ok := CacheQuotaData[key]
- if ok {
- quotaData.Count += 1
- quotaData.Quota += quota
- quotaData.TokenUsed += tokenUsed
- } else {
- quotaData = &QuotaData{
- UserID: userId,
- Username: username,
- ModelName: modelName,
- CreatedAt: createdAt,
- Count: 1,
- Quota: quota,
- TokenUsed: tokenUsed,
- }
- }
- CacheQuotaData[key] = quotaData
-}
-
-func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
- // 只精确到小时
- createdAt = createdAt - (createdAt % 3600)
-
- CacheQuotaDataLock.Lock()
- defer CacheQuotaDataLock.Unlock()
- logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed)
-}
-
-func SaveQuotaDataCache() {
- CacheQuotaDataLock.Lock()
- defer CacheQuotaDataLock.Unlock()
- size := len(CacheQuotaData)
- // 如果缓存中有数据,就保存到数据库中
- // 1. 先查询数据库中是否有数据
- // 2. 如果有数据,就更新数据
- // 3. 如果没有数据,就插入数据
- for _, quotaData := range CacheQuotaData {
- quotaDataDB := &QuotaData{}
- DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
- quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
- if quotaDataDB.Id > 0 {
- //quotaDataDB.Count += quotaData.Count
- //quotaDataDB.Quota += quotaData.Quota
- //DB.Table("quota_data").Save(quotaDataDB)
- increaseQuotaData(quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.Count, quotaData.Quota, quotaData.CreatedAt, quotaData.TokenUsed)
- } else {
- DB.Table("quota_data").Create(quotaData)
- }
- }
- CacheQuotaData = make(map[string]*QuotaData)
- common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
-}
-
-func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
- err := DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
- userId, username, modelName, createdAt).Updates(map[string]interface{}{
- "count": gorm.Expr("count + ?", count),
- "quota": gorm.Expr("quota + ?", quota),
- "token_used": gorm.Expr("token_used + ?", tokenUsed),
- }).Error
- if err != nil {
- common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
- }
-}
-
-func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
- var quotaDatas []*QuotaData
- // 从quota_data表中查询数据
- err = DB.Table("quota_data").Where("username = ? and created_at >= ? and created_at <= ?", username, startTime, endTime).Find("aDatas).Error
- return quotaDatas, err
-}
-
-func GetQuotaDataByUserId(userId int, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
- var quotaDatas []*QuotaData
- // 从quota_data表中查询数据
- err = DB.Table("quota_data").Where("user_id = ? and created_at >= ? and created_at <= ?", userId, startTime, endTime).Find("aDatas).Error
- return quotaDatas, err
-}
-
-func GetAllQuotaDates(startTime int64, endTime int64, username string) (quotaData []*QuotaData, err error) {
- if username != "" {
- return GetQuotaDataByUsername(username, startTime, endTime)
- }
- var quotaDatas []*QuotaData
- // 从quota_data表中查询数据
- // only select model_name, sum(count) as count, sum(quota) as quota, model_name, created_at from quota_data group by model_name, created_at;
- //err = DB.Table("quota_data").Where("created_at >= ? and created_at <= ?", startTime, endTime).Find("aDatas).Error
- err = DB.Table("quota_data").Select("model_name, sum(count) as count, sum(quota) as quota, sum(token_used) as token_used, created_at").Where("created_at >= ? and created_at <= ?", startTime, endTime).Group("model_name, created_at").Find("aDatas).Error
- return quotaDatas, err
-}
diff --git a/new-api/model/user.go b/new-api/model/user.go
deleted file mode 100644
index a1c0d2347ad2cdd0c8cd5c57647627adfc03cf29..0000000000000000000000000000000000000000
--- a/new-api/model/user.go
+++ /dev/null
@@ -1,917 +0,0 @@
-package model
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "strconv"
- "strings"
-
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
-)
-
-// User if you add sensitive fields, don't forget to clean them in setupLogin function.
-// Otherwise, the sensitive information will be saved on local storage in plain text!
-type User struct {
- Id int `json:"id"`
- Username string `json:"username" gorm:"unique;index" validate:"max=20"`
- Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
- OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
- DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
- Role int `json:"role" gorm:"type:int;default:1"` // admin, common
- Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
- Email string `json:"email" gorm:"index" validate:"max=50"`
- GitHubId string `json:"github_id" gorm:"column:github_id;index"`
- OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
- WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
- TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
- VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
- AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
- Quota int `json:"quota" gorm:"type:int;default:0"`
- UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
- RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
- Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
- AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
- AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
- AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
- AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
- InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
- DeletedAt gorm.DeletedAt `gorm:"index"`
- LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
- Setting string `json:"setting" gorm:"type:text;column:setting"`
- Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
- StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
-}
-
-func (user *User) ToBaseUser() *UserBase {
- cache := &UserBase{
- Id: user.Id,
- Group: user.Group,
- Quota: user.Quota,
- Status: user.Status,
- Username: user.Username,
- Setting: user.Setting,
- Email: user.Email,
- }
- return cache
-}
-
-func (user *User) GetAccessToken() string {
- if user.AccessToken == nil {
- return ""
- }
- return *user.AccessToken
-}
-
-func (user *User) SetAccessToken(token string) {
- user.AccessToken = &token
-}
-
-func (user *User) GetSetting() dto.UserSetting {
- setting := dto.UserSetting{}
- if user.Setting != "" {
- err := json.Unmarshal([]byte(user.Setting), &setting)
- if err != nil {
- common.SysLog("failed to unmarshal setting: " + err.Error())
- }
- }
- return setting
-}
-
-func (user *User) SetSetting(setting dto.UserSetting) {
- settingBytes, err := json.Marshal(setting)
- if err != nil {
- common.SysLog("failed to marshal setting: " + err.Error())
- return
- }
- user.Setting = string(settingBytes)
-}
-
-// 根据用户角色生成默认的边栏配置
-func generateDefaultSidebarConfigForRole(userRole int) string {
- defaultConfig := map[string]interface{}{}
-
- // 聊天区域 - 所有用户都可以访问
- defaultConfig["chat"] = map[string]interface{}{
- "enabled": true,
- "playground": true,
- "chat": true,
- }
-
- // 控制台区域 - 所有用户都可以访问
- defaultConfig["console"] = map[string]interface{}{
- "enabled": true,
- "detail": true,
- "token": true,
- "log": true,
- "midjourney": true,
- "task": true,
- }
-
- // 个人中心区域 - 所有用户都可以访问
- defaultConfig["personal"] = map[string]interface{}{
- "enabled": true,
- "topup": true,
- "personal": true,
- }
-
- // 管理员区域 - 根据角色决定
- if userRole == common.RoleAdminUser {
- // 管理员可以访问管理员区域,但不能访问系统设置
- defaultConfig["admin"] = map[string]interface{}{
- "enabled": true,
- "channel": true,
- "models": true,
- "redemption": true,
- "user": true,
- "setting": false, // 管理员不能访问系统设置
- }
- } else if userRole == common.RoleRootUser {
- // 超级管理员可以访问所有功能
- defaultConfig["admin"] = map[string]interface{}{
- "enabled": true,
- "channel": true,
- "models": true,
- "redemption": true,
- "user": true,
- "setting": true,
- }
- }
- // 普通用户不包含admin区域
-
- // 转换为JSON字符串
- configBytes, err := json.Marshal(defaultConfig)
- if err != nil {
- common.SysLog("生成默认边栏配置失败: " + err.Error())
- return ""
- }
-
- return string(configBytes)
-}
-
-// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
-func CheckUserExistOrDeleted(username string, email string) (bool, error) {
- var user User
-
- // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
- // check email if empty
- var err error
- if email == "" {
- err = DB.Unscoped().First(&user, "username = ?", username).Error
- } else {
- err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
- }
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- // not exist, return false, nil
- return false, nil
- }
- // other error, return false, err
- return false, err
- }
- // exist, return true, nil
- return true, nil
-}
-
-func GetMaxUserId() int {
- var user User
- DB.Unscoped().Last(&user)
- return user.Id
-}
-
-func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
- // Start transaction
- tx := DB.Begin()
- if tx.Error != nil {
- return nil, 0, tx.Error
- }
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
-
- // Get total count within transaction
- err = tx.Unscoped().Model(&User{}).Count(&total).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // Get paginated users within same transaction
- err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // Commit transaction
- if err = tx.Commit().Error; err != nil {
- return nil, 0, err
- }
-
- return users, total, nil
-}
-
-func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
- var users []*User
- var total int64
- var err error
-
- // 开始事务
- tx := DB.Begin()
- if tx.Error != nil {
- return nil, 0, tx.Error
- }
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
-
- // 构建基础查询
- query := tx.Unscoped().Model(&User{})
-
- // 构建搜索条件
- likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
-
- // 尝试将关键字转换为整数ID
- keywordInt, err := strconv.Atoi(keyword)
- if err == nil {
- // 如果是数字,同时搜索ID和其他字段
- likeCondition = "id = ? OR " + likeCondition
- if group != "" {
- query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
- keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
- } else {
- query = query.Where(likeCondition,
- keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
- }
- } else {
- // 非数字关键字,只搜索字符串字段
- if group != "" {
- query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
- "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
- } else {
- query = query.Where(likeCondition,
- "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
- }
- }
-
- // 获取总数
- err = query.Count(&total).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // 获取分页数据
- err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
- if err != nil {
- tx.Rollback()
- return nil, 0, err
- }
-
- // 提交事务
- if err = tx.Commit().Error; err != nil {
- return nil, 0, err
- }
-
- return users, total, nil
-}
-
-func GetUserById(id int, selectAll bool) (*User, error) {
- if id == 0 {
- return nil, errors.New("id 为空!")
- }
- user := User{Id: id}
- var err error = nil
- if selectAll {
- err = DB.First(&user, "id = ?", id).Error
- } else {
- err = DB.Omit("password").First(&user, "id = ?", id).Error
- }
- return &user, err
-}
-
-func GetUserIdByAffCode(affCode string) (int, error) {
- if affCode == "" {
- return 0, errors.New("affCode 为空!")
- }
- var user User
- err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
- return user.Id, err
-}
-
-func DeleteUserById(id int) (err error) {
- if id == 0 {
- return errors.New("id 为空!")
- }
- user := User{Id: id}
- return user.Delete()
-}
-
-func HardDeleteUserById(id int) error {
- if id == 0 {
- return errors.New("id 为空!")
- }
- err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
- return err
-}
-
-func inviteUser(inviterId int) (err error) {
- user, err := GetUserById(inviterId, true)
- if err != nil {
- return err
- }
- user.AffCount++
- user.AffQuota += common.QuotaForInviter
- user.AffHistoryQuota += common.QuotaForInviter
- return DB.Save(user).Error
-}
-
-func (user *User) TransferAffQuotaToQuota(quota int) error {
- // 检查quota是否小于最小额度
- if float64(quota) < common.QuotaPerUnit {
- return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
- }
-
- // 开始数据库事务
- tx := DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- defer tx.Rollback() // 确保在函数退出时事务能回滚
-
- // 加锁查询用户以确保数据一致性
- err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error
- if err != nil {
- return err
- }
-
- // 再次检查用户的AffQuota是否足够
- if user.AffQuota < quota {
- return errors.New("邀请额度不足!")
- }
-
- // 更新用户额度
- user.AffQuota -= quota
- user.Quota += quota
-
- // 保存用户状态
- if err := tx.Save(user).Error; err != nil {
- return err
- }
-
- // 提交事务
- return tx.Commit().Error
-}
-
-func (user *User) Insert(inviterId int) error {
- var err error
- if user.Password != "" {
- user.Password, err = common.Password2Hash(user.Password)
- if err != nil {
- return err
- }
- }
- user.Quota = common.QuotaForNewUser
- //user.SetAccessToken(common.GetUUID())
- user.AffCode = common.GetRandomString(4)
-
- // 初始化用户设置,包括默认的边栏配置
- if user.Setting == "" {
- defaultSetting := dto.UserSetting{}
- // 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
- user.SetSetting(defaultSetting)
- }
-
- result := DB.Create(user)
- if result.Error != nil {
- return result.Error
- }
-
- // 用户创建成功后,根据角色初始化边栏配置
- // 需要重新获取用户以确保有正确的ID和Role
- var createdUser User
- if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
- // 生成基于角色的默认边栏配置
- defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
- if defaultSidebarConfig != "" {
- currentSetting := createdUser.GetSetting()
- currentSetting.SidebarModules = defaultSidebarConfig
- createdUser.SetSetting(currentSetting)
- createdUser.Update(false)
- common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
- }
- }
-
- if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
- }
- if inviterId != 0 {
- if common.QuotaForInvitee > 0 {
- _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
- }
- if common.QuotaForInviter > 0 {
- //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
- _ = inviteUser(inviterId)
- }
- }
- return nil
-}
-
-func (user *User) Update(updatePassword bool) error {
- var err error
- if updatePassword {
- user.Password, err = common.Password2Hash(user.Password)
- if err != nil {
- return err
- }
- }
- newUser := *user
- DB.First(&user, user.Id)
- if err = DB.Model(user).Updates(newUser).Error; err != nil {
- return err
- }
-
- // Update cache
- return updateUserCache(*user)
-}
-
-func (user *User) Edit(updatePassword bool) error {
- var err error
- if updatePassword {
- user.Password, err = common.Password2Hash(user.Password)
- if err != nil {
- return err
- }
- }
-
- newUser := *user
- updates := map[string]interface{}{
- "username": newUser.Username,
- "display_name": newUser.DisplayName,
- "group": newUser.Group,
- "quota": newUser.Quota,
- "remark": newUser.Remark,
- }
- if updatePassword {
- updates["password"] = newUser.Password
- }
-
- DB.First(&user, user.Id)
- if err = DB.Model(user).Updates(updates).Error; err != nil {
- return err
- }
-
- // Update cache
- return updateUserCache(*user)
-}
-
-func (user *User) Delete() error {
- if user.Id == 0 {
- return errors.New("id 为空!")
- }
- if err := DB.Delete(user).Error; err != nil {
- return err
- }
-
- // 清除缓存
- return invalidateUserCache(user.Id)
-}
-
-func (user *User) HardDelete() error {
- if user.Id == 0 {
- return errors.New("id 为空!")
- }
- err := DB.Unscoped().Delete(user).Error
- return err
-}
-
-// ValidateAndFill check password & user status
-func (user *User) ValidateAndFill() (err error) {
- // When querying with struct, GORM will only query with non-zero fields,
- // that means if your field's value is 0, '', false or other zero values,
- // it won't be used to build query conditions
- password := user.Password
- username := strings.TrimSpace(user.Username)
- if username == "" || password == "" {
- return errors.New("用户名或密码为空")
- }
- // find buy username or email
- DB.Where("username = ? OR email = ?", username, username).First(user)
- okay := common.ValidatePasswordAndHash(password, user.Password)
- if !okay || user.Status != common.UserStatusEnabled {
- return errors.New("用户名或密码错误,或用户已被封禁")
- }
- return nil
-}
-
-func (user *User) FillUserById() error {
- if user.Id == 0 {
- return errors.New("id 为空!")
- }
- DB.Where(User{Id: user.Id}).First(user)
- return nil
-}
-
-func (user *User) FillUserByEmail() error {
- if user.Email == "" {
- return errors.New("email 为空!")
- }
- DB.Where(User{Email: user.Email}).First(user)
- return nil
-}
-
-func (user *User) FillUserByGitHubId() error {
- if user.GitHubId == "" {
- return errors.New("GitHub id 为空!")
- }
- DB.Where(User{GitHubId: user.GitHubId}).First(user)
- return nil
-}
-
-func (user *User) FillUserByOidcId() error {
- if user.OidcId == "" {
- return errors.New("oidc id 为空!")
- }
- DB.Where(User{OidcId: user.OidcId}).First(user)
- return nil
-}
-
-func (user *User) FillUserByWeChatId() error {
- if user.WeChatId == "" {
- return errors.New("WeChat id 为空!")
- }
- DB.Where(User{WeChatId: user.WeChatId}).First(user)
- return nil
-}
-
-func (user *User) FillUserByTelegramId() error {
- if user.TelegramId == "" {
- return errors.New("Telegram id 为空!")
- }
- err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return errors.New("该 Telegram 账户未绑定")
- }
- return nil
-}
-
-func IsEmailAlreadyTaken(email string) bool {
- return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
-}
-
-func IsWeChatIdAlreadyTaken(wechatId string) bool {
- return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
-}
-
-func IsGitHubIdAlreadyTaken(githubId string) bool {
- return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
-}
-
-func IsOidcIdAlreadyTaken(oidcId string) bool {
- return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
-}
-
-func IsTelegramIdAlreadyTaken(telegramId string) bool {
- return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
-}
-
-func ResetUserPasswordByEmail(email string, password string) error {
- if email == "" || password == "" {
- return errors.New("邮箱地址或密码为空!")
- }
- hashedPassword, err := common.Password2Hash(password)
- if err != nil {
- return err
- }
- err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
- return err
-}
-
-func IsAdmin(userId int) bool {
- if userId == 0 {
- return false
- }
- var user User
- err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
- if err != nil {
- common.SysLog("no such user " + err.Error())
- return false
- }
- return user.Role >= common.RoleAdminUser
-}
-
-//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
-//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
-// defer func() {
-// // Update Redis cache asynchronously on successful DB read
-// if shouldUpdateRedis(fromDB, err) {
-// gopool.Go(func() {
-// if err := updateUserStatusCache(id, status); err != nil {
-// common.SysError("failed to update user status cache: " + err.Error())
-// }
-// })
-// }
-// }()
-// if !fromDB && common.RedisEnabled {
-// // Try Redis first
-// status, err := getUserStatusCache(id)
-// if err == nil {
-// return status == common.UserStatusEnabled, nil
-// }
-// // Don't return error - fall through to DB
-// }
-// fromDB = true
-// var user User
-// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
-// if err != nil {
-// return false, err
-// }
-//
-// return user.Status == common.UserStatusEnabled, nil
-//}
-
-func ValidateAccessToken(token string) (user *User) {
- if token == "" {
- return nil
- }
- token = strings.Replace(token, "Bearer ", "", 1)
- user = &User{}
- if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
- return user
- }
- return nil
-}
-
-// GetUserQuota gets quota from Redis first, falls back to DB if needed
-func GetUserQuota(id int, fromDB bool) (quota int, err error) {
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) {
- gopool.Go(func() {
- if err := updateUserQuotaCache(id, quota); err != nil {
- common.SysLog("failed to update user quota cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- quota, err := getUserQuotaCache(id)
- if err == nil {
- return quota, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
- if err != nil {
- return 0, err
- }
-
- return quota, nil
-}
-
-func GetUserUsedQuota(id int) (quota int, err error) {
- err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
- return quota, err
-}
-
-func GetUserEmail(id int) (email string, err error) {
- err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
- return email, err
-}
-
-// GetUserGroup gets group from Redis first, falls back to DB if needed
-func GetUserGroup(id int, fromDB bool) (group string, err error) {
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) {
- gopool.Go(func() {
- if err := updateUserGroupCache(id, group); err != nil {
- common.SysLog("failed to update user group cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- group, err := getUserGroupCache(id)
- if err == nil {
- return group, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
- if err != nil {
- return "", err
- }
-
- return group, nil
-}
-
-// GetUserSetting gets setting from Redis first, falls back to DB if needed
-func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
- var setting string
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) {
- gopool.Go(func() {
- if err := updateUserSettingCache(id, setting); err != nil {
- common.SysLog("failed to update user setting cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- setting, err := getUserSettingCache(id)
- if err == nil {
- return setting, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
- if err != nil {
- return settingMap, err
- }
- userBase := &UserBase{
- Setting: setting,
- }
- return userBase.GetSetting(), nil
-}
-
-func IncreaseUserQuota(id int, quota int, db bool) (err error) {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- gopool.Go(func() {
- err := cacheIncrUserQuota(id, int64(quota))
- if err != nil {
- common.SysLog("failed to increase user quota: " + err.Error())
- }
- })
- if !db && common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeUserQuota, id, quota)
- return nil
- }
- return increaseUserQuota(id, quota)
-}
-
-func increaseUserQuota(id int, quota int) (err error) {
- err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
- if err != nil {
- return err
- }
- return err
-}
-
-func DecreaseUserQuota(id int, quota int) (err error) {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- gopool.Go(func() {
- err := cacheDecrUserQuota(id, int64(quota))
- if err != nil {
- common.SysLog("failed to decrease user quota: " + err.Error())
- }
- })
- if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
- return nil
- }
- return decreaseUserQuota(id, quota)
-}
-
-func decreaseUserQuota(id int, quota int) (err error) {
- err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
- if err != nil {
- return err
- }
- return err
-}
-
-func DeltaUpdateUserQuota(id int, delta int) (err error) {
- if delta == 0 {
- return nil
- }
- if delta > 0 {
- return IncreaseUserQuota(id, delta, false)
- } else {
- return DecreaseUserQuota(id, -delta)
- }
-}
-
-//func GetRootUserEmail() (email string) {
-// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
-// return email
-//}
-
-func GetRootUser() (user *User) {
- DB.Where("role = ?", common.RoleRootUser).First(&user)
- return user
-}
-
-func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
- if common.BatchUpdateEnabled {
- addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
- addNewRecord(BatchUpdateTypeRequestCount, id, 1)
- return
- }
- updateUserUsedQuotaAndRequestCount(id, quota, 1)
-}
-
-func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
- err := DB.Model(&User{}).Where("id = ?", id).Updates(
- map[string]interface{}{
- "used_quota": gorm.Expr("used_quota + ?", quota),
- "request_count": gorm.Expr("request_count + ?", count),
- },
- ).Error
- if err != nil {
- common.SysLog("failed to update user used quota and request count: " + err.Error())
- return
- }
-
- //// 更新缓存
- //if err := invalidateUserCache(id); err != nil {
- // common.SysError("failed to invalidate user cache: " + err.Error())
- //}
-}
-
-func updateUserUsedQuota(id int, quota int) {
- err := DB.Model(&User{}).Where("id = ?", id).Updates(
- map[string]interface{}{
- "used_quota": gorm.Expr("used_quota + ?", quota),
- },
- ).Error
- if err != nil {
- common.SysLog("failed to update user used quota: " + err.Error())
- }
-}
-
-func updateUserRequestCount(id int, count int) {
- err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
- if err != nil {
- common.SysLog("failed to update user request count: " + err.Error())
- }
-}
-
-// GetUsernameById gets username from Redis first, falls back to DB if needed
-func GetUsernameById(id int, fromDB bool) (username string, err error) {
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) {
- gopool.Go(func() {
- if err := updateUserNameCache(id, username); err != nil {
- common.SysLog("failed to update user name cache: " + err.Error())
- }
- })
- }
- }()
- if !fromDB && common.RedisEnabled {
- username, err := getUserNameCache(id)
- if err == nil {
- return username, nil
- }
- // Don't return error - fall through to DB
- }
- fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
- if err != nil {
- return "", err
- }
-
- return username, nil
-}
-
-func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
- var user User
- err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
- return !errors.Is(err, gorm.ErrRecordNotFound)
-}
-
-func (user *User) FillUserByLinuxDOId() error {
- if user.LinuxDOId == "" {
- return errors.New("linux do id is empty")
- }
- err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
- return err
-}
-
-func RootUserExists() bool {
- var user User
- err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
- if err != nil {
- return false
- }
- return true
-}
diff --git a/new-api/model/user_cache.go b/new-api/model/user_cache.go
deleted file mode 100644
index d60dbe018427c67bb149241fa2dc5db4d42e5d57..0000000000000000000000000000000000000000
--- a/new-api/model/user_cache.go
+++ /dev/null
@@ -1,218 +0,0 @@
-package model
-
-import (
- "fmt"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "time"
-
- "github.com/gin-gonic/gin"
-
- "github.com/bytedance/gopkg/util/gopool"
-)
-
-// UserBase struct remains the same as it represents the cached data structure
-type UserBase struct {
- Id int `json:"id"`
- Group string `json:"group"`
- Email string `json:"email"`
- Quota int `json:"quota"`
- Status int `json:"status"`
- Username string `json:"username"`
- Setting string `json:"setting"`
-}
-
-func (user *UserBase) WriteContext(c *gin.Context) {
- common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
- common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
- common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
- common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
- common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
- common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
-}
-
-func (user *UserBase) GetSetting() dto.UserSetting {
- setting := dto.UserSetting{}
- if user.Setting != "" {
- err := common.Unmarshal([]byte(user.Setting), &setting)
- if err != nil {
- common.SysLog("failed to unmarshal setting: " + err.Error())
- }
- }
- return setting
-}
-
-// getUserCacheKey returns the key for user cache
-func getUserCacheKey(userId int) string {
- return fmt.Sprintf("user:%d", userId)
-}
-
-// invalidateUserCache clears user cache
-func invalidateUserCache(userId int) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisDelKey(getUserCacheKey(userId))
-}
-
-// updateUserCache updates all user cache fields using hash
-func updateUserCache(user User) error {
- if !common.RedisEnabled {
- return nil
- }
-
- return common.RedisHSetObj(
- getUserCacheKey(user.Id),
- user.ToBaseUser(),
- time.Duration(common.RedisKeyCacheSeconds())*time.Second,
- )
-}
-
-// GetUserCache gets complete user cache from hash
-func GetUserCache(userId int) (userCache *UserBase, err error) {
- var user *User
- var fromDB bool
- defer func() {
- // Update Redis cache asynchronously on successful DB read
- if shouldUpdateRedis(fromDB, err) && user != nil {
- gopool.Go(func() {
- if err := updateUserCache(*user); err != nil {
- common.SysLog("failed to update user status cache: " + err.Error())
- }
- })
- }
- }()
-
- // Try getting from Redis first
- userCache, err = cacheGetUserBase(userId)
- if err == nil {
- return userCache, nil
- }
-
- // If Redis fails, get from DB
- fromDB = true
- user, err = GetUserById(userId, false)
- if err != nil {
- return nil, err // Return nil and error if DB lookup fails
- }
-
- // Create cache object from user data
- userCache = &UserBase{
- Id: user.Id,
- Group: user.Group,
- Quota: user.Quota,
- Status: user.Status,
- Username: user.Username,
- Setting: user.Setting,
- Email: user.Email,
- }
-
- return userCache, nil
-}
-
-func cacheGetUserBase(userId int) (*UserBase, error) {
- if !common.RedisEnabled {
- return nil, fmt.Errorf("redis is not enabled")
- }
- var userCache UserBase
- // Try getting from Redis first
- err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
- if err != nil {
- return nil, err
- }
- return &userCache, nil
-}
-
-// Add atomic quota operations using hash fields
-func cacheIncrUserQuota(userId int, delta int64) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
-}
-
-func cacheDecrUserQuota(userId int, delta int64) error {
- return cacheIncrUserQuota(userId, -delta)
-}
-
-// Helper functions to get individual fields if needed
-func getUserGroupCache(userId int) (string, error) {
- cache, err := GetUserCache(userId)
- if err != nil {
- return "", err
- }
- return cache.Group, nil
-}
-
-func getUserQuotaCache(userId int) (int, error) {
- cache, err := GetUserCache(userId)
- if err != nil {
- return 0, err
- }
- return cache.Quota, nil
-}
-
-func getUserStatusCache(userId int) (int, error) {
- cache, err := GetUserCache(userId)
- if err != nil {
- return 0, err
- }
- return cache.Status, nil
-}
-
-func getUserNameCache(userId int) (string, error) {
- cache, err := GetUserCache(userId)
- if err != nil {
- return "", err
- }
- return cache.Username, nil
-}
-
-func getUserSettingCache(userId int) (dto.UserSetting, error) {
- cache, err := GetUserCache(userId)
- if err != nil {
- return dto.UserSetting{}, err
- }
- return cache.GetSetting(), nil
-}
-
-// New functions for individual field updates
-func updateUserStatusCache(userId int, status bool) error {
- if !common.RedisEnabled {
- return nil
- }
- statusInt := common.UserStatusEnabled
- if !status {
- statusInt = common.UserStatusDisabled
- }
- return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
-}
-
-func updateUserQuotaCache(userId int, quota int) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
-}
-
-func updateUserGroupCache(userId int, group string) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
-}
-
-func updateUserNameCache(userId int, username string) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
-}
-
-func updateUserSettingCache(userId int, setting string) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
-}
diff --git a/new-api/model/utils.go b/new-api/model/utils.go
deleted file mode 100644
index 6471a07dc5b0c792dfad0d94ecc8a2793b288ed1..0000000000000000000000000000000000000000
--- a/new-api/model/utils.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package model
-
-import (
- "errors"
- "one-api/common"
- "sync"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
-)
-
-const (
- BatchUpdateTypeUserQuota = iota
- BatchUpdateTypeTokenQuota
- BatchUpdateTypeUsedQuota
- BatchUpdateTypeChannelUsedQuota
- BatchUpdateTypeRequestCount
- BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
-)
-
-var batchUpdateStores []map[int]int
-var batchUpdateLocks []sync.Mutex
-
-func init() {
- for i := 0; i < BatchUpdateTypeCount; i++ {
- batchUpdateStores = append(batchUpdateStores, make(map[int]int))
- batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
- }
-}
-
-func InitBatchUpdater() {
- gopool.Go(func() {
- for {
- time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
- batchUpdate()
- }
- })
-}
-
-func addNewRecord(type_ int, id int, value int) {
- batchUpdateLocks[type_].Lock()
- defer batchUpdateLocks[type_].Unlock()
- if _, ok := batchUpdateStores[type_][id]; !ok {
- batchUpdateStores[type_][id] = value
- } else {
- batchUpdateStores[type_][id] += value
- }
-}
-
-func batchUpdate() {
- // check if there's any data to update
- hasData := false
- for i := 0; i < BatchUpdateTypeCount; i++ {
- batchUpdateLocks[i].Lock()
- if len(batchUpdateStores[i]) > 0 {
- hasData = true
- batchUpdateLocks[i].Unlock()
- break
- }
- batchUpdateLocks[i].Unlock()
- }
-
- if !hasData {
- return
- }
-
- common.SysLog("batch update started")
- for i := 0; i < BatchUpdateTypeCount; i++ {
- batchUpdateLocks[i].Lock()
- store := batchUpdateStores[i]
- batchUpdateStores[i] = make(map[int]int)
- batchUpdateLocks[i].Unlock()
- // TODO: maybe we can combine updates with same key?
- for key, value := range store {
- switch i {
- case BatchUpdateTypeUserQuota:
- err := increaseUserQuota(key, value)
- if err != nil {
- common.SysLog("failed to batch update user quota: " + err.Error())
- }
- case BatchUpdateTypeTokenQuota:
- err := increaseTokenQuota(key, value)
- if err != nil {
- common.SysLog("failed to batch update token quota: " + err.Error())
- }
- case BatchUpdateTypeUsedQuota:
- updateUserUsedQuota(key, value)
- case BatchUpdateTypeRequestCount:
- updateUserRequestCount(key, value)
- case BatchUpdateTypeChannelUsedQuota:
- updateChannelUsedQuota(key, value)
- }
- }
- }
- common.SysLog("batch update finished")
-}
-
-func RecordExist(err error) (bool, error) {
- if err == nil {
- return true, nil
- }
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return false, nil
- }
- return false, err
-}
-
-func shouldUpdateRedis(fromDB bool, err error) bool {
- return common.RedisEnabled && fromDB && err == nil
-}
diff --git a/new-api/model/vendor_meta.go b/new-api/model/vendor_meta.go
deleted file mode 100644
index 1e80df21a7e5af38f3e3ce1f115c33ab1cbaec29..0000000000000000000000000000000000000000
--- a/new-api/model/vendor_meta.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package model
-
-import (
- "one-api/common"
-
- "gorm.io/gorm"
-)
-
-// Vendor 用于存储供应商信息,供模型引用
-// Name 唯一,用于在模型中关联
-// Icon 采用 @lobehub/icons 的图标名,前端可直接渲染
-// Status 预留字段,1 表示启用
-// 本表同样遵循 3NF 设计范式
-
-type Vendor struct {
- Id int `json:"id"`
- Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"`
- Description string `json:"description,omitempty" gorm:"type:text"`
- Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
- Status int `json:"status" gorm:"default:1"`
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
- DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"`
-}
-
-// Insert 创建新的供应商记录
-func (v *Vendor) Insert() error {
- now := common.GetTimestamp()
- v.CreatedTime = now
- v.UpdatedTime = now
- return DB.Create(v).Error
-}
-
-// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID)
-func IsVendorNameDuplicated(id int, name string) (bool, error) {
- if name == "" {
- return false, nil
- }
- var cnt int64
- err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
- return cnt > 0, err
-}
-
-// Update 更新供应商记录
-func (v *Vendor) Update() error {
- v.UpdatedTime = common.GetTimestamp()
- return DB.Save(v).Error
-}
-
-// Delete 软删除供应商
-func (v *Vendor) Delete() error {
- return DB.Delete(v).Error
-}
-
-// GetVendorByID 根据 ID 获取供应商
-func GetVendorByID(id int) (*Vendor, error) {
- var v Vendor
- err := DB.First(&v, id).Error
- if err != nil {
- return nil, err
- }
- return &v, nil
-}
-
-// GetAllVendors 获取全部供应商(分页)
-func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
- var vendors []*Vendor
- err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
- return vendors, err
-}
-
-// SearchVendors 按关键字搜索供应商
-func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
- db := DB.Model(&Vendor{})
- if keyword != "" {
- like := "%" + keyword + "%"
- db = db.Where("name LIKE ? OR description LIKE ?", like, like)
- }
- var total int64
- if err := db.Count(&total).Error; err != nil {
- return nil, 0, err
- }
- var vendors []*Vendor
- if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
- return nil, 0, err
- }
- return vendors, total, nil
-}
diff --git a/new-api/one-api.service b/new-api/one-api.service
deleted file mode 100644
index 855006a79efd5798b6ad44b24c615523ab33f869..0000000000000000000000000000000000000000
--- a/new-api/one-api.service
+++ /dev/null
@@ -1,18 +0,0 @@
-# File path: /etc/systemd/system/one-api.service
-# sudo systemctl daemon-reload
-# sudo systemctl start one-api
-# sudo systemctl enable one-api
-# sudo systemctl status one-api
-[Unit]
-Description=One API Service
-After=network.target
-
-[Service]
-User=ubuntu # 注意修改用户名
-WorkingDirectory=/path/to/one-api # 注意修改路径
-ExecStart=/path/to/one-api/one-api --port 3000 --log-dir /path/to/one-api/logs # 注意修改路径和端口号
-Restart=always
-RestartSec=5
-
-[Install]
-WantedBy=multi-user.target
diff --git a/new-api/relay/audio_handler.go b/new-api/relay/audio_handler.go
deleted file mode 100644
index de22ed6b363297c9fc6731b73564e01a75a91d82..0000000000000000000000000000000000000000
--- a/new-api/relay/audio_handler.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package relay
-
-import (
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- audioReq, ok := info.Request.(*dto.AudioRequest)
- if !ok {
- return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(audioReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- resp, err := adaptor.DoRequest(c, info, ioReader)
- if err != nil {
- return types.NewError(err, types.ErrorCodeDoRequestFailed)
- }
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
-
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
-
- return nil
-}
diff --git a/new-api/relay/channel/adapter.go b/new-api/relay/channel/adapter.go
deleted file mode 100644
index 8a7ef24d91c1ff10f6acbfc08c74c326fcd086b3..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/adapter.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package channel
-
-import (
- "io"
- "net/http"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor interface {
- // Init IsStream bool
- Init(info *relaycommon.RelayInfo)
- GetRequestURL(info *relaycommon.RelayInfo) (string, error)
- SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
- ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
- ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
- ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
- ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
- ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
- ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
- DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
- DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
- GetModelList() []string
- GetChannelName() string
- ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
- ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
-}
-
-type TaskAdaptor interface {
- Init(info *relaycommon.RelayInfo)
-
- ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
-
- BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
- BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
- BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
-
- DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
- DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
-
- GetModelList() []string
- GetChannelName() string
-
- // FetchTask
- FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
-
- ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
-}
diff --git a/new-api/relay/channel/ai360/constants.go b/new-api/relay/channel/ai360/constants.go
deleted file mode 100644
index f7de961248eb9c5ac82651bb2eb1b647dd917ac4..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ai360/constants.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package ai360
-
-var ModelList = []string{
- "360gpt-turbo",
- "360gpt-turbo-responsibility-8k",
- "360gpt-pro",
- "360gpt2-pro",
- "360GPT_S2_V9",
- "embedding-bert-512-v1",
- "embedding_s1_v1",
- "semantic_similarity_s1_v1",
-}
-
-var ChannelName = "ai360"
diff --git a/new-api/relay/channel/ali/adaptor.go b/new-api/relay/channel/ali/adaptor.go
deleted file mode 100644
index 5e2337a168db15167785934d6c0e9d2039c0d96e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/adaptor.go
+++ /dev/null
@@ -1,180 +0,0 @@
-package ali
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- return req, nil
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- var fullRequestURL string
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
- default:
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
- case constant.RelayModeRerank:
- fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
- case constant.RelayModeImagesGenerations:
- fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
- case constant.RelayModeImagesEdits:
- fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
- case constant.RelayModeCompletions:
- fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
- default:
- fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
- }
- }
-
- return fullRequestURL, nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- if info.IsStream {
- req.Set("X-DashScope-SSE", "enable")
- }
- if c.GetString("plugin") != "" {
- req.Set("X-DashScope-Plugin", c.GetString("plugin"))
- }
- if info.RelayMode == constant.RelayModeImagesGenerations {
- req.Set("X-DashScope-Async", "enable")
- }
- if info.RelayMode == constant.RelayModeImagesEdits {
- req.Set("Content-Type", "application/json")
- }
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- // docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
- // fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
- if strings.Contains(request.Model, "thinking") {
- request.EnableThinking = true
- request.Stream = true
- info.IsStream = true
- }
- // fix: ali parameter.enable_thinking must be set to false for non-streaming calls
- if !info.IsStream {
- request.EnableThinking = false
- }
-
- switch info.RelayMode {
- default:
- aliReq := requestOpenAI2Ali(*request)
- return aliReq, nil
- }
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- if info.RelayMode == constant.RelayModeImagesGenerations {
- aliRequest, err := oaiImage2Ali(request)
- if err != nil {
- return nil, fmt.Errorf("convert image request failed: %w", err)
- }
- return aliRequest, nil
- } else if info.RelayMode == constant.RelayModeImagesEdits {
- // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
- // 如果用户使用表单,则需要解析表单数据
- if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
- aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request)
- if err != nil {
- return nil, fmt.Errorf("convert image edit form request failed: %w", err)
- }
- return aliRequest, nil
- } else {
- aliRequest, err := oaiImage2Ali(request)
- if err != nil {
- return nil, fmt.Errorf("convert image request failed: %w", err)
- }
- return aliRequest, nil
- }
- }
- return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return ConvertRerankRequest(request), nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- if info.IsStream {
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- } else {
- return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
- }
- default:
- switch info.RelayMode {
- case constant.RelayModeImagesGenerations:
- err, usage = aliImageHandler(c, resp, info)
- case constant.RelayModeImagesEdits:
- err, usage = aliImageEditHandler(c, resp, info)
- case constant.RelayModeRerank:
- err, usage = RerankHandler(c, resp, info)
- default:
- adaptor := openai.Adaptor{}
- usage, err = adaptor.DoResponse(c, resp, info)
- }
- return usage, err
- }
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/ali/constants.go b/new-api/relay/channel/ali/constants.go
deleted file mode 100644
index e9a22a0c5276e6c8f207213153c98c2e93821e39..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/constants.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package ali
-
-var ModelList = []string{
- "qwen-turbo",
- "qwen-plus",
- "qwen-max",
- "qwen-max-longcontext",
- "qwq-32b",
- "qwen3-235b-a22b",
- "text-embedding-v1",
- "gte-rerank-v2",
-}
-
-var ChannelName = "ali"
diff --git a/new-api/relay/channel/ali/dto.go b/new-api/relay/channel/ali/dto.go
deleted file mode 100644
index ff7c34ce976f2f7155079432df5587aabb7ee2aa..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/dto.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package ali
-
-import "one-api/dto"
-
-type AliMessage struct {
- Content any `json:"content"`
- Role string `json:"role"`
-}
-
-type AliMediaContent struct {
- Image string `json:"image,omitempty"`
- Text string `json:"text,omitempty"`
-}
-
-type AliInput struct {
- Prompt string `json:"prompt,omitempty"`
- //History []AliMessage `json:"history,omitempty"`
- Messages []AliMessage `json:"messages"`
-}
-
-type AliParameters struct {
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Seed uint64 `json:"seed,omitempty"`
- EnableSearch bool `json:"enable_search,omitempty"`
- IncrementalOutput bool `json:"incremental_output,omitempty"`
-}
-
-type AliChatRequest struct {
- Model string `json:"model"`
- Input AliInput `json:"input,omitempty"`
- Parameters AliParameters `json:"parameters,omitempty"`
-}
-
-type AliEmbeddingRequest struct {
- Model string `json:"model"`
- Input struct {
- Texts []string `json:"texts"`
- } `json:"input"`
- Parameters *struct {
- TextType string `json:"text_type,omitempty"`
- } `json:"parameters,omitempty"`
-}
-
-type AliEmbedding struct {
- Embedding []float64 `json:"embedding"`
- TextIndex int `json:"text_index"`
-}
-
-type AliEmbeddingResponse struct {
- Output struct {
- Embeddings []AliEmbedding `json:"embeddings"`
- } `json:"output"`
- Usage AliUsage `json:"usage"`
- AliError
-}
-
-type AliError struct {
- Code string `json:"code"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
-}
-
-type AliUsage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-type TaskResult struct {
- B64Image string `json:"b64_image,omitempty"`
- Url string `json:"url,omitempty"`
- Code string `json:"code,omitempty"`
- Message string `json:"message,omitempty"`
-}
-
-type AliOutput struct {
- TaskId string `json:"task_id,omitempty"`
- TaskStatus string `json:"task_status,omitempty"`
- Text string `json:"text"`
- FinishReason string `json:"finish_reason"`
- Message string `json:"message,omitempty"`
- Code string `json:"code,omitempty"`
- Results []TaskResult `json:"results,omitempty"`
- Choices []map[string]any `json:"choices,omitempty"`
-}
-
-type AliResponse struct {
- Output AliOutput `json:"output"`
- Usage AliUsage `json:"usage"`
- AliError
-}
-
-type AliImageRequest struct {
- Model string `json:"model"`
- Input any `json:"input"`
- Parameters any `json:"parameters,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
-}
-
-type AliImageParameters struct {
- Size string `json:"size,omitempty"`
- N int `json:"n,omitempty"`
- Steps string `json:"steps,omitempty"`
- Scale string `json:"scale,omitempty"`
- Watermark *bool `json:"watermark,omitempty"`
-}
-
-type AliImageInput struct {
- Prompt string `json:"prompt,omitempty"`
- NegativePrompt string `json:"negative_prompt,omitempty"`
- Messages []AliMessage `json:"messages,omitempty"`
-}
-
-type AliRerankParameters struct {
- TopN *int `json:"top_n,omitempty"`
- ReturnDocuments *bool `json:"return_documents,omitempty"`
-}
-
-type AliRerankInput struct {
- Query string `json:"query"`
- Documents []any `json:"documents"`
-}
-
-type AliRerankRequest struct {
- Model string `json:"model"`
- Input AliRerankInput `json:"input"`
- Parameters AliRerankParameters `json:"parameters,omitempty"`
-}
-
-type AliRerankResponse struct {
- Output struct {
- Results []dto.RerankResponseResult `json:"results"`
- } `json:"output"`
- Usage AliUsage `json:"usage"`
- RequestId string `json:"request_id"`
- AliError
-}
diff --git a/new-api/relay/channel/ali/image.go b/new-api/relay/channel/ali/image.go
deleted file mode 100644
index e73e3d27e9bfb2e3f25a1dbcb90316228cd2e68c..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/image.go
+++ /dev/null
@@ -1,339 +0,0 @@
-package ali
-
-import (
- "context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
- var imageRequest AliImageRequest
- imageRequest.Model = request.Model
- imageRequest.ResponseFormat = request.ResponseFormat
- logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
- if request.Extra != nil {
- if val, ok := request.Extra["parameters"]; ok {
- err := common.Unmarshal(val, &imageRequest.Parameters)
- if err != nil {
- return nil, fmt.Errorf("invalid parameters field: %w", err)
- }
- }
- if val, ok := request.Extra["input"]; ok {
- err := common.Unmarshal(val, &imageRequest.Input)
- if err != nil {
- return nil, fmt.Errorf("invalid input field: %w", err)
- }
- }
- }
-
- if imageRequest.Parameters == nil {
- imageRequest.Parameters = AliImageParameters{
- Size: strings.Replace(request.Size, "x", "*", -1),
- N: int(request.N),
- Watermark: request.Watermark,
- }
- }
-
- if imageRequest.Input == nil {
- imageRequest.Input = AliImageInput{
- Prompt: request.Prompt,
- }
- }
-
- return &imageRequest, nil
-}
-
-func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
- var imageRequest AliImageRequest
- imageRequest.Model = request.Model
- imageRequest.ResponseFormat = request.ResponseFormat
-
- mf := c.Request.MultipartForm
- if mf == nil {
- if _, err := c.MultipartForm(); err != nil {
- return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
- }
- mf = c.Request.MultipartForm
- }
-
- var imageFiles []*multipart.FileHeader
- var exists bool
-
- // First check for standard "image" field
- if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
- // If not found, check for "image[]" field
- if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
- // If still not found, iterate through all fields to find any that start with "image["
- foundArrayImages := false
- for fieldName, files := range mf.File {
- if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
- foundArrayImages = true
- imageFiles = append(imageFiles, files...)
- }
- }
-
- // If no image fields found at all
- if !foundArrayImages && (len(imageFiles) == 0) {
- return nil, errors.New("image is required")
- }
- }
- }
-
- if len(imageFiles) == 0 {
- return nil, errors.New("image is required")
- }
-
- if len(imageFiles) > 1 {
- return nil, errors.New("only one image is supported for qwen edit")
- }
-
- // 获取base64编码的图片
- var imageBase64s []string
- for _, file := range imageFiles {
- image, err := file.Open()
- if err != nil {
- return nil, errors.New("failed to open image file")
- }
-
- // 读取文件内容
- imageData, err := io.ReadAll(image)
- if err != nil {
- return nil, errors.New("failed to read image file")
- }
-
- // 获取MIME类型
- mimeType := http.DetectContentType(imageData)
-
- // 编码为base64
- base64Data := base64.StdEncoding.EncodeToString(imageData)
-
- // 构造data URL格式
- dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
- imageBase64s = append(imageBase64s, dataURL)
- image.Close()
- }
-
- //dto.MediaContent{}
- mediaContents := make([]AliMediaContent, len(imageBase64s))
- for i, b64 := range imageBase64s {
- mediaContents[i] = AliMediaContent{
- Image: b64,
- }
- }
- mediaContents = append(mediaContents, AliMediaContent{
- Text: request.Prompt,
- })
- imageRequest.Input = AliImageInput{
- Messages: []AliMessage{
- {
- Role: "user",
- Content: mediaContents,
- },
- },
- }
- imageRequest.Parameters = AliImageParameters{
- Watermark: request.Watermark,
- }
- return &imageRequest, nil
-}
-
-func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
- url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
-
- var aliResponse AliResponse
-
- req, err := http.NewRequest("GET", url, nil)
- if err != nil {
- return &aliResponse, err, nil
- }
-
- req.Header.Set("Authorization", "Bearer "+info.ApiKey)
-
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- common.SysLog("updateTask client.Do err: " + err.Error())
- return &aliResponse, err, nil
- }
- defer resp.Body.Close()
-
- responseBody, err := io.ReadAll(resp.Body)
-
- var response AliResponse
- err = common.Unmarshal(responseBody, &response)
- if err != nil {
- common.SysLog("updateTask NewDecoder err: " + err.Error())
- return &aliResponse, err, nil
- }
-
- return &response, nil, responseBody
-}
-
-func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
- waitSeconds := 10
- step := 0
- maxStep := 20
-
- var taskResponse AliResponse
- var responseBody []byte
-
- for {
- logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
- step++
- rsp, err, body := updateTask(info, taskID)
- responseBody = body
- if err != nil {
- logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
- time.Sleep(time.Duration(waitSeconds) * time.Second)
- continue
- }
-
- if rsp.Output.TaskStatus == "" {
- return &taskResponse, responseBody, nil
- }
-
- switch rsp.Output.TaskStatus {
- case "FAILED":
- fallthrough
- case "CANCELED":
- fallthrough
- case "SUCCEEDED":
- fallthrough
- case "UNKNOWN":
- return rsp, responseBody, nil
- }
- if step >= maxStep {
- break
- }
- time.Sleep(time.Duration(waitSeconds) * time.Second)
- }
-
- return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
-}
-
-func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
- imageResponse := dto.ImageResponse{
- Created: info.StartTime.Unix(),
- }
-
- for _, data := range response.Output.Results {
- var b64Json string
- if responseFormat == "b64_json" {
- _, b64, err := service.GetImageFromUrl(data.Url)
- if err != nil {
- logger.LogError(c, "get_image_data_failed: "+err.Error())
- continue
- }
- b64Json = b64
- } else {
- b64Json = data.B64Image
- }
-
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{
- Url: data.Url,
- B64Json: b64Json,
- RevisedPrompt: "",
- })
- }
- var mapResponse map[string]any
- _ = common.Unmarshal(originBody, &mapResponse)
- imageResponse.Extra = mapResponse
- return &imageResponse
-}
-
-func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
- responseFormat := c.GetString("response_format")
-
- var aliTaskResponse AliResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- service.CloseResponseBodyGracefully(resp)
- err = common.Unmarshal(responseBody, &aliTaskResponse)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
- }
-
- if aliTaskResponse.Message != "" {
- logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
- return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
- }
-
- aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponse), nil
- }
-
- if aliResponse.Output.TaskStatus != "SUCCEEDED" {
- return types.WithOpenAIError(types.OpenAIError{
- Message: aliResponse.Output.Message,
- Type: "ali_error",
- Param: "",
- Code: aliResponse.Output.Code,
- }, resp.StatusCode), nil
- }
-
- fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return nil, &dto.Usage{}
-}
-
-func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
- var aliResponse AliResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
-
- service.CloseResponseBodyGracefully(resp)
- err = common.Unmarshal(responseBody, &aliResponse)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
- }
-
- if aliResponse.Message != "" {
- logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
- return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
- }
- var fullTextResponse dto.ImageResponse
- if len(aliResponse.Output.Choices) > 0 {
- fullTextResponse = dto.ImageResponse{
- Created: info.StartTime.Unix(),
- Data: []dto.ImageData{
- {
- Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
- B64Json: "",
- },
- },
- }
- }
-
- var mapResponse map[string]any
- _ = common.Unmarshal(responseBody, &mapResponse)
- fullTextResponse.Extra = mapResponse
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return nil, &dto.Usage{}
-}
diff --git a/new-api/relay/channel/ali/rerank.go b/new-api/relay/channel/ali/rerank.go
deleted file mode 100644
index 084c2f70e015f0304959f214c3a3e0cbcd7b8e54..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/rerank.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package ali
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
- returnDocuments := request.ReturnDocuments
- if returnDocuments == nil {
- t := true
- returnDocuments = &t
- }
- return &AliRerankRequest{
- Model: request.Model,
- Input: AliRerankInput{
- Query: request.Query,
- Documents: request.Documents,
- },
- Parameters: AliRerankParameters{
- TopN: &request.TopN,
- ReturnDocuments: returnDocuments,
- },
- }
-}
-
-func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- service.CloseResponseBodyGracefully(resp)
-
- var aliResponse AliRerankResponse
- err = json.Unmarshal(responseBody, &aliResponse)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
- }
-
- if aliResponse.Code != "" {
- return types.WithOpenAIError(types.OpenAIError{
- Message: aliResponse.Message,
- Type: aliResponse.Code,
- Param: aliResponse.RequestId,
- Code: aliResponse.Code,
- }, resp.StatusCode), nil
- }
-
- usage := dto.Usage{
- PromptTokens: aliResponse.Usage.TotalTokens,
- CompletionTokens: 0,
- TotalTokens: aliResponse.Usage.TotalTokens,
- }
- rerankResponse := dto.RerankResponse{
- Results: aliResponse.Output.Results,
- Usage: usage,
- }
-
- jsonResponse, err := json.Marshal(rerankResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- c.Writer.Write(jsonResponse)
- return nil, &usage
-}
diff --git a/new-api/relay/channel/ali/text.go b/new-api/relay/channel/ali/text.go
deleted file mode 100644
index 43a925b5694cbc9b1856aa3d3ca667925de52b1f..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ali/text.go
+++ /dev/null
@@ -1,207 +0,0 @@
-package ali
-
-import (
- "bufio"
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/helper"
- "one-api/service"
- "strings"
-
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
-
-const EnableSearchModelSuffix = "-internet"
-
-func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
- if request.TopP >= 1 {
- request.TopP = 0.999
- } else if request.TopP <= 0 {
- request.TopP = 0.001
- }
- return &request
-}
-
-func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
- return &AliEmbeddingRequest{
- Model: request.Model,
- Input: struct {
- Texts []string `json:"texts"`
- }{
- Texts: request.ParseInput(),
- },
- }
-}
-
-func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var fullTextResponse dto.FlexibleEmbeddingResponse
- err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
- }
-
- service.CloseResponseBodyGracefully(resp)
-
- model := c.GetString("model")
- if model == "" {
- model = "text-embedding-v4"
- }
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
-}
-
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
- openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
- Model: model,
- Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
- }
-
- for _, item := range response.Output.Embeddings {
- openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
- Object: `embedding`,
- Index: item.TextIndex,
- Embedding: item.Embedding,
- })
- }
- return &openAIEmbeddingResponse
-}
-
-func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: response.Output.Text,
- },
- FinishReason: response.Output.FinishReason,
- }
- fullTextResponse := dto.OpenAITextResponse{
- Id: response.RequestId,
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []dto.OpenAITextResponseChoice{choice},
- Usage: dto.Usage{
- PromptTokens: response.Usage.InputTokens,
- CompletionTokens: response.Usage.OutputTokens,
- TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
- },
- }
- return &fullTextResponse
-}
-
-func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString(aliResponse.Output.Text)
- if aliResponse.Output.FinishReason != "null" {
- finishReason := aliResponse.Output.FinishReason
- choice.FinishReason = &finishReason
- }
- response := dto.ChatCompletionsStreamResponse{
- Id: aliResponse.RequestId,
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "ernie-bot",
- Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
- }
- return &response
-}
-
-func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var usage dto.Usage
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 { // ignore blank line or wrong format
- continue
- }
- if data[:5] != "data:" {
- continue
- }
- data = data[5:]
- dataChan <- data
- }
- stopChan <- true
- }()
- helper.SetEventStreamHeaders(c)
- lastResponseText := ""
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var aliResponse AliResponse
- err := json.Unmarshal([]byte(data), &aliResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
- if aliResponse.Usage.OutputTokens != 0 {
- usage.PromptTokens = aliResponse.Usage.InputTokens
- usage.CompletionTokens = aliResponse.Usage.OutputTokens
- usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
- }
- response := streamResponseAli2OpenAI(&aliResponse)
- response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
- lastResponseText = aliResponse.Output.Text
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- service.CloseResponseBodyGracefully(resp)
- return nil, &usage
-}
-
-func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var aliResponse AliResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &aliResponse)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
- }
- if aliResponse.Code != "" {
- return types.WithOpenAIError(types.OpenAIError{
- Message: aliResponse.Message,
- Type: "ali_error",
- Param: aliResponse.RequestId,
- Code: aliResponse.Code,
- }, resp.StatusCode), nil
- }
- fullTextResponse := responseAli2OpenAI(&aliResponse)
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
-}
diff --git a/new-api/relay/channel/api_request.go b/new-api/relay/channel/api_request.go
deleted file mode 100644
index f69f1793419345853d24db319d4866ba09b84086..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/api_request.go
+++ /dev/null
@@ -1,302 +0,0 @@
-package channel
-
-import (
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- common2 "one-api/common"
- "one-api/logger"
- "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/operation_setting"
- "one-api/types"
- "sync"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
- if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
- // multipart/form-data
- } else if info.RelayMode == constant.RelayModeRealtime {
- // websocket
- } else {
- req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Set("Accept", c.Request.Header.Get("Accept"))
- if info.IsStream && c.Request.Header.Get("Accept") == "" {
- req.Set("Accept", "text/event-stream")
- }
- }
-}
-
-func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- fullRequestURL, err := a.GetRequestURL(info)
- if err != nil {
- return nil, fmt.Errorf("get request url failed: %w", err)
- }
- if common2.DebugEnabled {
- println("fullRequestURL:", fullRequestURL)
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- headers := req.Header
- headerOverride := make(map[string]string)
- for k, v := range info.HeadersOverride {
- if str, ok := v.(string); ok {
- headerOverride[k] = str
- } else {
- return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
- }
- }
- for key, value := range headerOverride {
- headers.Set(key, value)
- }
- err = a.SetupRequestHeader(c, &headers, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := doRequest(c, req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
-}
-
-func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- fullRequestURL, err := a.GetRequestURL(info)
- if err != nil {
- return nil, fmt.Errorf("get request url failed: %w", err)
- }
- if common2.DebugEnabled {
- println("fullRequestURL:", fullRequestURL)
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- // set form data
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- headers := req.Header
- headerOverride := make(map[string]string)
- for k, v := range info.HeadersOverride {
- if str, ok := v.(string); ok {
- headerOverride[k] = str
- } else {
- return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
- }
- }
- for key, value := range headerOverride {
- headers.Set(key, value)
- }
- err = a.SetupRequestHeader(c, &headers, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := doRequest(c, req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
-}
-
-func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
- fullRequestURL, err := a.GetRequestURL(info)
- if err != nil {
- return nil, fmt.Errorf("get request url failed: %w", err)
- }
- targetHeader := http.Header{}
- err = a.SetupRequestHeader(c, &targetHeader, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
- if err != nil {
- return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
- }
- // send request body
- //all, err := io.ReadAll(requestBody)
- //err = service.WssString(c, targetConn, string(all))
- return targetConn, nil
-}
-
-func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
- pingerCtx, stopPinger := context.WithCancel(context.Background())
-
- gopool.Go(func() {
- defer func() {
- // 增加panic恢复处理
- if r := recover(); r != nil {
- if common2.DebugEnabled {
- println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
- }
- }
- if common2.DebugEnabled {
- println("SSE ping goroutine stopped.")
- }
- }()
-
- if pingInterval <= 0 {
- pingInterval = helper.DefaultPingInterval
- }
-
- ticker := time.NewTicker(pingInterval)
- // 确保在任何情况下都清理ticker
- defer func() {
- ticker.Stop()
- if common2.DebugEnabled {
- println("SSE ping ticker stopped")
- }
- }()
-
- var pingMutex sync.Mutex
- if common2.DebugEnabled {
- println("SSE ping goroutine started")
- }
-
- // 增加超时控制,防止goroutine长时间运行
- maxPingDuration := 120 * time.Minute // 最大ping持续时间
- pingTimeout := time.NewTimer(maxPingDuration)
- defer pingTimeout.Stop()
-
- for {
- select {
- // 发送 ping 数据
- case <-ticker.C:
- if err := sendPingData(c, &pingMutex); err != nil {
- if common2.DebugEnabled {
- println("SSE ping error, stopping goroutine:", err.Error())
- }
- return
- }
- // 收到退出信号
- case <-pingerCtx.Done():
- return
- // request 结束
- case <-c.Request.Context().Done():
- return
- // 超时保护,防止goroutine无限运行
- case <-pingTimeout.C:
- if common2.DebugEnabled {
- println("SSE ping goroutine timeout, stopping")
- }
- return
- }
- }
- })
-
- return stopPinger
-}
-
-func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
- // 增加超时控制,防止锁死等待
- done := make(chan error, 1)
- go func() {
- mutex.Lock()
- defer mutex.Unlock()
-
- err := helper.PingData(c)
- if err != nil {
- logger.LogError(c, "SSE ping error: "+err.Error())
- done <- err
- return
- }
-
- if common2.DebugEnabled {
- println("SSE ping data sent.")
- }
- done <- nil
- }()
-
- // 设置发送ping数据的超时时间
- select {
- case err := <-done:
- return err
- case <-time.After(10 * time.Second):
- return errors.New("SSE ping data send timeout")
- case <-c.Request.Context().Done():
- return errors.New("request context cancelled during ping")
- }
-}
-
-func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
- return doRequest(c, req, info)
-}
-func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
- var client *http.Client
- var err error
- if info.ChannelSetting.Proxy != "" {
- client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
- if err != nil {
- return nil, fmt.Errorf("new proxy http client failed: %w", err)
- }
- } else {
- client = service.GetHttpClient()
- }
-
- var stopPinger context.CancelFunc
- if info.IsStream {
- helper.SetEventStreamHeaders(c)
- // 处理流式请求的 ping 保活
- generalSettings := operation_setting.GetGeneralSetting()
- if generalSettings.PingIntervalEnabled && !info.DisablePing {
- pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
- stopPinger = startPingKeepAlive(c, pingInterval)
- // 使用defer确保在任何情况下都能停止ping goroutine
- defer func() {
- if stopPinger != nil {
- stopPinger()
- if common2.DebugEnabled {
- println("SSE ping goroutine stopped by defer")
- }
- }
- }()
- }
- }
-
- resp, err := client.Do(req)
- if err != nil {
- logger.LogError(c, "do request failed: "+err.Error())
- return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
- }
- if resp == nil {
- return nil, errors.New("resp is nil")
- }
-
- _ = req.Body.Close()
- _ = c.Request.Body.Close()
- return resp, nil
-}
-
-func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- fullRequestURL, err := a.BuildRequestURL(info)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- req.GetBody = func() (io.ReadCloser, error) {
- return io.NopCloser(requestBody), nil
- }
-
- err = a.BuildRequestHeader(c, req, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := doRequest(c, req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
-}
diff --git a/new-api/relay/channel/aws/adaptor.go b/new-api/relay/channel/aws/adaptor.go
deleted file mode 100644
index 2d265d24249b002390cec1757e25e00b3fc64e46..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/aws/adaptor.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package aws
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel/claude"
- relaycommon "one-api/relay/common"
- "one-api/setting/model_setting"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- RequestModeCompletion = 1
- RequestModeMessage = 2
-)
-
-type Adaptor struct {
- RequestMode int
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- c.Set("request_model", request.Model)
- c.Set("converted_request", request)
- return request, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- a.RequestMode = RequestModeMessage
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return "", nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- anthropicBeta := c.Request.Header.Get("anthropic-beta")
- if anthropicBeta != "" {
- req.Set("anthropic-beta", anthropicBeta)
- }
- model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- // 检查是否为Nova模型
- if isNovaModel(request.Model) {
- novaReq := convertToNovaRequest(request)
- c.Set("request_model", request.Model)
- c.Set("converted_request", novaReq)
- c.Set("is_nova_model", true)
- return novaReq, nil
- }
-
- // 原有的Claude模型处理逻辑
- var claudeReq *dto.ClaudeRequest
- var err error
- claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
- if err != nil {
- return nil, err
- }
- c.Set("request_model", claudeReq.Model)
- c.Set("converted_request", claudeReq)
- c.Set("is_nova_model", false)
- return claudeReq, err
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
- } else {
- err, usage = awsHandler(c, info, a.RequestMode)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() (models []string) {
- for n := range awsModelIDMap {
- models = append(models, n)
- }
-
- return
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/aws/constants.go b/new-api/relay/channel/aws/constants.go
deleted file mode 100644
index 01092243f063744bc188ed87fbc79457c4e4195c..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/aws/constants.go
+++ /dev/null
@@ -1,128 +0,0 @@
-package aws
-
-import "strings"
-
-var awsModelIDMap = map[string]string{
- "claude-instant-1.2": "anthropic.claude-instant-v1",
- "claude-2.0": "anthropic.claude-v2",
- "claude-2.1": "anthropic.claude-v2:1",
- "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
- "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
- "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
- "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
- "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
- "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
- "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
- "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
- "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
- "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
- "claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
- // Nova models
- "nova-micro-v1:0": "amazon.nova-micro-v1:0",
- "nova-lite-v1:0": "amazon.nova-lite-v1:0",
- "nova-pro-v1:0": "amazon.nova-pro-v1:0",
- "nova-premier-v1:0": "amazon.nova-premier-v1:0",
- "nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
- "nova-reel-v1:0": "amazon.nova-reel-v1:0",
- "nova-reel-v1:1": "amazon.nova-reel-v1:1",
- "nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
-}
-
-var awsModelCanCrossRegionMap = map[string]map[string]bool{
- "anthropic.claude-3-sonnet-20240229-v1:0": {
- "us": true,
- "eu": true,
- "ap": true,
- },
- "anthropic.claude-3-opus-20240229-v1:0": {
- "us": true,
- },
- "anthropic.claude-3-haiku-20240307-v1:0": {
- "us": true,
- "eu": true,
- "ap": true,
- },
- "anthropic.claude-3-5-sonnet-20240620-v1:0": {
- "us": true,
- "eu": true,
- "ap": true,
- },
- "anthropic.claude-3-5-sonnet-20241022-v2:0": {
- "us": true,
- "ap": true,
- },
- "anthropic.claude-3-5-haiku-20241022-v1:0": {
- "us": true,
- },
- "anthropic.claude-3-7-sonnet-20250219-v1:0": {
- "us": true,
- "ap": true,
- "eu": true,
- },
- "anthropic.claude-sonnet-4-20250514-v1:0": {
- "us": true,
- "ap": true,
- "eu": true,
- },
- "anthropic.claude-opus-4-20250514-v1:0": {
- "us": true,
- },
- "anthropic.claude-opus-4-1-20250805-v1:0": {
- "us": true,
- },
- "anthropic.claude-sonnet-4-5-20250929-v1:0": {
- "us": true,
- "ap": true,
- "eu": true,
- },
- // Nova models - all support three major regions
- "amazon.nova-micro-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
- "amazon.nova-lite-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
- "amazon.nova-pro-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
- "amazon.nova-premier-v1:0": {
- "us": true,
- },
- "amazon.nova-canvas-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
- "amazon.nova-reel-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
- "amazon.nova-reel-v1:1": {
- "us": true,
- },
- "amazon.nova-sonic-v1:0": {
- "us": true,
- "eu": true,
- "apac": true,
- },
-}
-
-var awsRegionCrossModelPrefixMap = map[string]string{
- "us": "us",
- "eu": "eu",
- "ap": "apac",
-}
-
-var ChannelName = "aws"
-
-// 判断是否为Nova模型
-func isNovaModel(modelId string) bool {
- return strings.HasPrefix(modelId, "nova-")
-}
diff --git a/new-api/relay/channel/aws/dto.go b/new-api/relay/channel/aws/dto.go
deleted file mode 100644
index 0f506713df0c1de2626abc136f766689193d29cf..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/aws/dto.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package aws
-
-import (
- "one-api/dto"
-)
-
-type AwsClaudeRequest struct {
- // AnthropicVersion should be "bedrock-2023-05-31"
- AnthropicVersion string `json:"anthropic_version"`
- System any `json:"system,omitempty"`
- Messages []dto.ClaudeMessage `json:"messages"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- StopSequences []string `json:"stop_sequences,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- Thinking *dto.Thinking `json:"thinking,omitempty"`
-}
-
-func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
- return &AwsClaudeRequest{
- AnthropicVersion: "bedrock-2023-05-31",
- System: req.System,
- Messages: req.Messages,
- MaxTokens: req.MaxTokens,
- Temperature: req.Temperature,
- TopP: req.TopP,
- TopK: req.TopK,
- StopSequences: req.StopSequences,
- Tools: req.Tools,
- ToolChoice: req.ToolChoice,
- Thinking: req.Thinking,
- }
-}
-
-// NovaMessage Nova模型使用messages-v1格式
-type NovaMessage struct {
- Role string `json:"role"`
- Content []NovaContent `json:"content"`
-}
-
-type NovaContent struct {
- Text string `json:"text"`
-}
-
-type NovaRequest struct {
- SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
- Messages []NovaMessage `json:"messages"` // 对话消息列表
- InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
-}
-
-type NovaInferenceConfig struct {
- MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
- Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
- TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
- TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
- StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
-}
-
-// 转换OpenAI请求为Nova格式
-func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
- novaMessages := make([]NovaMessage, len(req.Messages))
- for i, msg := range req.Messages {
- novaMessages[i] = NovaMessage{
- Role: msg.Role,
- Content: []NovaContent{{Text: msg.StringContent()}},
- }
- }
-
- novaReq := &NovaRequest{
- SchemaVersion: "messages-v1",
- Messages: novaMessages,
- }
-
- // 设置推理配置
- if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
- novaReq.InferenceConfig = &NovaInferenceConfig{}
- if req.MaxTokens != 0 {
- novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
- }
- if req.Temperature != nil && *req.Temperature != 0 {
- novaReq.InferenceConfig.Temperature = *req.Temperature
- }
- if req.TopP != 0 {
- novaReq.InferenceConfig.TopP = req.TopP
- }
- if req.TopK != 0 {
- novaReq.InferenceConfig.TopK = req.TopK
- }
- if req.Stop != nil {
- if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
- novaReq.InferenceConfig.StopSequences = stopSequences
- }
- }
- }
-
- return novaReq
-}
-
-// parseStopSequences 解析停止序列,支持字符串或字符串数组
-func parseStopSequences(stop any) []string {
- if stop == nil {
- return nil
- }
-
- switch v := stop.(type) {
- case string:
- if v != "" {
- return []string{v}
- }
- case []string:
- return v
- case []interface{}:
- var sequences []string
- for _, item := range v {
- if str, ok := item.(string); ok && str != "" {
- sequences = append(sequences, str)
- }
- }
- return sequences
- }
- return nil
-}
diff --git a/new-api/relay/channel/aws/relay-aws.go b/new-api/relay/channel/aws/relay-aws.go
deleted file mode 100644
index 95d1725fce4bb57ab810db8f4d73b15f51fb45e5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/aws/relay-aws.go
+++ /dev/null
@@ -1,295 +0,0 @@
-package aws
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel/claude"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/credentials"
- "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
- bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
- "github.com/aws/smithy-go/auth/bearer"
-)
-
-func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
- awsSecret := strings.Split(info.ApiKey, "|")
- var client *bedrockruntime.Client
- switch len(awsSecret) {
- case 2:
- apiKey := awsSecret[0]
- region := awsSecret[1]
- client = bedrockruntime.New(bedrockruntime.Options{
- Region: region,
- BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
- })
- case 3:
- ak := awsSecret[0]
- sk := awsSecret[1]
- region := awsSecret[2]
- client = bedrockruntime.New(bedrockruntime.Options{
- Region: region,
- Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
- })
- default:
- return nil, errors.New("invalid aws secret key")
- }
-
- return client, nil
-}
-
-func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
- return &dto.OpenAIErrorWithStatusCode{
- StatusCode: http.StatusInternalServerError,
- Error: dto.OpenAIError{
- Message: fmt.Sprintf("%s", err.Error()),
- },
- }
-}
-
-func awsRegionPrefix(awsRegionId string) string {
- parts := strings.Split(awsRegionId, "-")
- regionPrefix := ""
- if len(parts) > 0 {
- regionPrefix = parts[0]
- }
- return regionPrefix
-}
-
-func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
- regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
- return exists && regionSet[awsRegionPrefix]
-}
-
-func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
- modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
- if !find {
- return awsModelId
- }
- return modelPrefix + "." + awsModelId
-}
-
-func awsModelID(requestModel string) string {
- if awsModelID, ok := awsModelIDMap[requestModel]; ok {
- return awsModelID
- }
-
- return requestModel
-}
-
-func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
- awsCli, err := newAwsClient(c, info)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
- }
-
- awsModelId := awsModelID(c.GetString("request_model"))
- // 检查是否为Nova模型
- isNova, _ := c.Get("is_nova_model")
- if isNova == true {
- // Nova模型也支持跨区域
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
- if canCrossRegion {
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
- }
- return handleNovaRequest(c, awsCli, info, awsModelId)
- }
-
- // 原有的Claude处理逻辑
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
- if canCrossRegion {
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
- }
-
- awsReq := &bedrockruntime.InvokeModelInput{
- ModelId: aws.String(awsModelId),
- Accept: aws.String("application/json"),
- ContentType: aws.String("application/json"),
- }
-
- claudeReq_, ok := c.Get("converted_request")
- if !ok {
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
- }
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
- awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = common.Marshal(awsClaudeReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
- }
-
- awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
- if err != nil {
- return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
- }
-
- claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
-
- // 复制上游 Content-Type 到客户端响应头
- if awsResp.ContentType != nil && *awsResp.ContentType != "" {
- c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
- }
-
- handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
- if handlerErr != nil {
- return handlerErr, nil
- }
- return nil, claudeInfo.Usage
-}
-
-func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
- awsCli, err := newAwsClient(c, info)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
- }
-
- awsModelId := awsModelID(c.GetString("request_model"))
-
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
- if canCrossRegion {
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
- }
-
- awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
- ModelId: aws.String(awsModelId),
- Accept: aws.String("application/json"),
- ContentType: aws.String("application/json"),
- }
-
- claudeReq_, ok := c.Get("converted_request")
- if !ok {
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
- }
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
-
- awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = common.Marshal(awsClaudeReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
- }
-
- awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
- if err != nil {
- return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
- }
- stream := awsResp.GetStream()
- defer stream.Close()
-
- claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
-
- for event := range stream.Events() {
- switch v := event.(type) {
- case *bedrockruntimeTypes.ResponseStreamMemberChunk:
- info.SetFirstResponseTime()
- respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
- if respErr != nil {
- return respErr, nil
- }
- case *bedrockruntimeTypes.UnknownUnionMember:
- fmt.Println("unknown tag:", v.Tag)
- return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
- default:
- fmt.Println("union is nil or unknown type")
- return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
- }
- }
-
- claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
- return nil, claudeInfo.Usage
-}
-
-// Nova模型处理函数
-func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
- novaReq_, ok := c.Get("converted_request")
- if !ok {
- return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
- }
- novaReq := novaReq_.(*NovaRequest)
-
- // 使用InvokeModel API,但使用Nova格式的请求体
- awsReq := &bedrockruntime.InvokeModelInput{
- ModelId: aws.String(awsModelId),
- Accept: aws.String("application/json"),
- ContentType: aws.String("application/json"),
- }
-
- reqBody, err := json.Marshal(novaReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
- }
- awsReq.Body = reqBody
-
- awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
- }
-
- // 解析Nova响应
- var novaResp struct {
- Output struct {
- Message struct {
- Content []struct {
- Text string `json:"text"`
- } `json:"content"`
- } `json:"message"`
- } `json:"output"`
- Usage struct {
- InputTokens int `json:"inputTokens"`
- OutputTokens int `json:"outputTokens"`
- TotalTokens int `json:"totalTokens"`
- } `json:"usage"`
- }
-
- if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
- return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
- }
-
- // 构造OpenAI格式响应
- response := dto.OpenAITextResponse{
- Id: helper.GetResponseID(c),
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- Choices: []dto.OpenAITextResponseChoice{{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: novaResp.Output.Message.Content[0].Text,
- },
- FinishReason: "stop",
- }},
- Usage: dto.Usage{
- PromptTokens: novaResp.Usage.InputTokens,
- CompletionTokens: novaResp.Usage.OutputTokens,
- TotalTokens: novaResp.Usage.TotalTokens,
- },
- }
-
- c.JSON(http.StatusOK, response)
- return nil, &response.Usage
-}
diff --git a/new-api/relay/channel/baidu/adaptor.go b/new-api/relay/channel/baidu/adaptor.go
deleted file mode 100644
index c4211d20e001fc5e84a0abd59a497cf302768936..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu/adaptor.go
+++ /dev/null
@@ -1,169 +0,0 @@
-package baidu
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
- suffix := "chat/"
- if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
- suffix = "embeddings/"
- }
- if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
- suffix = "embeddings/"
- }
- if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
- suffix = "embeddings/"
- }
- switch info.UpstreamModelName {
- case "ERNIE-4.0":
- suffix += "completions_pro"
- case "ERNIE-Bot-4":
- suffix += "completions_pro"
- case "ERNIE-Bot":
- suffix += "completions"
- case "ERNIE-Bot-turbo":
- suffix += "eb-instant"
- case "ERNIE-Speed":
- suffix += "ernie_speed"
- case "ERNIE-4.0-8K":
- suffix += "completions_pro"
- case "ERNIE-3.5-8K":
- suffix += "completions"
- case "ERNIE-3.5-8K-0205":
- suffix += "ernie-3.5-8k-0205"
- case "ERNIE-3.5-8K-1222":
- suffix += "ernie-3.5-8k-1222"
- case "ERNIE-Bot-8K":
- suffix += "ernie_bot_8k"
- case "ERNIE-3.5-4K-0205":
- suffix += "ernie-3.5-4k-0205"
- case "ERNIE-Speed-8K":
- suffix += "ernie_speed"
- case "ERNIE-Speed-128K":
- suffix += "ernie-speed-128k"
- case "ERNIE-Lite-8K-0922":
- suffix += "eb-instant"
- case "ERNIE-Lite-8K-0308":
- suffix += "ernie-lite-8k"
- case "ERNIE-Tiny-8K":
- suffix += "ernie-tiny-8k"
- case "BLOOMZ-7B":
- suffix += "bloomz_7b1"
- case "Embedding-V1":
- suffix += "embedding-v1"
- case "bge-large-zh":
- suffix += "bge_large_zh"
- case "bge-large-en":
- suffix += "bge_large_en"
- case "tao-8k":
- suffix += "tao_8k"
- default:
- suffix += strings.ToLower(info.UpstreamModelName)
- }
- fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
- var accessToken string
- var err error
- if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
- return "", err
- }
- fullRequestURL += "?access_token=" + accessToken
- return fullRequestURL, nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- switch info.RelayMode {
- default:
- baiduRequest := requestOpenAI2Baidu(*request)
- return baiduRequest, nil
- }
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
- return baiduEmbeddingRequest, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- err, usage = baiduStreamHandler(c, info, resp)
- } else {
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- err, usage = baiduEmbeddingHandler(c, info, resp)
- default:
- err, usage = baiduHandler(c, info, resp)
- }
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/baidu/constants.go b/new-api/relay/channel/baidu/constants.go
deleted file mode 100644
index 847663ff2d5b99f08e724b0c37a81e40e2bbc368..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu/constants.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package baidu
-
-var ModelList = []string{
- "ERNIE-4.0-8K",
- "ERNIE-3.5-8K",
- "ERNIE-3.5-8K-0205",
- "ERNIE-3.5-8K-1222",
- "ERNIE-Bot-8K",
- "ERNIE-3.5-4K-0205",
- "ERNIE-Speed-8K",
- "ERNIE-Speed-128K",
- "ERNIE-Lite-8K-0922",
- "ERNIE-Lite-8K-0308",
- "ERNIE-Tiny-8K",
- "BLOOMZ-7B",
- "Embedding-V1",
- "bge-large-zh",
- "bge-large-en",
- "tao-8k",
-}
-
-var ChannelName = "baidu"
diff --git a/new-api/relay/channel/baidu/dto.go b/new-api/relay/channel/baidu/dto.go
deleted file mode 100644
index cd035cbf9cceb9707e6a3dcc4282274a32ecdd06..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu/dto.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package baidu
-
-import (
- "one-api/dto"
- "time"
-)
-
-type BaiduMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type BaiduChatRequest struct {
- Messages []BaiduMessage `json:"messages"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- PenaltyScore float64 `json:"penalty_score,omitempty"`
- Stream bool `json:"stream,omitempty"`
- System string `json:"system,omitempty"`
- DisableSearch bool `json:"disable_search,omitempty"`
- EnableCitation bool `json:"enable_citation,omitempty"`
- MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
- UserId string `json:"user_id,omitempty"`
-}
-
-type Error struct {
- ErrorCode int `json:"error_code"`
- ErrorMsg string `json:"error_msg"`
-}
-
-type BaiduChatResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Result string `json:"result"`
- IsTruncated bool `json:"is_truncated"`
- NeedClearHistory bool `json:"need_clear_history"`
- Usage dto.Usage `json:"usage"`
- Error
-}
-
-type BaiduChatStreamResponse struct {
- BaiduChatResponse
- SentenceId int `json:"sentence_id"`
- IsEnd bool `json:"is_end"`
-}
-
-type BaiduEmbeddingRequest struct {
- Input []string `json:"input"`
-}
-
-type BaiduEmbeddingData struct {
- Object string `json:"object"`
- Embedding []float64 `json:"embedding"`
- Index int `json:"index"`
-}
-
-type BaiduEmbeddingResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Data []BaiduEmbeddingData `json:"data"`
- Usage dto.Usage `json:"usage"`
- Error
-}
-
-type BaiduAccessToken struct {
- AccessToken string `json:"access_token"`
- Error string `json:"error,omitempty"`
- ErrorDescription string `json:"error_description,omitempty"`
- ExpiresIn int64 `json:"expires_in,omitempty"`
- ExpiresAt time.Time `json:"-"`
-}
-
-type BaiduTokenResponse struct {
- ExpiresIn int `json:"expires_in"`
- AccessToken string `json:"access_token"`
-}
diff --git a/new-api/relay/channel/baidu/relay-baidu.go b/new-api/relay/channel/baidu/relay-baidu.go
deleted file mode 100644
index 3431754345ae378fea571823c9ba6da5761dc3a5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu/relay-baidu.go
+++ /dev/null
@@ -1,245 +0,0 @@
-package baidu
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "sync"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
-
-var baiduTokenStore sync.Map
-
-func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
- baiduRequest := BaiduChatRequest{
- Temperature: request.Temperature,
- TopP: request.TopP,
- PenaltyScore: request.FrequencyPenalty,
- Stream: request.Stream,
- DisableSearch: false,
- EnableCitation: false,
- UserId: request.User,
- }
- if request.GetMaxTokens() != 0 {
- maxTokens := int(request.GetMaxTokens())
- if request.GetMaxTokens() == 1 {
- maxTokens = 2
- }
- baiduRequest.MaxOutputTokens = &maxTokens
- }
- for _, message := range request.Messages {
- if message.Role == "system" {
- baiduRequest.System = message.StringContent()
- } else {
- baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
- Role: message.Role,
- Content: message.StringContent(),
- })
- }
- }
- return &baiduRequest
-}
-
-func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: response.Result,
- },
- FinishReason: "stop",
- }
- fullTextResponse := dto.OpenAITextResponse{
- Id: response.Id,
- Object: "chat.completion",
- Created: response.Created,
- Choices: []dto.OpenAITextResponseChoice{choice},
- Usage: response.Usage,
- }
- return &fullTextResponse
-}
-
-func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString(baiduResponse.Result)
- if baiduResponse.IsEnd {
- choice.FinishReason = &constant.FinishReasonStop
- }
- response := dto.ChatCompletionsStreamResponse{
- Id: baiduResponse.Id,
- Object: "chat.completion.chunk",
- Created: baiduResponse.Created,
- Model: "ernie-bot",
- Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
- }
- return &response
-}
-
-func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
- return &BaiduEmbeddingRequest{
- Input: request.ParseInput(),
- }
-}
-
-func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
- openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
- Model: "baidu-embedding",
- Usage: response.Usage,
- }
- for _, item := range response.Data {
- openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
- Object: item.Object,
- Index: item.Index,
- Embedding: item.Embedding,
- })
- }
- return &openAIEmbeddingResponse
-}
-
-func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- usage := &dto.Usage{}
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var baiduResponse BaiduChatStreamResponse
- err := common.Unmarshal([]byte(data), &baiduResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
- if baiduResponse.Usage.TotalTokens != 0 {
- usage.TotalTokens = baiduResponse.Usage.TotalTokens
- usage.PromptTokens = baiduResponse.Usage.PromptTokens
- usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
- }
- response := streamResponseBaidu2OpenAI(&baiduResponse)
- err = helper.ObjectData(c, response)
- if err != nil {
- common.SysLog("error sending stream response: " + err.Error())
- }
- return true
- })
- service.CloseResponseBodyGracefully(resp)
- return nil, usage
-}
-
-func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var baiduResponse BaiduChatResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &baiduResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- if baiduResponse.ErrorMsg != "" {
- return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
- }
- fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
-}
-
-func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var baiduResponse BaiduEmbeddingResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &baiduResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- if baiduResponse.ErrorMsg != "" {
- return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
- }
- fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
-}
-
-func getBaiduAccessToken(apiKey string) (string, error) {
- if val, ok := baiduTokenStore.Load(apiKey); ok {
- var accessToken BaiduAccessToken
- if accessToken, ok = val.(BaiduAccessToken); ok {
- // soon this will expire
- if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
- go func() {
- _, _ = getBaiduAccessTokenHelper(apiKey)
- }()
- }
- return accessToken.AccessToken, nil
- }
- }
- accessToken, err := getBaiduAccessTokenHelper(apiKey)
- if err != nil {
- return "", err
- }
- if accessToken == nil {
- return "", errors.New("getBaiduAccessToken return a nil token")
- }
- return (*accessToken).AccessToken, nil
-}
-
-func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
- parts := strings.Split(apiKey, "|")
- if len(parts) != 2 {
- return nil, errors.New("invalid baidu apikey")
- }
- req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
- parts[0], parts[1]), nil)
- if err != nil {
- return nil, err
- }
- req.Header.Add("Content-Type", "application/json")
- req.Header.Add("Accept", "application/json")
- res, err := service.GetHttpClient().Do(req)
- if err != nil {
- return nil, err
- }
- defer res.Body.Close()
-
- var accessToken BaiduAccessToken
- err = json.NewDecoder(res.Body).Decode(&accessToken)
- if err != nil {
- return nil, err
- }
- if accessToken.Error != "" {
- return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
- }
- if accessToken.AccessToken == "" {
- return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
- }
- accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
- baiduTokenStore.Store(apiKey, accessToken)
- return &accessToken, nil
-}
diff --git a/new-api/relay/channel/baidu_v2/adaptor.go b/new-api/relay/channel/baidu_v2/adaptor.go
deleted file mode 100644
index 3a2fddbe10bb1220adc69378c87779757d501619..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu_v2/adaptor.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package baidu_v2
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
- return adaptor.ConvertClaudeRequest(c, info, req)
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
- return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
- case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
- case constant.RelayModeImagesGenerations:
- return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
- case constant.RelayModeImagesEdits:
- return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
- case constant.RelayModeRerank:
- return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
- default:
- }
- return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- keyParts := strings.Split(info.ApiKey, "|")
- if len(keyParts) == 0 || keyParts[0] == "" {
- return errors.New("invalid API key: authorization token is required")
- }
- if len(keyParts) > 1 {
- if keyParts[1] != "" {
- req.Set("appid", keyParts[1])
- }
- }
- req.Set("Authorization", "Bearer "+keyParts[0])
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if strings.HasSuffix(info.UpstreamModelName, "-search") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
- request.Model = info.UpstreamModelName
- if len(request.WebSearch) == 0 {
- toMap := request.ToMap()
- toMap["web_search"] = map[string]any{
- "enable": true,
- "enable_citation": true,
- "enable_trace": true,
- "enable_status": false,
- }
- return toMap, nil
- }
- return request, nil
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- adaptor := openai.Adaptor{}
- usage, err = adaptor.DoResponse(c, resp, info)
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/baidu_v2/constants.go b/new-api/relay/channel/baidu_v2/constants.go
deleted file mode 100644
index 1bd94e37d6efe84825a08fef6f18358b64b2252e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/baidu_v2/constants.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package baidu_v2
-
-var ModelList = []string{
- "ernie-4.0-8k-latest",
- "ernie-4.0-8k-preview",
- "ernie-4.0-8k",
- "ernie-4.0-turbo-8k-latest",
- "ernie-4.0-turbo-8k-preview",
- "ernie-4.0-turbo-8k",
- "ernie-4.0-turbo-128k",
- "ernie-3.5-8k-preview",
- "ernie-3.5-8k",
- "ernie-3.5-128k",
- "ernie-speed-8k",
- "ernie-speed-128k",
- "ernie-speed-pro-128k",
- "ernie-lite-8k",
- "ernie-lite-pro-128k",
- "ernie-tiny-8k",
- "ernie-char-8k",
- "ernie-char-fiction-8k",
- "ernie-novel-8k",
- "deepseek-v3",
- "deepseek-r1",
- "deepseek-r1-distill-qwen-32b",
- "deepseek-r1-distill-qwen-14b",
-}
-
-var ChannelName = "volcengine"
diff --git a/new-api/relay/channel/claude/adaptor.go b/new-api/relay/channel/claude/adaptor.go
deleted file mode 100644
index 1bd112dec20ca22b0803032c4729776ec3fdd14d..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/claude/adaptor.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package claude
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- RequestModeCompletion = 1
- RequestModeMessage = 2
-)
-
-type Adaptor struct {
- RequestMode int
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
- a.RequestMode = RequestModeCompletion
- } else {
- a.RequestMode = RequestModeMessage
- }
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- baseURL := ""
- if a.RequestMode == RequestModeMessage {
- baseURL = fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
- } else {
- baseURL = fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl)
- }
- if info.IsClaudeBetaQuery {
- baseURL = baseURL + "?beta=true"
- }
- return baseURL, nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("x-api-key", info.ApiKey)
- anthropicVersion := c.Request.Header.Get("anthropic-version")
- if anthropicVersion == "" {
- anthropicVersion = "2023-06-01"
- }
- req.Set("anthropic-version", anthropicVersion)
- anthropicBeta := c.Request.Header.Get("anthropic-beta")
- if anthropicBeta != "" {
- req.Set("anthropic-beta", anthropicBeta)
- }
- model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if a.RequestMode == RequestModeCompletion {
- return RequestOpenAI2ClaudeComplete(*request), nil
- } else {
- return RequestOpenAI2ClaudeMessage(c, *request)
- }
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- return ClaudeStreamHandler(c, resp, info, a.RequestMode)
- } else {
- return ClaudeHandler(c, resp, info, a.RequestMode)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/claude/constants.go b/new-api/relay/channel/claude/constants.go
deleted file mode 100644
index 991ffba9c0efb3e6b0f93f59c594f748c2aeaaa0..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/claude/constants.go
+++ /dev/null
@@ -1,26 +0,0 @@
-package claude
-
-var ModelList = []string{
- "claude-instant-1.2",
- "claude-2",
- "claude-2.0",
- "claude-2.1",
- "claude-3-sonnet-20240229",
- "claude-3-opus-20240229",
- "claude-3-haiku-20240307",
- "claude-3-5-haiku-20241022",
- "claude-3-5-sonnet-20240620",
- "claude-3-5-sonnet-20241022",
- "claude-3-7-sonnet-20250219",
- "claude-3-7-sonnet-20250219-thinking",
- "claude-sonnet-4-20250514",
- "claude-sonnet-4-20250514-thinking",
- "claude-opus-4-20250514",
- "claude-opus-4-20250514-thinking",
- "claude-opus-4-1-20250805",
- "claude-opus-4-1-20250805-thinking",
- "claude-sonnet-4-5-20250929",
- "claude-sonnet-4-5-20250929-thinking",
-}
-
-var ChannelName = "claude"
diff --git a/new-api/relay/channel/claude/dto.go b/new-api/relay/channel/claude/dto.go
deleted file mode 100644
index 00391d303f19cf5338ba80b292adcdc57b761086..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/claude/dto.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package claude
-
-//
-//type ClaudeMetadata struct {
-// UserId string `json:"user_id"`
-//}
-//
-//type ClaudeMediaMessage struct {
-// Type string `json:"type"`
-// Text string `json:"text,omitempty"`
-// Source *ClaudeMessageSource `json:"source,omitempty"`
-// Usage *ClaudeUsage `json:"usage,omitempty"`
-// StopReason *string `json:"stop_reason,omitempty"`
-// PartialJson string `json:"partial_json,omitempty"`
-// Thinking string `json:"thinking,omitempty"`
-// Signature string `json:"signature,omitempty"`
-// Delta string `json:"delta,omitempty"`
-// // tool_calls
-// Id string `json:"id,omitempty"`
-// Name string `json:"name,omitempty"`
-// Input any `json:"input,omitempty"`
-// Content string `json:"content,omitempty"`
-// ToolUseId string `json:"tool_use_id,omitempty"`
-//}
-//
-//type ClaudeMessageSource struct {
-// Type string `json:"type"`
-// MediaType string `json:"media_type"`
-// Data string `json:"data"`
-//}
-//
-//type ClaudeMessage struct {
-// Role string `json:"role"`
-// Content any `json:"content"`
-//}
-//
-//type Tool struct {
-// Name string `json:"name"`
-// Description string `json:"description,omitempty"`
-// InputSchema map[string]interface{} `json:"input_schema"`
-//}
-//
-//type InputSchema struct {
-// Type string `json:"type"`
-// Properties any `json:"properties,omitempty"`
-// Required any `json:"required,omitempty"`
-//}
-//
-//type ClaudeRequest struct {
-// Model string `json:"model"`
-// Prompt string `json:"prompt,omitempty"`
-// System string `json:"system,omitempty"`
-// Messages []ClaudeMessage `json:"messages,omitempty"`
-// MaxTokens uint `json:"max_tokens,omitempty"`
-// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
-// StopSequences []string `json:"stop_sequences,omitempty"`
-// Temperature *float64 `json:"temperature,omitempty"`
-// TopP float64 `json:"top_p,omitempty"`
-// TopK int `json:"top_k,omitempty"`
-// //ClaudeMetadata `json:"metadata,omitempty"`
-// Stream bool `json:"stream,omitempty"`
-// Tools any `json:"tools,omitempty"`
-// ToolChoice any `json:"tool_choice,omitempty"`
-// Thinking *Thinking `json:"thinking,omitempty"`
-//}
-//
-//type Thinking struct {
-// Type string `json:"type"`
-// BudgetTokens int `json:"budget_tokens"`
-//}
-//
-//type ClaudeError struct {
-// Type string `json:"type"`
-// Message string `json:"message"`
-//}
-//
-//type ClaudeResponse struct {
-// Id string `json:"id"`
-// Type string `json:"type"`
-// Content []ClaudeMediaMessage `json:"content"`
-// Completion string `json:"completion"`
-// StopReason string `json:"stop_reason"`
-// Model string `json:"model"`
-// Error ClaudeError `json:"error"`
-// Usage ClaudeUsage `json:"usage"`
-// Index int `json:"index"` // stream only
-// ContentBlock *ClaudeMediaMessage `json:"content_block"`
-// Delta *ClaudeMediaMessage `json:"delta"` // stream only
-// Message *ClaudeResponse `json:"message"` // stream only: message_start
-//}
-//
-//type ClaudeUsage struct {
-// InputTokens int `json:"input_tokens"`
-// OutputTokens int `json:"output_tokens"`
-//}
diff --git a/new-api/relay/channel/claude/relay-claude.go b/new-api/relay/channel/claude/relay-claude.go
deleted file mode 100644
index b43c13f0e2965d1571ddb704f64d2699768c04fa..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/claude/relay-claude.go
+++ /dev/null
@@ -1,831 +0,0 @@
-package claude
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "one-api/relay/channel/openrouter"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- WebSearchMaxUsesLow = 1
- WebSearchMaxUsesMedium = 5
- WebSearchMaxUsesHigh = 10
-)
-
-func stopReasonClaude2OpenAI(reason string) string {
- switch reason {
- case "stop_sequence":
- return "stop"
- case "end_turn":
- return "stop"
- case "max_tokens":
- return "length"
- case "tool_use":
- return "tool_calls"
- default:
- return reason
- }
-}
-
-func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
-
- claudeRequest := dto.ClaudeRequest{
- Model: textRequest.Model,
- Prompt: "",
- StopSequences: nil,
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- TopK: textRequest.TopK,
- Stream: textRequest.Stream,
- }
- if claudeRequest.MaxTokensToSample == 0 {
- claudeRequest.MaxTokensToSample = 4096
- }
- prompt := ""
- for _, message := range textRequest.Messages {
- if message.Role == "user" {
- prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
- } else if message.Role == "assistant" {
- prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
- } else if message.Role == "system" {
- if prompt == "" {
- prompt = message.StringContent()
- }
- }
- }
- prompt += "\n\nAssistant:"
- claudeRequest.Prompt = prompt
- return &claudeRequest
-}
-
-func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
- claudeTools := make([]any, 0, len(textRequest.Tools))
-
- for _, tool := range textRequest.Tools {
- if params, ok := tool.Function.Parameters.(map[string]any); ok {
- claudeTool := dto.Tool{
- Name: tool.Function.Name,
- Description: tool.Function.Description,
- }
- claudeTool.InputSchema = make(map[string]interface{})
- if params["type"] != nil {
- claudeTool.InputSchema["type"] = params["type"].(string)
- }
- claudeTool.InputSchema["properties"] = params["properties"]
- claudeTool.InputSchema["required"] = params["required"]
- for s, a := range params {
- if s == "type" || s == "properties" || s == "required" {
- continue
- }
- claudeTool.InputSchema[s] = a
- }
- claudeTools = append(claudeTools, &claudeTool)
- }
- }
-
- // Web search tool
- // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
- if textRequest.WebSearchOptions != nil {
- webSearchTool := dto.ClaudeWebSearchTool{
- Type: "web_search_20250305",
- Name: "web_search",
- }
-
- // 处理 user_location
- if textRequest.WebSearchOptions.UserLocation != nil {
- anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
- Type: "approximate", // 固定为 "approximate"
- }
-
- // 解析 UserLocation JSON
- var userLocationMap map[string]interface{}
- if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
- // 检查是否有 approximate 字段
- if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
- if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
- anthropicUserLocation.Timezone = timezone
- }
- if country, ok := approximateData["country"].(string); ok && country != "" {
- anthropicUserLocation.Country = country
- }
- if region, ok := approximateData["region"].(string); ok && region != "" {
- anthropicUserLocation.Region = region
- }
- if city, ok := approximateData["city"].(string); ok && city != "" {
- anthropicUserLocation.City = city
- }
- }
- }
-
- webSearchTool.UserLocation = anthropicUserLocation
- }
-
- // 处理 search_context_size 转换为 max_uses
- if textRequest.WebSearchOptions.SearchContextSize != "" {
- switch textRequest.WebSearchOptions.SearchContextSize {
- case "low":
- webSearchTool.MaxUses = WebSearchMaxUsesLow
- case "medium":
- webSearchTool.MaxUses = WebSearchMaxUsesMedium
- case "high":
- webSearchTool.MaxUses = WebSearchMaxUsesHigh
- }
- }
-
- claudeTools = append(claudeTools, &webSearchTool)
- }
-
- claudeRequest := dto.ClaudeRequest{
- Model: textRequest.Model,
- MaxTokens: textRequest.GetMaxTokens(),
- StopSequences: nil,
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- TopK: textRequest.TopK,
- Stream: textRequest.Stream,
- Tools: claudeTools,
- }
-
- // 处理 tool_choice 和 parallel_tool_calls
- if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
- claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
- if claudeToolChoice != nil {
- claudeRequest.ToolChoice = claudeToolChoice
- }
- }
-
- if claudeRequest.MaxTokens == 0 {
- claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
- }
-
- if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
- strings.HasSuffix(textRequest.Model, "-thinking") {
-
- // 因为BudgetTokens 必须大于1024
- if claudeRequest.MaxTokens < 1280 {
- claudeRequest.MaxTokens = 1280
- }
-
- // BudgetTokens 为 max_tokens 的 80%
- claudeRequest.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
- }
- // TODO: 临时处理
- // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
- claudeRequest.TopP = 0
- claudeRequest.Temperature = common.GetPointer[float64](1.0)
- claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
- }
-
- if textRequest.ReasoningEffort != "" {
- switch textRequest.ReasoningEffort {
- case "low":
- claudeRequest.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: common.GetPointer[int](1280),
- }
- case "medium":
- claudeRequest.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: common.GetPointer[int](2048),
- }
- case "high":
- claudeRequest.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: common.GetPointer[int](4096),
- }
- }
- }
-
- // 指定了 reasoning 参数,覆盖 budgetTokens
- if textRequest.Reasoning != nil {
- var reasoning openrouter.RequestReasoning
- if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
- return nil, err
- }
-
- budgetTokens := reasoning.MaxTokens
- if budgetTokens > 0 {
- claudeRequest.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: &budgetTokens,
- }
- }
- }
-
- if textRequest.Stop != nil {
- // stop maybe string/array string, convert to array string
- switch textRequest.Stop.(type) {
- case string:
- claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
- case []interface{}:
- stopSequences := make([]string, 0)
- for _, stop := range textRequest.Stop.([]interface{}) {
- stopSequences = append(stopSequences, stop.(string))
- }
- claudeRequest.StopSequences = stopSequences
- }
- }
- formatMessages := make([]dto.Message, 0)
- lastMessage := dto.Message{
- Role: "tool",
- }
- for i, message := range textRequest.Messages {
- if message.Role == "" {
- textRequest.Messages[i].Role = "user"
- }
- fmtMessage := dto.Message{
- Role: message.Role,
- Content: message.Content,
- }
- if message.Role == "tool" {
- fmtMessage.ToolCallId = message.ToolCallId
- }
- if message.Role == "assistant" && message.ToolCalls != nil {
- fmtMessage.ToolCalls = message.ToolCalls
- }
- if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
- if lastMessage.IsStringContent() && message.IsStringContent() {
- fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
- // delete last message
- formatMessages = formatMessages[:len(formatMessages)-1]
- }
- }
- if fmtMessage.Content == nil {
- fmtMessage.SetStringContent("...")
- }
- formatMessages = append(formatMessages, fmtMessage)
- lastMessage = fmtMessage
- }
-
- claudeMessages := make([]dto.ClaudeMessage, 0)
- isFirstMessage := true
- // 初始化system消息数组,用于累积多个system消息
- var systemMessages []dto.ClaudeMediaMessage
-
- for _, message := range formatMessages {
- if message.Role == "system" {
- // 根据Claude API规范,system字段使用数组格式更有通用性
- if message.IsStringContent() {
- systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
- Type: "text",
- Text: common.GetPointer[string](message.StringContent()),
- })
- } else {
- // 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
- for _, ctx := range message.ParseContent() {
- if ctx.Type == "text" {
- systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
- Type: "text",
- Text: common.GetPointer[string](ctx.Text),
- })
- }
- // 未来可以在这里扩展对图片等其他类型的支持
- }
- }
- } else {
- if isFirstMessage {
- isFirstMessage = false
- if message.Role != "user" {
- // fix: first message is assistant, add user message
- claudeMessage := dto.ClaudeMessage{
- Role: "user",
- Content: []dto.ClaudeMediaMessage{
- {
- Type: "text",
- Text: common.GetPointer[string]("..."),
- },
- },
- }
- claudeMessages = append(claudeMessages, claudeMessage)
- }
- }
- claudeMessage := dto.ClaudeMessage{
- Role: message.Role,
- }
- if message.Role == "tool" {
- if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
- lastMessage := claudeMessages[len(claudeMessages)-1]
- if content, ok := lastMessage.Content.(string); ok {
- lastMessage.Content = []dto.ClaudeMediaMessage{
- {
- Type: "text",
- Text: common.GetPointer[string](content),
- },
- }
- }
- lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
- Type: "tool_result",
- ToolUseId: message.ToolCallId,
- Content: message.Content,
- })
- claudeMessages[len(claudeMessages)-1] = lastMessage
- continue
- } else {
- claudeMessage.Role = "user"
- claudeMessage.Content = []dto.ClaudeMediaMessage{
- {
- Type: "tool_result",
- ToolUseId: message.ToolCallId,
- Content: message.Content,
- },
- }
- }
- } else if message.IsStringContent() && message.ToolCalls == nil {
- claudeMessage.Content = message.StringContent()
- } else {
- claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
- for _, mediaMessage := range message.ParseContent() {
- claudeMediaMessage := dto.ClaudeMediaMessage{
- Type: mediaMessage.Type,
- }
- if mediaMessage.Type == "text" {
- claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
- } else {
- imageUrl := mediaMessage.GetImageMedia()
- claudeMediaMessage.Type = "image"
- claudeMediaMessage.Source = &dto.ClaudeMessageSource{
- Type: "base64",
- }
- // 判断是否是url
- if strings.HasPrefix(imageUrl.Url, "http") {
- // 是url,获取图片的类型和base64编码的数据
- fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
- if err != nil {
- return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
- }
- claudeMediaMessage.Source.MediaType = fileData.MimeType
- claudeMediaMessage.Source.Data = fileData.Base64Data
- } else {
- _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
- if err != nil {
- return nil, err
- }
- claudeMediaMessage.Source.MediaType = "image/" + format
- claudeMediaMessage.Source.Data = base64String
- }
- }
- claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
- }
- if message.ToolCalls != nil {
- for _, toolCall := range message.ParseToolCalls() {
- inputObj := make(map[string]any)
- if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
- common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
- continue
- }
- claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
- Type: "tool_use",
- Id: toolCall.ID,
- Name: toolCall.Function.Name,
- Input: inputObj,
- })
- }
- }
- claudeMessage.Content = claudeMediaMessages
- }
- claudeMessages = append(claudeMessages, claudeMessage)
- }
- }
-
- // 设置累积的system消息
- if len(systemMessages) > 0 {
- claudeRequest.System = systemMessages
- }
-
- claudeRequest.Prompt = ""
- claudeRequest.Messages = claudeMessages
- return &claudeRequest, nil
-}
-
-func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
- var response dto.ChatCompletionsStreamResponse
- response.Object = "chat.completion.chunk"
- response.Model = claudeResponse.Model
- response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
- tools := make([]dto.ToolCallResponse, 0)
- fcIdx := 0
- if claudeResponse.Index != nil {
- fcIdx = *claudeResponse.Index - 1
- if fcIdx < 0 {
- fcIdx = 0
- }
- }
- var choice dto.ChatCompletionsStreamResponseChoice
- if reqMode == RequestModeCompletion {
- choice.Delta.SetContentString(claudeResponse.Completion)
- finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
- if finishReason != "null" {
- choice.FinishReason = &finishReason
- }
- } else {
- if claudeResponse.Type == "message_start" {
- response.Id = claudeResponse.Message.Id
- response.Model = claudeResponse.Message.Model
- //claudeUsage = &claudeResponse.Message.Usage
- choice.Delta.SetContentString("")
- choice.Delta.Role = "assistant"
- } else if claudeResponse.Type == "content_block_start" {
- if claudeResponse.ContentBlock != nil {
- // 如果是文本块,尽可能发送首段文本(若存在)
- if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
- choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
- }
- if claudeResponse.ContentBlock.Type == "tool_use" {
- tools = append(tools, dto.ToolCallResponse{
- Index: common.GetPointer(fcIdx),
- ID: claudeResponse.ContentBlock.Id,
- Type: "function",
- Function: dto.FunctionResponse{
- Name: claudeResponse.ContentBlock.Name,
- Arguments: "",
- },
- })
- }
- } else {
- return nil
- }
- } else if claudeResponse.Type == "content_block_delta" {
- if claudeResponse.Delta != nil {
- choice.Delta.Content = claudeResponse.Delta.Text
- switch claudeResponse.Delta.Type {
- case "input_json_delta":
- tools = append(tools, dto.ToolCallResponse{
- Type: "function",
- Index: common.GetPointer(fcIdx),
- Function: dto.FunctionResponse{
- Arguments: *claudeResponse.Delta.PartialJson,
- },
- })
- case "signature_delta":
- // 加密的不处理
- signatureContent := "\n"
- choice.Delta.ReasoningContent = &signatureContent
- case "thinking_delta":
- thinkingContent := claudeResponse.Delta.Thinking
- choice.Delta.ReasoningContent = &thinkingContent
- }
- }
- } else if claudeResponse.Type == "message_delta" {
- finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
- if finishReason != "null" {
- choice.FinishReason = &finishReason
- }
- //claudeUsage = &claudeResponse.Usage
- } else if claudeResponse.Type == "message_stop" {
- return nil
- } else {
- return nil
- }
- }
- if len(tools) > 0 {
- choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
- choice.Delta.ToolCalls = tools
- }
- response.Choices = append(response.Choices, choice)
-
- return &response
-}
-
-func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
- choices := make([]dto.OpenAITextResponseChoice, 0)
- fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- }
- var responseText string
- var responseThinking string
- if len(claudeResponse.Content) > 0 {
- responseText = claudeResponse.Content[0].GetText()
- responseThinking = claudeResponse.Content[0].Thinking
- }
- tools := make([]dto.ToolCallResponse, 0)
- thinkingContent := ""
-
- if reqMode == RequestModeCompletion {
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: strings.TrimPrefix(claudeResponse.Completion, " "),
- Name: nil,
- },
- FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
- }
- choices = append(choices, choice)
- } else {
- fullTextResponse.Id = claudeResponse.Id
- for _, message := range claudeResponse.Content {
- switch message.Type {
- case "tool_use":
- args, _ := json.Marshal(message.Input)
- tools = append(tools, dto.ToolCallResponse{
- ID: message.Id,
- Type: "function", // compatible with other OpenAI derivative applications
- Function: dto.FunctionResponse{
- Name: message.Name,
- Arguments: string(args),
- },
- })
- case "thinking":
- // 加密的不管, 只输出明文的推理过程
- thinkingContent = message.Thinking
- case "text":
- responseText = message.GetText()
- }
- }
- }
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- },
- FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
- }
- choice.SetStringContent(responseText)
- if len(responseThinking) > 0 {
- choice.ReasoningContent = responseThinking
- }
- if len(tools) > 0 {
- choice.Message.SetToolCalls(tools)
- }
- choice.Message.ReasoningContent = thinkingContent
- fullTextResponse.Model = claudeResponse.Model
- choices = append(choices, choice)
- fullTextResponse.Choices = choices
- return &fullTextResponse
-}
-
-type ClaudeResponseInfo struct {
- ResponseId string
- Created int64
- Model string
- ResponseText strings.Builder
- Usage *dto.Usage
- Done bool
-}
-
-func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
- if requestMode == RequestModeCompletion {
- claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
- } else {
- if claudeResponse.Type == "message_start" {
- claudeInfo.ResponseId = claudeResponse.Message.Id
- claudeInfo.Model = claudeResponse.Message.Model
-
- // message_start, 获取usage
- claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
- claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
- } else if claudeResponse.Type == "content_block_delta" {
- if claudeResponse.Delta.Text != nil {
- claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
- }
- if claudeResponse.Delta.Thinking != "" {
- claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
- }
- } else if claudeResponse.Type == "message_delta" {
- // 最终的usage获取
- if claudeResponse.Usage.InputTokens > 0 {
- // 不叠加,只取最新的
- claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
- }
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
-
- // 判断是否完整
- claudeInfo.Done = true
- } else if claudeResponse.Type == "content_block_start" {
- } else {
- return false
- }
- }
- if oaiResponse != nil {
- oaiResponse.Id = claudeInfo.ResponseId
- oaiResponse.Created = claudeInfo.Created
- oaiResponse.Model = claudeInfo.Model
- }
- return true
-}
-
-func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
- var claudeResponse dto.ClaudeResponse
- err := common.UnmarshalJsonStr(data, &claudeResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
- return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
- }
- if info.RelayFormat == types.RelayFormatClaude {
- FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
-
- if requestMode == RequestModeCompletion {
- } else {
- if claudeResponse.Type == "message_start" {
- // message_start, 获取usage
- info.UpstreamModelName = claudeResponse.Message.Model
- } else if claudeResponse.Type == "content_block_delta" {
- } else if claudeResponse.Type == "message_delta" {
- }
- }
- helper.ClaudeChunkData(c, claudeResponse, data)
- } else if info.RelayFormat == types.RelayFormatOpenAI {
- response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
-
- if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
- return nil
- }
-
- err = helper.ObjectData(c, response)
- if err != nil {
- logger.LogError(c, "send_stream_response_failed: "+err.Error())
- }
- }
- return nil
-}
-
-func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
-
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- if claudeInfo.Usage.PromptTokens == 0 {
- //上游出错
- }
- if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
- if common.DebugEnabled {
- common.SysLog("claude response usage is not complete, maybe upstream error")
- }
- claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
- }
-
- if info.RelayFormat == types.RelayFormatClaude {
- //
- } else if info.RelayFormat == types.RelayFormatOpenAI {
- if info.ShouldIncludeUsage {
- response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
- err := helper.ObjectData(c, response)
- if err != nil {
- common.SysLog("send final response failed: " + err.Error())
- }
- }
- helper.Done(c)
- }
-}
-
-func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
- claudeInfo := &ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
- var err *types.NewAPIError
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
- if err != nil {
- return false
- }
- return true
- })
- if err != nil {
- return nil, err
- }
-
- HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
- return claudeInfo.Usage, nil
-}
-
-func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
- var claudeResponse dto.ClaudeResponse
- err := common.Unmarshal(data, &claudeResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
- return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
- }
- if requestMode == RequestModeCompletion {
- completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
- claudeInfo.Usage.PromptTokens = info.PromptTokens
- claudeInfo.Usage.CompletionTokens = completionTokens
- claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
- } else {
- claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
- }
- var responseData []byte
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
- openaiResponse.Usage = *claudeInfo.Usage
- responseData, err = json.Marshal(openaiResponse)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- case types.RelayFormatClaude:
- responseData = data
- }
-
- if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
- c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
- }
-
- service.IOCopyBytesGracefully(c, httpResp, responseData)
- return nil
-}
-
-func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- claudeInfo := &ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- if common.DebugEnabled {
- println("responseBody: ", string(responseBody))
- }
- handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
- if handleErr != nil {
- return nil, handleErr
- }
- return claudeInfo.Usage, nil
-}
-
-func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
- var claudeToolChoice *dto.ClaudeToolChoice
-
- // 处理 tool_choice 字符串值
- if toolChoiceStr, ok := toolChoice.(string); ok {
- switch toolChoiceStr {
- case "auto":
- claudeToolChoice = &dto.ClaudeToolChoice{
- Type: "auto",
- }
- case "required":
- claudeToolChoice = &dto.ClaudeToolChoice{
- Type: "any",
- }
- case "none":
- claudeToolChoice = &dto.ClaudeToolChoice{
- Type: "none",
- }
- }
- } else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
- // 处理 tool_choice 对象值
- if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
- if toolName, ok := function["name"].(string); ok {
- claudeToolChoice = &dto.ClaudeToolChoice{
- Type: "tool",
- Name: toolName,
- }
- }
- }
- }
-
- // 处理 parallel_tool_calls
- if parallelToolCalls != nil {
- if claudeToolChoice == nil {
- // 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
- claudeToolChoice = &dto.ClaudeToolChoice{
- Type: "auto",
- }
- }
-
- // 设置 disable_parallel_tool_use
- // 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
- claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
- }
-
- return claudeToolChoice
-}
diff --git a/new-api/relay/channel/cloudflare/adaptor.go b/new-api/relay/channel/cloudflare/adaptor.go
deleted file mode 100644
index 01bdc38236189bfaf9976db620443f7916432300..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cloudflare/adaptor.go
+++ /dev/null
@@ -1,135 +0,0 @@
-package cloudflare
-
-import (
- "bytes"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
- case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
- case constant.RelayModeResponses:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
- default:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- switch info.RelayMode {
- case constant.RelayModeCompletions:
- return convertCf2CompletionsRequest(*request), nil
- default:
- return request, nil
- }
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- // 添加文件字段
- file, _, err := c.Request.FormFile("file")
- if err != nil {
- return nil, errors.New("file is required")
- }
- defer file.Close()
- // 打开临时文件用于保存上传的文件内容
- requestBody := &bytes.Buffer{}
-
- // 将上传的文件内容复制到临时文件
- if _, err := io.Copy(requestBody, file); err != nil {
- return nil, err
- }
- return requestBody, nil
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- fallthrough
- case constant.RelayModeChatCompletions:
- if info.IsStream {
- err, usage = cfStreamHandler(c, info, resp)
- } else {
- err, usage = cfHandler(c, info, resp)
- }
- case constant.RelayModeResponses:
- if info.IsStream {
- usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OaiResponsesHandler(c, info, resp)
- }
- case constant.RelayModeAudioTranslation:
- fallthrough
- case constant.RelayModeAudioTranscription:
- err, usage = cfSTTHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/cloudflare/constant.go b/new-api/relay/channel/cloudflare/constant.go
deleted file mode 100644
index 3d6bf33df13f43ff1ff74ff34dda12523b206072..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cloudflare/constant.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package cloudflare
-
-var ModelList = []string{
- "@cf/meta/llama-3.1-8b-instruct",
- "@cf/meta/llama-2-7b-chat-fp16",
- "@cf/meta/llama-2-7b-chat-int8",
- "@cf/mistral/mistral-7b-instruct-v0.1",
- "@hf/thebloke/deepseek-coder-6.7b-base-awq",
- "@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
- "@cf/deepseek-ai/deepseek-math-7b-base",
- "@cf/deepseek-ai/deepseek-math-7b-instruct",
- "@cf/thebloke/discolm-german-7b-v1-awq",
- "@cf/tiiuae/falcon-7b-instruct",
- "@cf/google/gemma-2b-it-lora",
- "@hf/google/gemma-7b-it",
- "@cf/google/gemma-7b-it-lora",
- "@hf/nousresearch/hermes-2-pro-mistral-7b",
- "@hf/thebloke/llama-2-13b-chat-awq",
- "@cf/meta-llama/llama-2-7b-chat-hf-lora",
- "@cf/meta/llama-3-8b-instruct",
- "@hf/thebloke/llamaguard-7b-awq",
- "@hf/thebloke/mistral-7b-instruct-v0.1-awq",
- "@hf/mistralai/mistral-7b-instruct-v0.2",
- "@cf/mistral/mistral-7b-instruct-v0.2-lora",
- "@hf/thebloke/neural-chat-7b-v3-1-awq",
- "@cf/openchat/openchat-3.5-0106",
- "@hf/thebloke/openhermes-2.5-mistral-7b-awq",
- "@cf/microsoft/phi-2",
- "@cf/qwen/qwen1.5-0.5b-chat",
- "@cf/qwen/qwen1.5-1.8b-chat",
- "@cf/qwen/qwen1.5-14b-chat-awq",
- "@cf/qwen/qwen1.5-7b-chat-awq",
- "@cf/defog/sqlcoder-7b-2",
- "@hf/nexusflow/starling-lm-7b-beta",
- "@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
- "@hf/thebloke/zephyr-7b-beta-awq",
-}
-
-var ChannelName = "cloudflare"
diff --git a/new-api/relay/channel/cloudflare/dto.go b/new-api/relay/channel/cloudflare/dto.go
deleted file mode 100644
index 096681233b2ed877de451943947cfb3861b9b35d..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cloudflare/dto.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package cloudflare
-
-import "one-api/dto"
-
-type CfRequest struct {
- Messages []dto.Message `json:"messages,omitempty"`
- Lora string `json:"lora,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- Prompt string `json:"prompt,omitempty"`
- Raw bool `json:"raw,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
-}
-
-type CfAudioResponse struct {
- Result CfSTTResult `json:"result"`
-}
-
-type CfSTTResult struct {
- Text string `json:"text"`
-}
diff --git a/new-api/relay/channel/cloudflare/relay_cloudflare.go b/new-api/relay/channel/cloudflare/relay_cloudflare.go
deleted file mode 100644
index e8ab94c1f16787cc6739d9343a7b45841fcffce9..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cloudflare/relay_cloudflare.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package cloudflare
-
-import (
- "bufio"
- "encoding/json"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
- p, _ := textRequest.Prompt.(string)
- return &CfRequest{
- Prompt: p,
- MaxTokens: textRequest.GetMaxTokens(),
- Stream: textRequest.Stream,
- Temperature: textRequest.Temperature,
- }
-}
-
-func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
-
- helper.SetEventStreamHeaders(c)
- id := helper.GetResponseID(c)
- var responseText string
- isFirst := true
-
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < len("data: ") {
- continue
- }
- data = strings.TrimPrefix(data, "data: ")
- data = strings.TrimSuffix(data, "\r")
-
- if data == "[DONE]" {
- break
- }
-
- var response dto.ChatCompletionsStreamResponse
- err := json.Unmarshal([]byte(data), &response)
- if err != nil {
- logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
- continue
- }
- for _, choice := range response.Choices {
- choice.Delta.Role = "assistant"
- responseText += choice.Delta.GetContentString()
- }
- response.Id = id
- response.Model = info.UpstreamModelName
- err = helper.ObjectData(c, response)
- if isFirst {
- isFirst = false
- info.FirstResponseTime = time.Now()
- }
- if err != nil {
- logger.LogError(c, "error_rendering_stream_response: "+err.Error())
- }
- }
-
- if err := scanner.Err(); err != nil {
- logger.LogError(c, "error_scanning_stream_response: "+err.Error())
- }
- usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- if info.ShouldIncludeUsage {
- response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
- err := helper.ObjectData(c, response)
- if err != nil {
- logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
- }
- }
- helper.Done(c)
-
- service.CloseResponseBodyGracefully(resp)
-
- return nil, usage
-}
-
-func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.CloseResponseBodyGracefully(resp)
- var response dto.TextResponse
- err = json.Unmarshal(responseBody, &response)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- response.Model = info.UpstreamModelName
- var responseText string
- for _, choice := range response.Choices {
- responseText += choice.Message.StringContent()
- }
- usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- response.Usage = *usage
- response.Id = helper.GetResponseID(c)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
- return nil, usage
-}
-
-func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
- var cfResp CfAudioResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &cfResp)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
-
- audioResp := &dto.AudioResponse{
- Text: cfResp.Result.Text,
- }
-
- jsonResponse, err := json.Marshal(audioResp)
- if err != nil {
- return types.NewError(err, types.ErrorCodeBadResponseBody), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
-
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-
- return nil, usage
-}
diff --git a/new-api/relay/channel/cohere/adaptor.go b/new-api/relay/channel/cohere/adaptor.go
deleted file mode 100644
index ed9edd89dabf72b521b8f51c6c1a28c7114ce783..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cohere/adaptor.go
+++ /dev/null
@@ -1,99 +0,0 @@
-package cohere
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
- } else {
- return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- return requestOpenAI2Cohere(*request), nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return requestConvertRerank2Cohere(request), nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.RelayMode == constant.RelayModeRerank {
- usage, err = cohereRerankHandler(c, resp, info)
- } else {
- if info.IsStream {
- usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
- } else {
- usage, err = cohereHandler(c, info, resp)
- }
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/cohere/constant.go b/new-api/relay/channel/cohere/constant.go
deleted file mode 100644
index e1255bda60bc838e2b0e8523d6afe252db653fde..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cohere/constant.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package cohere
-
-var ModelList = []string{
- "command-a-03-2025",
- "command-r", "command-r-plus",
- "command-r-08-2024", "command-r-plus-08-2024",
- "c4ai-aya-23-35b", "c4ai-aya-23-8b",
- "command-light", "command-light-nightly", "command", "command-nightly",
- "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
-}
-
-var ChannelName = "cohere"
diff --git a/new-api/relay/channel/cohere/dto.go b/new-api/relay/channel/cohere/dto.go
deleted file mode 100644
index 94546ab1af74e06bbe2b0eb809f83abb89146426..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cohere/dto.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package cohere
-
-import "one-api/dto"
-
-type CohereRequest struct {
- Model string `json:"model"`
- ChatHistory []ChatHistory `json:"chat_history"`
- Message string `json:"message"`
- Stream bool `json:"stream"`
- MaxTokens uint `json:"max_tokens"`
- SafetyMode string `json:"safety_mode,omitempty"`
-}
-
-type ChatHistory struct {
- Role string `json:"role"`
- Message string `json:"message"`
-}
-
-type CohereResponse struct {
- IsFinished bool `json:"is_finished"`
- EventType string `json:"event_type"`
- Text string `json:"text,omitempty"`
- FinishReason string `json:"finish_reason,omitempty"`
- Response *CohereResponseResult `json:"response"`
-}
-
-type CohereResponseResult struct {
- ResponseId string `json:"response_id"`
- FinishReason string `json:"finish_reason,omitempty"`
- Text string `json:"text"`
- Meta CohereMeta `json:"meta"`
-}
-
-type CohereRerankRequest struct {
- Documents []any `json:"documents"`
- Query string `json:"query"`
- Model string `json:"model"`
- TopN int `json:"top_n"`
- ReturnDocuments bool `json:"return_documents"`
-}
-
-type CohereRerankResponseResult struct {
- Results []dto.RerankResponseResult `json:"results"`
- Meta CohereMeta `json:"meta"`
-}
-
-type CohereMeta struct {
- //Tokens CohereTokens `json:"tokens"`
- BilledUnits CohereBilledUnits `json:"billed_units"`
-}
-
-type CohereBilledUnits struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
-}
-
-type CohereTokens struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
-}
diff --git a/new-api/relay/channel/cohere/relay-cohere.go b/new-api/relay/channel/cohere/relay-cohere.go
deleted file mode 100644
index 33a66c24033b99dc6b44dbdb7b187916456e760e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/cohere/relay-cohere.go
+++ /dev/null
@@ -1,248 +0,0 @@
-package cohere
-
-import (
- "bufio"
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
- cohereReq := CohereRequest{
- Model: textRequest.Model,
- ChatHistory: []ChatHistory{},
- Message: "",
- Stream: textRequest.Stream,
- MaxTokens: textRequest.GetMaxTokens(),
- }
- if common.CohereSafetySetting != "NONE" {
- cohereReq.SafetyMode = common.CohereSafetySetting
- }
- if cohereReq.MaxTokens == 0 {
- cohereReq.MaxTokens = 4000
- }
- for _, msg := range textRequest.Messages {
- if msg.Role == "user" {
- cohereReq.Message = msg.StringContent()
- } else {
- var role string
- if msg.Role == "assistant" {
- role = "CHATBOT"
- } else if msg.Role == "system" {
- role = "SYSTEM"
- } else {
- role = "USER"
- }
- cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
- Role: role,
- Message: msg.StringContent(),
- })
- }
- }
-
- return &cohereReq
-}
-
-func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
- if rerankRequest.TopN == 0 {
- rerankRequest.TopN = 1
- }
- cohereReq := CohereRerankRequest{
- Query: rerankRequest.Query,
- Documents: rerankRequest.Documents,
- Model: rerankRequest.Model,
- TopN: rerankRequest.TopN,
- ReturnDocuments: true,
- }
- return &cohereReq
-}
-
-func stopReasonCohere2OpenAI(reason string) string {
- switch reason {
- case "COMPLETE":
- return "stop"
- case "MAX_TOKENS":
- return "max_tokens"
- default:
- return reason
- }
-}
-
-func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseId := helper.GetResponseID(c)
- createdTime := common.GetTimestamp()
- usage := &dto.Usage{}
- responseText := ""
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- dataChan <- data
- }
- stopChan <- true
- }()
- helper.SetEventStreamHeaders(c)
- isFirst := true
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- if isFirst {
- isFirst = false
- info.FirstResponseTime = time.Now()
- }
- data = strings.TrimSuffix(data, "\r")
- var cohereResp CohereResponse
- err := json.Unmarshal([]byte(data), &cohereResp)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
- var openaiResp dto.ChatCompletionsStreamResponse
- openaiResp.Id = responseId
- openaiResp.Created = createdTime
- openaiResp.Object = "chat.completion.chunk"
- openaiResp.Model = info.UpstreamModelName
- if cohereResp.IsFinished {
- finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
- openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
- {
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
- Index: 0,
- FinishReason: &finishReason,
- },
- }
- if cohereResp.Response != nil {
- usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
- usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
- }
- } else {
- openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
- {
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
- Role: "assistant",
- Content: &cohereResp.Text,
- },
- Index: 0,
- },
- }
- responseText += cohereResp.Text
- }
- jsonStr, err := json.Marshal(openaiResp)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- if usage.PromptTokens == 0 {
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- }
- return usage, nil
-}
-
-func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- createdTime := common.GetTimestamp()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- service.CloseResponseBodyGracefully(resp)
- var cohereResp CohereResponseResult
- err = json.Unmarshal(responseBody, &cohereResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- usage := dto.Usage{}
- usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
- usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
- usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
-
- var openaiResp dto.TextResponse
- openaiResp.Id = cohereResp.ResponseId
- openaiResp.Created = createdTime
- openaiResp.Object = "chat.completion"
- openaiResp.Model = info.UpstreamModelName
- openaiResp.Usage = usage
-
- openaiResp.Choices = []dto.OpenAITextResponseChoice{
- {
- Index: 0,
- Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
- FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
- },
- }
-
- jsonResponse, err := json.Marshal(openaiResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
- return &usage, nil
-}
-
-func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- service.CloseResponseBodyGracefully(resp)
- var cohereResp CohereRerankResponseResult
- err = json.Unmarshal(responseBody, &cohereResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- usage := dto.Usage{}
- if cohereResp.Meta.BilledUnits.InputTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens = 0
- usage.TotalTokens = info.PromptTokens
- } else {
- usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
- usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
- usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
- }
-
- var rerankResp dto.RerankResponse
- rerankResp.Results = cohereResp.Results
- rerankResp.Usage = usage
-
- jsonResponse, err := json.Marshal(rerankResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return &usage, nil
-}
diff --git a/new-api/relay/channel/coze/adaptor.go b/new-api/relay/channel/coze/adaptor.go
deleted file mode 100644
index a5c2fa6bb9ff8caa5601b3ab863c0331fdafa438..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/coze/adaptor.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package coze
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/common"
- "one-api/types"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-// ConvertAudioRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- return nil, errors.New("not implemented")
-}
-
-// ConvertClaudeRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-// ConvertEmbeddingRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-// ConvertImageRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-// ConvertOpenAIRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return convertCozeChatRequest(c, *request), nil
-}
-
-// ConvertOpenAIResponsesRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-// ConvertRerankRequest implements channel.Adaptor.
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-// DoRequest implements channel.Adaptor.
-func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
- if info.IsStream {
- return channel.DoApiRequest(a, c, info, requestBody)
- }
- // 首先发送创建消息请求,成功后再发送获取消息请求
- // 发送创建消息请求
- resp, err := channel.DoApiRequest(a, c, info, requestBody)
- if err != nil {
- return nil, err
- }
- // 解析 resp
- var cozeResponse CozeChatResponse
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, err
- }
- err = json.Unmarshal(respBody, &cozeResponse)
- if cozeResponse.Code != 0 {
- return nil, errors.New(cozeResponse.Msg)
- }
- c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
- c.Set("coze_chat_id", cozeResponse.Data.Id)
- // 轮询检查消息是否完成
- for {
- err, isComplete := checkIfChatComplete(a, c, info)
- if err != nil {
- return nil, err
- } else {
- if isComplete {
- break
- }
- }
- time.Sleep(time.Second * 1)
- }
- // 发送获取消息请求
- return getChatDetail(a, c, info)
-}
-
-// DoResponse implements channel.Adaptor.
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = cozeChatStreamHandler(c, info, resp)
- } else {
- usage, err = cozeChatHandler(c, info, resp)
- }
- return
-}
-
-// GetChannelName implements channel.Adaptor.
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
-
-// GetModelList implements channel.Adaptor.
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-// GetRequestURL implements channel.Adaptor.
-func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
-}
-
-// Init implements channel.Adaptor.
-func (a *Adaptor) Init(info *common.RelayInfo) {
-
-}
-
-// SetupRequestHeader implements channel.Adaptor.
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
diff --git a/new-api/relay/channel/coze/constants.go b/new-api/relay/channel/coze/constants.go
deleted file mode 100644
index c79af56d9e31f75daf4e7809afaf7558bab36ce5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/coze/constants.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package coze
-
-var ModelList = []string{
- "moonshot-v1-8k",
- "moonshot-v1-32k",
- "moonshot-v1-128k",
- "Baichuan4",
- "abab6.5s-chat-pro",
- "glm-4-0520",
- "qwen-max",
- "deepseek-r1",
- "deepseek-v3",
- "deepseek-r1-distill-qwen-32b",
- "deepseek-r1-distill-qwen-7b",
- "step-1v-8k",
- "step-1.5v-mini",
- "Doubao-pro-32k",
- "Doubao-pro-256k",
- "Doubao-lite-128k",
- "Doubao-lite-32k",
- "Doubao-vision-lite-32k",
- "Doubao-vision-pro-32k",
- "Doubao-1.5-pro-vision-32k",
- "Doubao-1.5-lite-32k",
- "Doubao-1.5-pro-32k",
- "Doubao-1.5-thinking-pro",
- "Doubao-1.5-pro-256k",
-}
-
-var ChannelName = "coze"
diff --git a/new-api/relay/channel/coze/dto.go b/new-api/relay/channel/coze/dto.go
deleted file mode 100644
index 38c0101ee6bcaf08f0b932fbdcfb49104b411903..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/coze/dto.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package coze
-
-import "encoding/json"
-
-type CozeError struct {
- Code int `json:"code"`
- Message string `json:"message"`
-}
-
-type CozeEnterMessage struct {
- Role string `json:"role"`
- Type string `json:"type,omitempty"`
- Content any `json:"content,omitempty"`
- MetaData json.RawMessage `json:"meta_data,omitempty"`
- ContentType string `json:"content_type,omitempty"`
-}
-
-type CozeChatRequest struct {
- BotId string `json:"bot_id"`
- UserId string `json:"user_id"`
- AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
- Stream bool `json:"stream,omitempty"`
- CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
- AutoSaveHistory bool `json:"auto_save_history,omitempty"`
- MetaData json.RawMessage `json:"meta_data,omitempty"`
- ExtraParams json.RawMessage `json:"extra_params,omitempty"`
- ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
- Parameters json.RawMessage `json:"parameters,omitempty"`
-}
-
-type CozeChatResponse struct {
- Code int `json:"code"`
- Msg string `json:"msg"`
- Data CozeChatResponseData `json:"data"`
-}
-
-type CozeChatResponseData struct {
- Id string `json:"id"`
- ConversationId string `json:"conversation_id"`
- BotId string `json:"bot_id"`
- CreatedAt int64 `json:"created_at"`
- LastError CozeError `json:"last_error"`
- Status string `json:"status"`
- Usage CozeChatUsage `json:"usage"`
-}
-
-type CozeChatUsage struct {
- TokenCount int `json:"token_count"`
- OutputCount int `json:"output_count"`
- InputCount int `json:"input_count"`
-}
-
-type CozeChatDetailResponse struct {
- Data []CozeChatV3MessageDetail `json:"data"`
- Code int `json:"code"`
- Msg string `json:"msg"`
- Detail CozeResponseDetail `json:"detail"`
-}
-
-type CozeChatV3MessageDetail struct {
- Id string `json:"id"`
- Role string `json:"role"`
- Type string `json:"type"`
- BotId string `json:"bot_id"`
- ChatId string `json:"chat_id"`
- Content json.RawMessage `json:"content"`
- MetaData json.RawMessage `json:"meta_data"`
- CreatedAt int64 `json:"created_at"`
- SectionId string `json:"section_id"`
- UpdatedAt int64 `json:"updated_at"`
- ContentType string `json:"content_type"`
- ConversationId string `json:"conversation_id"`
- ReasoningContent string `json:"reasoning_content"`
-}
-
-type CozeResponseDetail struct {
- Logid string `json:"logid"`
-}
diff --git a/new-api/relay/channel/coze/relay-coze.go b/new-api/relay/channel/coze/relay-coze.go
deleted file mode 100644
index 2fe435cf01356abfba8c93e076de27acb369d98e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/coze/relay-coze.go
+++ /dev/null
@@ -1,296 +0,0 @@
-package coze
-
-import (
- "bufio"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
- var messages []CozeEnterMessage
- // 将 request的messages的role为user的content转换为CozeMessage
- for _, message := range request.Messages {
- if message.Role == "user" {
- messages = append(messages, CozeEnterMessage{
- Role: "user",
- Content: message.Content,
- // TODO: support more content type
- ContentType: "text",
- })
- }
- }
- user := request.User
- if user == "" {
- user = helper.GetResponseID(c)
- }
- cozeRequest := &CozeChatRequest{
- BotId: c.GetString("bot_id"),
- UserId: user,
- AdditionalMessages: messages,
- Stream: request.Stream,
- }
- return cozeRequest
-}
-
-func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- service.CloseResponseBodyGracefully(resp)
- // convert coze response to openai response
- var response dto.TextResponse
- var cozeResponse CozeChatDetailResponse
- response.Model = info.UpstreamModelName
- err = json.Unmarshal(responseBody, &cozeResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- if cozeResponse.Code != 0 {
- return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
- }
- // 从上下文获取 usage
- var usage dto.Usage
- usage.PromptTokens = c.GetInt("coze_input_count")
- usage.CompletionTokens = c.GetInt("coze_output_count")
- usage.TotalTokens = c.GetInt("coze_token_count")
- response.Usage = usage
- response.Id = helper.GetResponseID(c)
-
- var responseContent json.RawMessage
- for _, data := range cozeResponse.Data {
- if data.Type == "answer" {
- responseContent = data.Content
- response.Created = data.CreatedAt
- }
- }
- // 添加 response.Choices
- response.Choices = []dto.OpenAITextResponseChoice{
- {
- Index: 0,
- Message: dto.Message{Role: "assistant", Content: responseContent},
- FinishReason: "stop",
- },
- }
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
-
- return &usage, nil
-}
-
-func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- helper.SetEventStreamHeaders(c)
- id := helper.GetResponseID(c)
- var responseText string
-
- var currentEvent string
- var currentData string
- var usage = &dto.Usage{}
-
- for scanner.Scan() {
- line := scanner.Text()
-
- if line == "" {
- if currentEvent != "" && currentData != "" {
- // handle last event
- handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
- currentEvent = ""
- currentData = ""
- }
- continue
- }
-
- if strings.HasPrefix(line, "event:") {
- currentEvent = strings.TrimSpace(line[6:])
- continue
- }
-
- if strings.HasPrefix(line, "data:") {
- currentData = strings.TrimSpace(line[5:])
- continue
- }
- }
-
- // Last event
- if currentEvent != "" && currentData != "" {
- handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
- }
-
- if err := scanner.Err(); err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- helper.Done(c)
-
- if usage.TotalTokens == 0 {
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
- }
-
- return usage, nil
-}
-
-func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
- switch event {
- case "conversation.chat.completed":
- // 将 data 解析为 CozeChatResponseData
- var chatData CozeChatResponseData
- err := json.Unmarshal([]byte(data), &chatData)
- if err != nil {
- common.SysLog("error_unmarshalling_stream_response: " + err.Error())
- return
- }
-
- usage.PromptTokens = chatData.Usage.InputCount
- usage.CompletionTokens = chatData.Usage.OutputCount
- usage.TotalTokens = chatData.Usage.TokenCount
-
- finishReason := "stop"
- stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
- helper.ObjectData(c, stopResponse)
-
- case "conversation.message.delta":
- // 将 data 解析为 CozeChatV3MessageDetail
- var messageData CozeChatV3MessageDetail
- err := json.Unmarshal([]byte(data), &messageData)
- if err != nil {
- common.SysLog("error_unmarshalling_stream_response: " + err.Error())
- return
- }
-
- var content string
- err = json.Unmarshal(messageData.Content, &content)
- if err != nil {
- common.SysLog("error_unmarshalling_stream_response: " + err.Error())
- return
- }
-
- *responseText += content
-
- openaiResponse := dto.ChatCompletionsStreamResponse{
- Id: id,
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- }
-
- choice := dto.ChatCompletionsStreamResponseChoice{
- Index: 0,
- }
- choice.Delta.SetContentString(content)
- openaiResponse.Choices = append(openaiResponse.Choices, choice)
-
- helper.ObjectData(c, openaiResponse)
-
- case "error":
- var errorData CozeError
- err := json.Unmarshal([]byte(data), &errorData)
- if err != nil {
- common.SysLog("error_unmarshalling_stream_response: " + err.Error())
- return
- }
-
- common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
- }
-}
-
-func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
- requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
-
- requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
- // 将 conversationId和chatId作为参数发送get请求
- req, err := http.NewRequest("GET", requestURL, nil)
- if err != nil {
- return err, false
- }
- err = a.SetupRequestHeader(c, &req.Header, info)
- if err != nil {
- return err, false
- }
-
- resp, err := doRequest(req, info) // 调用 doRequest
- if err != nil {
- return err, false
- }
- if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
- return fmt.Errorf("resp is nil"), false
- }
- defer resp.Body.Close() // 确保响应体被关闭
-
- // 解析 resp 到 CozeChatResponse
- var cozeResponse CozeChatResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("read response body failed: %w", err), false
- }
- err = json.Unmarshal(responseBody, &cozeResponse)
- if err != nil {
- return fmt.Errorf("unmarshal response body failed: %w", err), false
- }
- if cozeResponse.Data.Status == "completed" {
- // 在上下文设置 usage
- c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
- c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
- c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
- return nil, true
- } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
- return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
- } else {
- return nil, false
- }
-}
-
-func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
- requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
-
- requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
- req, err := http.NewRequest("GET", requestURL, nil)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- err = a.SetupRequestHeader(c, &req.Header, info)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := doRequest(req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
-}
-
-func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
- var client *http.Client
- var err error // 声明 err 变量
- if info.ChannelSetting.Proxy != "" {
- client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
- if err != nil {
- return nil, fmt.Errorf("new proxy http client failed: %w", err)
- }
- } else {
- client = service.GetHttpClient()
- }
- resp, err := client.Do(req)
- if err != nil { // 增加对 client.Do(req) 返回错误的检查
- return nil, fmt.Errorf("client.Do failed: %w", err)
- }
- // _ = resp.Body.Close()
- return resp, nil
-}
diff --git a/new-api/relay/channel/deepseek/adaptor.go b/new-api/relay/channel/deepseek/adaptor.go
deleted file mode 100644
index c11bc1a6125212e0e2d780f126b33db2fda77932..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/deepseek/adaptor.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package deepseek
-
-import (
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "strings"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := claude.Adaptor{}
- return adaptor.ConvertClaudeRequest(c, info, req)
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- fimBaseUrl := info.ChannelBaseUrl
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
- default:
- if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
- fimBaseUrl += "/beta"
- }
- switch info.RelayMode {
- case constant.RelayModeCompletions:
- return fmt.Sprintf("%s/completions", fimBaseUrl), nil
- default:
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
- }
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- if info.IsStream {
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- } else {
- return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
- }
- default:
- adaptor := openai.Adaptor{}
- return adaptor.DoResponse(c, resp, info)
- }
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/deepseek/constants.go b/new-api/relay/channel/deepseek/constants.go
deleted file mode 100644
index e6b705b010e79de737b87a3644e2afcc65f73216..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/deepseek/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package deepseek
-
-var ModelList = []string{
- "deepseek-chat", "deepseek-reasoner",
-}
-
-var ChannelName = "deepseek"
diff --git a/new-api/relay/channel/dify/adaptor.go b/new-api/relay/channel/dify/adaptor.go
deleted file mode 100644
index 954c24dcfa79af3e1aa5aca19d02afdf6897baad..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/dify/adaptor.go
+++ /dev/null
@@ -1,120 +0,0 @@
-package dify
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- BotTypeChatFlow = 1 // chatflow default
- BotTypeAgent = 2
- BotTypeWorkFlow = 3
- BotTypeCompletion = 4
-)
-
-type Adaptor struct {
- BotType int
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- //if strings.HasPrefix(info.UpstreamModelName, "agent") {
- // a.BotType = BotTypeAgent
- //} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
- // a.BotType = BotTypeWorkFlow
- //} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
- // a.BotType = BotTypeCompletion
- //} else {
- //}
- a.BotType = BotTypeChatFlow
-
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch a.BotType {
- case BotTypeWorkFlow:
- return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
- case BotTypeCompletion:
- return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
- case BotTypeAgent:
- fallthrough
- default:
- return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return requestOpenAI2Dify(c, info, *request), nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- return difyStreamHandler(c, info, resp)
- } else {
- return difyHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/dify/constants.go b/new-api/relay/channel/dify/constants.go
deleted file mode 100644
index 5b3f64cc891b45d0134e5bbd261c5f1d4e27b7af..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/dify/constants.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package dify
-
-var ModelList []string
-
-var ChannelName = "dify"
diff --git a/new-api/relay/channel/dify/dto.go b/new-api/relay/channel/dify/dto.go
deleted file mode 100644
index fb59ff0d0071341cb7da00644dc54860038f3106..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/dify/dto.go
+++ /dev/null
@@ -1,45 +0,0 @@
-package dify
-
-import "one-api/dto"
-
-type DifyChatRequest struct {
- Inputs map[string]interface{} `json:"inputs"`
- Query string `json:"query"`
- ResponseMode string `json:"response_mode"`
- User string `json:"user"`
- AutoGenerateName bool `json:"auto_generate_name"`
- Files []DifyFile `json:"files"`
-}
-
-type DifyFile struct {
- Type string `json:"type"`
- TransferMode string `json:"transfer_mode"`
- URL string `json:"url,omitempty"`
- UploadFileId string `json:"upload_file_id,omitempty"`
-}
-
-type DifyMetaData struct {
- Usage dto.Usage `json:"usage"`
-}
-
-type DifyData struct {
- WorkflowId string `json:"workflow_id"`
- NodeId string `json:"node_id"`
- NodeType string `json:"node_type"`
- Status string `json:"status"`
-}
-
-type DifyChatCompletionResponse struct {
- ConversationId string `json:"conversation_id"`
- Answer string `json:"answer"`
- CreateAt int64 `json:"create_at"`
- MetaData DifyMetaData `json:"metadata"`
-}
-
-type DifyChunkChatCompletionResponse struct {
- Event string `json:"event"`
- ConversationId string `json:"conversation_id"`
- Answer string `json:"answer"`
- Data DifyData `json:"data"`
- MetaData DifyMetaData `json:"metadata"`
-}
diff --git a/new-api/relay/channel/dify/relay-dify.go b/new-api/relay/channel/dify/relay-dify.go
deleted file mode 100644
index 1850aed3fe81c7fbee44d608a11ebb8e1bcd5764..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/dify/relay-dify.go
+++ /dev/null
@@ -1,289 +0,0 @@
-package dify
-
-import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "os"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
- uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
- switch media.Type {
- case dto.ContentTypeImageURL:
- // Decode base64 data
- imageMedia := media.GetImageMedia()
- base64Data := imageMedia.Url
- // Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
- if idx := strings.Index(base64Data, ","); idx != -1 {
- base64Data = base64Data[idx+1:]
- }
-
- // Decode base64 string
- decodedData, err := base64.StdEncoding.DecodeString(base64Data)
- if err != nil {
- common.SysLog("failed to decode base64: " + err.Error())
- return nil
- }
-
- // Create temporary file
- tempFile, err := os.CreateTemp("", "dify-upload-*")
- if err != nil {
- common.SysLog("failed to create temp file: " + err.Error())
- return nil
- }
- defer tempFile.Close()
- defer os.Remove(tempFile.Name())
-
- // Write decoded data to temp file
- if _, err := tempFile.Write(decodedData); err != nil {
- common.SysLog("failed to write to temp file: " + err.Error())
- return nil
- }
-
- // Create multipart form
- body := &bytes.Buffer{}
- writer := multipart.NewWriter(body)
-
- // Add user field
- if err := writer.WriteField("user", user); err != nil {
- common.SysLog("failed to add user field: " + err.Error())
- return nil
- }
-
- // Create form file with proper mime type
- mimeType := imageMedia.MimeType
- if mimeType == "" {
- mimeType = "image/jpeg" // default mime type
- }
-
- // Create form file
- part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
- if err != nil {
- common.SysLog("failed to create form file: " + err.Error())
- return nil
- }
-
- // Copy file content to form
- if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
- common.SysLog("failed to copy file content: " + err.Error())
- return nil
- }
- writer.Close()
-
- // Create HTTP request
- req, err := http.NewRequest("POST", uploadUrl, body)
- if err != nil {
- common.SysLog("failed to create request: " + err.Error())
- return nil
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
-
- // Send request
- client := service.GetHttpClient()
- resp, err := client.Do(req)
- if err != nil {
- common.SysLog("failed to send request: " + err.Error())
- return nil
- }
- defer resp.Body.Close()
-
- // Parse response
- var result struct {
- Id string `json:"id"`
- }
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- common.SysLog("failed to decode response: " + err.Error())
- return nil
- }
-
- return &DifyFile{
- UploadFileId: result.Id,
- Type: "image",
- TransferMode: "local_file",
- }
- }
- return nil
-}
-
-func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
- difyReq := DifyChatRequest{
- Inputs: make(map[string]interface{}),
- AutoGenerateName: false,
- }
-
- user := request.User
- if user == "" {
- user = helper.GetResponseID(c)
- }
- difyReq.User = user
-
- files := make([]DifyFile, 0)
- var content strings.Builder
- for _, message := range request.Messages {
- if message.Role == "system" {
- content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
- } else if message.Role == "assistant" {
- content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
- } else {
- parseContent := message.ParseContent()
- for _, mediaContent := range parseContent {
- switch mediaContent.Type {
- case dto.ContentTypeText:
- content.WriteString("USER: \n" + mediaContent.Text + "\n")
- case dto.ContentTypeImageURL:
- media := mediaContent.GetImageMedia()
- var file *DifyFile
- if media.IsRemoteImage() {
- file.Type = media.MimeType
- file.TransferMode = "remote_url"
- file.URL = media.Url
- } else {
- file = uploadDifyFile(c, info, difyReq.User, mediaContent)
- }
- if file != nil {
- files = append(files, *file)
- }
- }
- }
- }
- }
- difyReq.Query = content.String()
- difyReq.Files = files
- mode := "blocking"
- if request.Stream {
- mode = "streaming"
- }
- difyReq.ResponseMode = mode
- return &difyReq
-}
-
-func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
- response := dto.ChatCompletionsStreamResponse{
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "dify",
- }
- var choice dto.ChatCompletionsStreamResponseChoice
- if strings.HasPrefix(difyResponse.Event, "workflow_") {
- if constant.DifyDebug {
- text := "Workflow: " + difyResponse.Data.WorkflowId
- if difyResponse.Event == "workflow_finished" {
- text += " " + difyResponse.Data.Status
- }
- choice.Delta.SetReasoningContent(text + "\n")
- }
- } else if strings.HasPrefix(difyResponse.Event, "node_") {
- if constant.DifyDebug {
- text := "Node: " + difyResponse.Data.NodeType
- if difyResponse.Event == "node_finished" {
- text += " " + difyResponse.Data.Status
- }
- choice.Delta.SetReasoningContent(text + "\n")
- }
- } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
- if difyResponse.Answer == " Thinking...
\n" {
- difyResponse.Answer = ""
- } else if difyResponse.Answer == " " {
- difyResponse.Answer = ""
- }
-
- choice.Delta.SetContentString(difyResponse.Answer)
- }
- response.Choices = append(response.Choices, choice)
- return &response
-}
-
-func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var responseText string
- usage := &dto.Usage{}
- var nodeToken int
- helper.SetEventStreamHeaders(c)
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var difyResponse DifyChunkChatCompletionResponse
- err := json.Unmarshal([]byte(data), &difyResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
- var openaiResponse dto.ChatCompletionsStreamResponse
- if difyResponse.Event == "message_end" {
- usage = &difyResponse.MetaData.Usage
- return false
- } else if difyResponse.Event == "error" {
- return false
- } else {
- openaiResponse = *streamResponseDify2OpenAI(difyResponse)
- if len(openaiResponse.Choices) != 0 {
- responseText += openaiResponse.Choices[0].Delta.GetContentString()
- if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
- nodeToken += 1
- }
- }
- }
- err = helper.ObjectData(c, openaiResponse)
- if err != nil {
- common.SysLog(err.Error())
- }
- return true
- })
- helper.Done(c)
- if usage.TotalTokens == 0 {
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- }
- usage.CompletionTokens += nodeToken
- return usage, nil
-}
-
-func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var difyResponse DifyChatCompletionResponse
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &difyResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- fullTextResponse := dto.OpenAITextResponse{
- Id: difyResponse.ConversationId,
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Usage: difyResponse.MetaData.Usage,
- }
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: difyResponse.Answer,
- },
- FinishReason: "stop",
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- c.Writer.Write(jsonResponse)
- return &difyResponse.MetaData.Usage, nil
-}
diff --git a/new-api/relay/channel/gemini/adaptor.go b/new-api/relay/channel/gemini/adaptor.go
deleted file mode 100644
index 0f40bf74e80b960a03fbc865f3cb8567ffd48493..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/gemini/adaptor.go
+++ /dev/null
@@ -1,254 +0,0 @@
-package gemini
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
- if len(request.Contents) > 0 {
- for i, content := range request.Contents {
- if i == 0 {
- if request.Contents[0].Role == "" {
- request.Contents[0].Role = "user"
- }
- }
- for _, part := range content.Parts {
- if part.FileData != nil {
- if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
- part.FileData.MimeType = "video/webm"
- }
- }
- }
- }
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
- oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
- if err != nil {
- return nil, err
- }
- return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return nil, errors.New("not supported model for image generation")
- }
-
- // convert size to aspect ratio but allow user to specify aspect ratio
- aspectRatio := "1:1" // default aspect ratio
- size := strings.TrimSpace(request.Size)
- if size != "" {
- if strings.Contains(size, ":") {
- aspectRatio = size
- } else {
- switch size {
- case "1024x1024":
- aspectRatio = "1:1"
- case "1024x1792":
- aspectRatio = "9:16"
- case "1792x1024":
- aspectRatio = "16:9"
- }
- }
- }
-
- // build gemini imagen request
- geminiRequest := dto.GeminiImageRequest{
- Instances: []dto.GeminiImageInstance{
- {
- Prompt: request.Prompt,
- },
- },
- Parameters: dto.GeminiImageParameters{
- SampleCount: int(request.N),
- AspectRatio: aspectRatio,
- PersonGeneration: "allow_adult", // default allow adult
- },
- }
-
- return geminiRequest, nil
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-
- if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // 新增逻辑:处理 -thinking- 格式
- if strings.Contains(info.UpstreamModelName, "-thinking-") {
- parts := strings.Split(info.UpstreamModelName, "-thinking-")
- info.UpstreamModelName = parts[0]
- } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
- }
- }
-
- version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
-
- if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
- }
-
- if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
- strings.HasPrefix(info.UpstreamModelName, "embedding") ||
- strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
- action := "embedContent"
- if info.IsGeminiBatchEmbedding {
- action = "batchEmbedContents"
- }
- return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
- }
-
- action := "generateContent"
- if info.IsStream {
- action = "streamGenerateContent?alt=sse"
- if info.RelayMode == constant.RelayModeGemini {
- info.DisablePing = true
- }
- }
- return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("x-goog-api-key", info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
-
- geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
- if err != nil {
- return nil, err
- }
-
- return geminiRequest, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- if request.Input == nil {
- return nil, errors.New("input is required")
- }
-
- inputs := request.ParseInput()
- if len(inputs) == 0 {
- return nil, errors.New("input is empty")
- }
- // We always build a batch-style payload with `requests`, so ensure we call the
- // batch endpoint upstream to avoid payload/endpoint mismatches.
- info.IsGeminiBatchEmbedding = true
- // process all inputs
- geminiRequests := make([]map[string]interface{}, 0, len(inputs))
- for _, input := range inputs {
- geminiRequest := map[string]interface{}{
- "model": fmt.Sprintf("models/%s", info.UpstreamModelName),
- "content": dto.GeminiChatContent{
- Parts: []dto.GeminiPart{
- {
- Text: input,
- },
- },
- },
- }
-
- // set specific parameters for different models
- // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
- switch info.UpstreamModelName {
- case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
- // Only newer models introduced after 2024 support OutputDimensionality
- if request.Dimensions > 0 {
- geminiRequest["outputDimensionality"] = request.Dimensions
- }
- }
- geminiRequests = append(geminiRequests, geminiRequest)
- }
-
- return map[string]interface{}{
- "requests": geminiRequests,
- }, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.RelayMode == constant.RelayModeGemini {
- if strings.Contains(info.RequestURLPath, ":embedContent") ||
- strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
- return NativeGeminiEmbeddingHandler(c, resp, info)
- }
- if info.IsStream {
- return GeminiTextGenerationStreamHandler(c, info, resp)
- } else {
- return GeminiTextGenerationHandler(c, info, resp)
- }
- }
-
- if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return GeminiImageHandler(c, info, resp)
- }
-
- // check if the model is an embedding model
- if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
- strings.HasPrefix(info.UpstreamModelName, "embedding") ||
- strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
- return GeminiEmbeddingHandler(c, info, resp)
- }
-
- if info.IsStream {
- return GeminiChatStreamHandler(c, info, resp)
- } else {
- return GeminiChatHandler(c, info, resp)
- }
-
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/gemini/constant.go b/new-api/relay/channel/gemini/constant.go
deleted file mode 100644
index d509344d0ad6e272440f02858da4c2cc44cfcb65..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/gemini/constant.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package gemini
-
-var ModelList = []string{
- // stable version
- "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
- "gemini-2.0-flash",
- // latest version
- "gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
- // preview version
- "gemini-2.0-flash-lite-preview",
- // gemini exp
- "gemini-exp-1206",
- // flash exp
- "gemini-2.0-flash-exp",
- // pro exp
- "gemini-2.0-pro-exp",
- // thinking exp
- "gemini-2.0-flash-thinking-exp",
- "gemini-2.5-pro-exp-03-25",
- "gemini-2.5-pro-preview-03-25",
- // imagen models
- "imagen-3.0-generate-002",
- // embedding models
- "gemini-embedding-exp-03-07",
- "text-embedding-004",
- "embedding-001",
-}
-
-var SafetySettingList = []string{
- "HARM_CATEGORY_HARASSMENT",
- "HARM_CATEGORY_HATE_SPEECH",
- "HARM_CATEGORY_SEXUALLY_EXPLICIT",
- "HARM_CATEGORY_DANGEROUS_CONTENT",
- "HARM_CATEGORY_CIVIC_INTEGRITY",
-}
-
-var ChannelName = "google gemini"
diff --git a/new-api/relay/channel/gemini/relay-gemini-native.go b/new-api/relay/channel/gemini/relay-gemini-native.go
deleted file mode 100644
index 296a1e8c7b713c9c1ce6cb008d55f1edccce9ce5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/gemini/relay-gemini-native.go
+++ /dev/null
@@ -1,175 +0,0 @@
-package gemini
-
-import (
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/pkg/errors"
-
- "github.com/gin-gonic/gin"
-)
-
-func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- // 读取响应体
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- if common.DebugEnabled {
- println(string(responseBody))
- }
-
- // 解析为 Gemini 原生响应格式
- var geminiResponse dto.GeminiChatResponse
- err = common.Unmarshal(responseBody, &geminiResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- // 计算使用量(基于 UsageMetadata)
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
- }
-
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
-
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- return &usage, nil
-}
-
-func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- if common.DebugEnabled {
- println(string(responseBody))
- }
-
- usage := &dto.Usage{
- PromptTokens: info.PromptTokens,
- TotalTokens: info.PromptTokens,
- }
-
- if info.IsGeminiBatchEmbedding {
- var geminiResponse dto.GeminiBatchEmbeddingResponse
- err = common.Unmarshal(responseBody, &geminiResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- } else {
- var geminiResponse dto.GeminiEmbeddingResponse
- err = common.Unmarshal(responseBody, &geminiResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- }
-
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- return usage, nil
-}
-
-func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var usage = &dto.Usage{}
- var imageCount int
-
- helper.SetEventStreamHeaders(c)
-
- responseText := strings.Builder{}
-
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var geminiResponse dto.GeminiChatResponse
- err := common.UnmarshalJsonStr(data, &geminiResponse)
- if err != nil {
- logger.LogError(c, "error unmarshalling stream response: "+err.Error())
- return false
- }
-
- // 统计图片数量
- for _, candidate := range geminiResponse.Candidates {
- for _, part := range candidate.Content.Parts {
- if part.InlineData != nil && part.InlineData.MimeType != "" {
- imageCount++
- }
- if part.Text != "" {
- responseText.WriteString(part.Text)
- }
- }
- }
-
- // 更新使用量统计
- if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
- usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
- usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
- }
-
- // 直接发送 GeminiChatResponse 响应
- err = helper.StringData(c, data)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- info.SendResponseCount++
- return true
- })
-
- if info.SendResponseCount == 0 {
- return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
- }
-
- if imageCount != 0 {
- if usage.CompletionTokens == 0 {
- usage.CompletionTokens = imageCount * 258
- }
- }
-
- // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
- if usage.CompletionTokens == 0 {
- str := responseText.String()
- if len(str) > 0 {
- usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- // 空补全,不需要使用量
- usage = &dto.Usage{}
- }
- }
-
- // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
- //helper.Done(c)
-
- return usage, nil
-}
diff --git a/new-api/relay/channel/gemini/relay-gemini.go b/new-api/relay/channel/gemini/relay-gemini.go
deleted file mode 100644
index 5a1214781512b3e04feb3f7f41ed9980fd527e34..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/gemini/relay-gemini.go
+++ /dev/null
@@ -1,1193 +0,0 @@
-package gemini
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strconv"
- "strings"
- "unicode/utf8"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
-var geminiSupportedMimeTypes = map[string]bool{
- "application/pdf": true,
- "audio/mpeg": true,
- "audio/mp3": true,
- "audio/wav": true,
- "image/png": true,
- "image/jpeg": true,
- "image/webp": true,
- "text/plain": true,
- "video/mov": true,
- "video/mpeg": true,
- "video/mp4": true,
- "video/mpg": true,
- "video/avi": true,
- "video/wmv": true,
- "video/mpegps": true,
- "video/flv": true,
-}
-
-// Gemini 允许的思考预算范围
-const (
- pro25MinBudget = 128
- pro25MaxBudget = 32768
- flash25MaxBudget = 24576
- flash25LiteMinBudget = 512
- flash25LiteMaxBudget = 24576
-)
-
-func isNew25ProModel(modelName string) bool {
- return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
- !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
- !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
-}
-
-func is25FlashLiteModel(modelName string) bool {
- return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
-}
-
-// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
-func clampThinkingBudget(modelName string, budget int) int {
- isNew25Pro := isNew25ProModel(modelName)
- is25FlashLite := is25FlashLiteModel(modelName)
-
- if is25FlashLite {
- if budget < flash25LiteMinBudget {
- return flash25LiteMinBudget
- }
- if budget > flash25LiteMaxBudget {
- return flash25LiteMaxBudget
- }
- } else if isNew25Pro {
- if budget < pro25MinBudget {
- return pro25MinBudget
- }
- if budget > pro25MaxBudget {
- return pro25MaxBudget
- }
- } else { // 其他模型
- if budget < 0 {
- return 0
- }
- if budget > flash25MaxBudget {
- return flash25MaxBudget
- }
- }
- return budget
-}
-
-// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
-// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
-// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
-func clampThinkingBudgetByEffort(modelName string, effort string) int {
- isNew25Pro := isNew25ProModel(modelName)
- is25FlashLite := is25FlashLiteModel(modelName)
-
- maxBudget := 0
- if is25FlashLite {
- maxBudget = flash25LiteMaxBudget
- }
- if isNew25Pro {
- maxBudget = pro25MaxBudget
- } else {
- maxBudget = flash25MaxBudget
- }
- switch effort {
- case "high":
- maxBudget = maxBudget * 80 / 100
- case "medium":
- maxBudget = maxBudget * 50 / 100
- case "low":
- maxBudget = maxBudget * 20 / 100
- }
- return clampThinkingBudget(modelName, maxBudget)
-}
-
-func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
- if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- modelName := info.UpstreamModelName
- isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
- !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
- !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
-
- if strings.Contains(modelName, "-thinking-") {
- parts := strings.SplitN(modelName, "-thinking-", 2)
- if len(parts) == 2 && parts[1] != "" {
- if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
- clampedBudget := clampThinkingBudget(modelName, budgetTokens)
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(clampedBudget),
- IncludeThoughts: true,
- }
- }
- }
- } else if strings.HasSuffix(modelName, "-thinking") {
- unsupportedModels := []string{
- "gemini-2.5-pro-preview-05-06",
- "gemini-2.5-pro-preview-03-25",
- }
- isUnsupported := false
- for _, unsupportedModel := range unsupportedModels {
- if strings.HasPrefix(modelName, unsupportedModel) {
- isUnsupported = true
- break
- }
- }
-
- if isUnsupported {
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- IncludeThoughts: true,
- }
- } else {
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- IncludeThoughts: true,
- }
- if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
- budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
- clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
- geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
- } else {
- if len(oaiRequest) > 0 {
- // 如果有reasoningEffort参数,则根据其值设置思考预算
- geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
- }
- }
- }
- } else if strings.HasSuffix(modelName, "-nothinking") {
- if !isNew25Pro {
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(0),
- }
- }
- }
- }
-}
-
-// Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
-
- geminiRequest := dto.GeminiChatRequest{
- Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
- GenerationConfig: dto.GeminiChatGenerationConfig{
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- MaxOutputTokens: textRequest.GetMaxTokens(),
- Seed: int64(textRequest.Seed),
- },
- }
-
- if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
- geminiRequest.GenerationConfig.ResponseModalities = []string{
- "TEXT",
- "IMAGE",
- }
- }
-
- adaptorWithExtraBody := false
-
- if len(textRequest.ExtraBody) > 0 {
- if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
- var extraBody map[string]interface{}
- if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
- return nil, fmt.Errorf("invalid extra body: %w", err)
- }
- // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
- if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
- adaptorWithExtraBody = true
- if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
- if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
- budgetInt := int(budget)
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(budgetInt),
- IncludeThoughts: true,
- }
- } else {
- geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
- IncludeThoughts: true,
- }
- }
- }
- }
- }
- }
-
- if !adaptorWithExtraBody {
- ThinkingAdaptor(&geminiRequest, info, textRequest)
- }
-
- safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
- for _, category := range SafetySettingList {
- safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
- Category: category,
- Threshold: model_setting.GetGeminiSafetySetting(category),
- })
- }
- geminiRequest.SafetySettings = safetySettings
-
- // openaiContent.FuncToToolCalls()
- if textRequest.Tools != nil {
- functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
- googleSearch := false
- codeExecution := false
- urlContext := false
- for _, tool := range textRequest.Tools {
- if tool.Function.Name == "googleSearch" {
- googleSearch = true
- continue
- }
- if tool.Function.Name == "codeExecution" {
- codeExecution = true
- continue
- }
- if tool.Function.Name == "urlContext" {
- urlContext = true
- continue
- }
- if tool.Function.Parameters != nil {
-
- params, ok := tool.Function.Parameters.(map[string]interface{})
- if ok {
- if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
- if len(props) == 0 {
- tool.Function.Parameters = nil
- }
- }
- }
- }
- // Clean the parameters before appending
- cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
- tool.Function.Parameters = cleanedParams
- functions = append(functions, tool.Function)
- }
- geminiTools := geminiRequest.GetTools()
- if codeExecution {
- geminiTools = append(geminiTools, dto.GeminiChatTool{
- CodeExecution: make(map[string]string),
- })
- }
- if googleSearch {
- geminiTools = append(geminiTools, dto.GeminiChatTool{
- GoogleSearch: make(map[string]string),
- })
- }
- if urlContext {
- geminiTools = append(geminiTools, dto.GeminiChatTool{
- URLContext: make(map[string]string),
- })
- }
- if len(functions) > 0 {
- geminiTools = append(geminiTools, dto.GeminiChatTool{
- FunctionDeclarations: functions,
- })
- }
- geminiRequest.SetTools(geminiTools)
- }
-
- if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
- geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
-
- if len(textRequest.ResponseFormat.JsonSchema) > 0 {
- // 先将json.RawMessage解析
- var jsonSchema dto.FormatJsonSchema
- if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
- cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
- geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
- }
- }
- }
- tool_call_ids := make(map[string]string)
- var system_content []string
- //shouldAddDummyModelMessage := false
- for _, message := range textRequest.Messages {
- if message.Role == "system" {
- system_content = append(system_content, message.StringContent())
- continue
- } else if message.Role == "tool" || message.Role == "function" {
- if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
- geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
- Role: "user",
- })
- }
- var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
- name := ""
- if message.Name != nil {
- name = *message.Name
- } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
- name = val
- }
- var contentMap map[string]interface{}
- contentStr := message.StringContent()
-
- // 1. 尝试解析为 JSON 对象
- if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
- // 2. 如果失败,尝试解析为 JSON 数组
- var contentSlice []interface{}
- if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
- // 如果是数组,包装成对象
- contentMap = map[string]interface{}{"result": contentSlice}
- } else {
- // 3. 如果再次失败,作为纯文本处理
- contentMap = map[string]interface{}{"content": contentStr}
- }
- }
-
- functionResp := &dto.GeminiFunctionResponse{
- Name: name,
- Response: contentMap,
- }
-
- *parts = append(*parts, dto.GeminiPart{
- FunctionResponse: functionResp,
- })
- continue
- }
- var parts []dto.GeminiPart
- content := dto.GeminiChatContent{
- Role: message.Role,
- }
- // isToolCall := false
- if message.ToolCalls != nil {
- // message.Role = "model"
- // isToolCall = true
- for _, call := range message.ParseToolCalls() {
- args := map[string]interface{}{}
- if call.Function.Arguments != "" {
- if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
- return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
- }
- }
- toolCall := dto.GeminiPart{
- FunctionCall: &dto.FunctionCall{
- FunctionName: call.Function.Name,
- Arguments: args,
- },
- }
- parts = append(parts, toolCall)
- tool_call_ids[call.ID] = call.Function.Name
- }
- }
-
- openaiContent := message.ParseContent()
- imageNum := 0
- for _, part := range openaiContent {
- if part.Type == dto.ContentTypeText {
- if part.Text == "" {
- continue
- }
- parts = append(parts, dto.GeminiPart{
- Text: part.Text,
- })
- } else if part.Type == dto.ContentTypeImageURL {
- imageNum += 1
-
- if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
- return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
- }
- // 判断是否是url
- if strings.HasPrefix(part.GetImageMedia().Url, "http") {
- // 是url,获取文件的类型和base64编码的数据
- fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
- if err != nil {
- return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
- }
-
- // 校验 MimeType 是否在 Gemini 支持的白名单中
- if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
- url := part.GetImageMedia().Url
- return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
- }
-
- parts = append(parts, dto.GeminiPart{
- InlineData: &dto.GeminiInlineData{
- MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
- Data: fileData.Base64Data,
- },
- })
- } else {
- format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
- if err != nil {
- return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
- }
- parts = append(parts, dto.GeminiPart{
- InlineData: &dto.GeminiInlineData{
- MimeType: format,
- Data: base64String,
- },
- })
- }
- } else if part.Type == dto.ContentTypeFile {
- if part.GetFile().FileId != "" {
- return nil, fmt.Errorf("only base64 file is supported in gemini")
- }
- format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
- if err != nil {
- return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
- }
- parts = append(parts, dto.GeminiPart{
- InlineData: &dto.GeminiInlineData{
- MimeType: format,
- Data: base64String,
- },
- })
- } else if part.Type == dto.ContentTypeInputAudio {
- if part.GetInputAudio().Data == "" {
- return nil, fmt.Errorf("only base64 audio is supported in gemini")
- }
- base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
- if err != nil {
- return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
- }
- parts = append(parts, dto.GeminiPart{
- InlineData: &dto.GeminiInlineData{
- MimeType: "audio/" + part.GetInputAudio().Format,
- Data: base64String,
- },
- })
- }
- }
-
- content.Parts = parts
-
- // there's no assistant role in gemini and API shall vomit if Role is not user or model
- if content.Role == "assistant" {
- content.Role = "model"
- }
- if len(content.Parts) > 0 {
- geminiRequest.Contents = append(geminiRequest.Contents, content)
- }
- }
-
- if len(system_content) > 0 {
- geminiRequest.SystemInstructions = &dto.GeminiChatContent{
- Parts: []dto.GeminiPart{
- {
- Text: strings.Join(system_content, "\n"),
- },
- },
- }
- }
-
- return &geminiRequest, nil
-}
-
-// Helper function to get a list of supported MIME types for error messages
-func getSupportedMimeTypesList() []string {
- keys := make([]string, 0, len(geminiSupportedMimeTypes))
- for k := range geminiSupportedMimeTypes {
- keys = append(keys, k)
- }
- return keys
-}
-
-// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
-func cleanFunctionParameters(params interface{}) interface{} {
- if params == nil {
- return nil
- }
-
- switch v := params.(type) {
- case map[string]interface{}:
- // Create a copy to avoid modifying the original
- cleanedMap := make(map[string]interface{})
- for k, val := range v {
- cleanedMap[k] = val
- }
-
- // Remove unsupported root-level fields
- delete(cleanedMap, "default")
- delete(cleanedMap, "exclusiveMaximum")
- delete(cleanedMap, "exclusiveMinimum")
- delete(cleanedMap, "$schema")
- delete(cleanedMap, "additionalProperties")
-
- // Check and clean 'format' for string types
- if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
- if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
- if formatValue != "enum" && formatValue != "date-time" {
- delete(cleanedMap, "format")
- }
- }
- }
-
- // Clean properties
- if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
- cleanedProps := make(map[string]interface{})
- for propName, propValue := range props {
- cleanedProps[propName] = cleanFunctionParameters(propValue)
- }
- cleanedMap["properties"] = cleanedProps
- }
-
- // Recursively clean items in arrays
- if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
- cleanedMap["items"] = cleanFunctionParameters(items)
- }
- // Also handle items if it's an array of schemas
- if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
- cleanedItemsArray := make([]interface{}, len(itemsArray))
- for i, item := range itemsArray {
- cleanedItemsArray[i] = cleanFunctionParameters(item)
- }
- cleanedMap["items"] = cleanedItemsArray
- }
-
- // Recursively clean other schema composition keywords
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
- if nested, ok := cleanedMap[field].([]interface{}); ok {
- cleanedNested := make([]interface{}, len(nested))
- for i, item := range nested {
- cleanedNested[i] = cleanFunctionParameters(item)
- }
- cleanedMap[field] = cleanedNested
- }
- }
-
- // Recursively clean patternProperties
- if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
- cleanedPatternProps := make(map[string]interface{})
- for pattern, schema := range patternProps {
- cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
- }
- cleanedMap["patternProperties"] = cleanedPatternProps
- }
-
- // Recursively clean definitions
- if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
- cleanedDefinitions := make(map[string]interface{})
- for defName, defSchema := range definitions {
- cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
- }
- cleanedMap["definitions"] = cleanedDefinitions
- }
-
- // Recursively clean $defs (newer JSON Schema draft)
- if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
- cleanedDefs := make(map[string]interface{})
- for defName, defSchema := range defs {
- cleanedDefs[defName] = cleanFunctionParameters(defSchema)
- }
- cleanedMap["$defs"] = cleanedDefs
- }
-
- // Clean conditional keywords
- for _, field := range []string{"if", "then", "else", "not"} {
- if nested, ok := cleanedMap[field]; ok {
- cleanedMap[field] = cleanFunctionParameters(nested)
- }
- }
-
- return cleanedMap
-
- case []interface{}:
- // Handle arrays of schemas
- cleanedArray := make([]interface{}, len(v))
- for i, item := range v {
- cleanedArray[i] = cleanFunctionParameters(item)
- }
- return cleanedArray
-
- default:
- // Not a map or array, return as is (e.g., could be a primitive)
- return params
- }
-}
-
-func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
- if depth >= 5 {
- return schema
- }
-
- v, ok := schema.(map[string]interface{})
- if !ok || len(v) == 0 {
- return schema
- }
- // 删除所有的title字段
- delete(v, "title")
- delete(v, "$schema")
- // 如果type不为object和array,则直接返回
- if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
- return schema
- }
- switch v["type"] {
- case "object":
- delete(v, "additionalProperties")
- // 处理 properties
- if properties, ok := v["properties"].(map[string]interface{}); ok {
- for key, value := range properties {
- properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
- }
- }
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
- if nested, ok := v[field].([]interface{}); ok {
- for i, item := range nested {
- nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
- }
- }
- }
- case "array":
- if items, ok := v["items"].(map[string]interface{}); ok {
- v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
- }
- }
-
- return v
-}
-
-func unescapeString(s string) (string, error) {
- var result []rune
- escaped := false
- i := 0
-
- for i < len(s) {
- r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
- if r == utf8.RuneError {
- return "", fmt.Errorf("invalid UTF-8 encoding")
- }
-
- if escaped {
- // 如果是转义符后的字符,检查其类型
- switch r {
- case '"':
- result = append(result, '"')
- case '\\':
- result = append(result, '\\')
- case '/':
- result = append(result, '/')
- case 'b':
- result = append(result, '\b')
- case 'f':
- result = append(result, '\f')
- case 'n':
- result = append(result, '\n')
- case 'r':
- result = append(result, '\r')
- case 't':
- result = append(result, '\t')
- case '\'':
- result = append(result, '\'')
- default:
- // 如果遇到一个非法的转义字符,直接按原样输出
- result = append(result, '\\', r)
- }
- escaped = false
- } else {
- if r == '\\' {
- escaped = true // 记录反斜杠作为转义符
- } else {
- result = append(result, r)
- }
- }
- i += size // 移动到下一个字符
- }
-
- return string(result), nil
-}
-func unescapeMapOrSlice(data interface{}) interface{} {
- switch v := data.(type) {
- case map[string]interface{}:
- for k, val := range v {
- v[k] = unescapeMapOrSlice(val)
- }
- case []interface{}:
- for i, val := range v {
- v[i] = unescapeMapOrSlice(val)
- }
- case string:
- if unescaped, err := unescapeString(v); err != nil {
- return v
- } else {
- return unescaped
- }
- }
- return data
-}
-
-func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
- var argsBytes []byte
- var err error
- if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
- argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
- } else {
- argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
- }
-
- if err != nil {
- return nil
- }
- return &dto.ToolCallResponse{
- ID: fmt.Sprintf("call_%s", common.GetUUID()),
- Type: "function",
- Function: dto.FunctionResponse{
- Arguments: string(argsBytes),
- Name: item.FunctionCall.FunctionName,
- },
- }
-}
-
-func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
- fullTextResponse := dto.OpenAITextResponse{
- Id: helper.GetResponseID(c),
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
- }
- isToolCall := false
- for _, candidate := range response.Candidates {
- choice := dto.OpenAITextResponseChoice{
- Index: int(candidate.Index),
- Message: dto.Message{
- Role: "assistant",
- Content: "",
- },
- FinishReason: constant.FinishReasonStop,
- }
- if len(candidate.Content.Parts) > 0 {
- var texts []string
- var toolCalls []dto.ToolCallResponse
- for _, part := range candidate.Content.Parts {
- if part.InlineData != nil {
- // 媒体内容
- if strings.HasPrefix(part.InlineData.MimeType, "image") {
- imgText := ""
- texts = append(texts, imgText)
- } else {
- // 其他媒体类型,直接显示链接
- texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
- }
- } else if part.FunctionCall != nil {
- choice.FinishReason = constant.FinishReasonToolCalls
- if call := getResponseToolCall(&part); call != nil {
- toolCalls = append(toolCalls, *call)
- }
- } else if part.Thought {
- choice.Message.ReasoningContent = part.Text
- } else {
- if part.ExecutableCode != nil {
- texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
- } else if part.CodeExecutionResult != nil {
- texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
- } else {
- // 过滤掉空行
- if part.Text != "\n" {
- texts = append(texts, part.Text)
- }
- }
- }
- }
- if len(toolCalls) > 0 {
- choice.Message.SetToolCalls(toolCalls)
- isToolCall = true
- }
- choice.Message.SetStringContent(strings.Join(texts, "\n"))
-
- }
- if candidate.FinishReason != nil {
- switch *candidate.FinishReason {
- case "STOP":
- choice.FinishReason = constant.FinishReasonStop
- case "MAX_TOKENS":
- choice.FinishReason = constant.FinishReasonLength
- default:
- choice.FinishReason = constant.FinishReasonContentFilter
- }
- }
- if isToolCall {
- choice.FinishReason = constant.FinishReasonToolCalls
- }
-
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- }
- return &fullTextResponse
-}
-
-func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
- choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
- isStop := false
- for _, candidate := range geminiResponse.Candidates {
- if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
- isStop = true
- candidate.FinishReason = nil
- }
- choice := dto.ChatCompletionsStreamResponseChoice{
- Index: int(candidate.Index),
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
- //Role: "assistant",
- },
- }
- var texts []string
- isTools := false
- isThought := false
- if candidate.FinishReason != nil {
- // p := GeminiConvertFinishReason(*candidate.FinishReason)
- switch *candidate.FinishReason {
- case "STOP":
- choice.FinishReason = &constant.FinishReasonStop
- case "MAX_TOKENS":
- choice.FinishReason = &constant.FinishReasonLength
- default:
- choice.FinishReason = &constant.FinishReasonContentFilter
- }
- }
- for _, part := range candidate.Content.Parts {
- if part.InlineData != nil {
- if strings.HasPrefix(part.InlineData.MimeType, "image") {
- imgText := ""
- texts = append(texts, imgText)
- }
- } else if part.FunctionCall != nil {
- isTools = true
- if call := getResponseToolCall(&part); call != nil {
- call.SetIndex(len(choice.Delta.ToolCalls))
- choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
- }
-
- } else if part.Thought {
- isThought = true
- texts = append(texts, part.Text)
- } else {
- if part.ExecutableCode != nil {
- texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
- } else if part.CodeExecutionResult != nil {
- texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
- } else {
- if part.Text != "\n" {
- texts = append(texts, part.Text)
- }
- }
- }
- }
- if isThought {
- choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
- } else {
- choice.Delta.SetContentString(strings.Join(texts, "\n"))
- }
- if isTools {
- choice.FinishReason = &constant.FinishReasonToolCalls
- }
- choices = append(choices, choice)
- }
-
- var response dto.ChatCompletionsStreamResponse
- response.Object = "chat.completion.chunk"
- response.Choices = choices
- return &response, isStop
-}
-
-func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
- streamData, err := common.Marshal(resp)
- if err != nil {
- return fmt.Errorf("failed to marshal stream response: %w", err)
- }
- err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
- if err != nil {
- return fmt.Errorf("failed to handle stream format: %w", err)
- }
- return nil
-}
-
-func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
- streamData, err := common.Marshal(resp)
- if err != nil {
- return fmt.Errorf("failed to marshal stream response: %w", err)
- }
- openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
- return nil
-}
-
-func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- // responseText := ""
- id := helper.GetResponseID(c)
- createAt := common.GetTimestamp()
- responseText := strings.Builder{}
- var usage = &dto.Usage{}
- var imageCount int
- finishReason := constant.FinishReasonStop
-
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var geminiResponse dto.GeminiChatResponse
- err := common.UnmarshalJsonStr(data, &geminiResponse)
- if err != nil {
- logger.LogError(c, "error unmarshalling stream response: "+err.Error())
- return false
- }
-
- for _, candidate := range geminiResponse.Candidates {
- for _, part := range candidate.Content.Parts {
- if part.InlineData != nil && part.InlineData.MimeType != "" {
- imageCount++
- }
- if part.Text != "" {
- responseText.WriteString(part.Text)
- }
- }
- }
-
- response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
-
- response.Id = id
- response.Created = createAt
- response.Model = info.UpstreamModelName
- if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
- usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
- usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
- }
- logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
- if info.SendResponseCount == 0 {
- // send first response
- emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
- if response.IsToolCall() {
- emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
- emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
- emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
- finishReason = constant.FinishReasonToolCalls
- err = handleStream(c, info, emptyResponse)
- if err != nil {
- logger.LogError(c, err.Error())
- }
-
- response.ClearToolCalls()
- if response.IsFinished() {
- response.Choices[0].FinishReason = nil
- }
- } else {
- err = handleStream(c, info, emptyResponse)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- }
- }
-
- err = handleStream(c, info, response)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- if isStop {
- _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
- }
- return true
- })
-
- if info.SendResponseCount == 0 {
- // 空补全,报错不计费
- // empty response, throw an error
- return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
- }
-
- if imageCount != 0 {
- if usage.CompletionTokens == 0 {
- usage.CompletionTokens = imageCount * 258
- }
- }
-
- usage.PromptTokensDetails.TextTokens = usage.PromptTokens
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
- if usage.CompletionTokens == 0 {
- str := responseText.String()
- if len(str) > 0 {
- usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- // 空补全,不需要使用量
- usage = &dto.Usage{}
- }
- }
-
- response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
- err := handleFinalStream(c, info, response)
- if err != nil {
- common.SysLog("send final response failed: " + err.Error())
- }
- //if info.RelayFormat == relaycommon.RelayFormatOpenAI {
- // helper.Done(c)
- //}
- //resp.Body.Close()
- return usage, nil
-}
-
-func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- if common.DebugEnabled {
- println(string(responseBody))
- }
- var geminiResponse dto.GeminiChatResponse
- err = common.Unmarshal(responseBody, &geminiResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if len(geminiResponse.Candidates) == 0 {
- return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
- fullTextResponse.Model = info.UpstreamModelName
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
- }
-
- usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
- for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
- if detail.Modality == "AUDIO" {
- usage.PromptTokensDetails.AudioTokens = detail.TokenCount
- } else if detail.Modality == "TEXT" {
- usage.PromptTokensDetails.TextTokens = detail.TokenCount
- }
- }
-
- fullTextResponse.Usage = usage
-
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- responseBody, err = common.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- case types.RelayFormatClaude:
- claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
- claudeRespStr, err := common.Marshal(claudeResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- responseBody = claudeRespStr
- case types.RelayFormatGemini:
- break
- }
-
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- return &usage, nil
-}
-
-func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- responseBody, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- var geminiResponse dto.GeminiBatchEmbeddingResponse
- if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
- return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- // convert to openai format response
- openAIResponse := dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
- Model: info.UpstreamModelName,
- }
-
- for i, embedding := range geminiResponse.Embeddings {
- openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
- Object: "embedding",
- Embedding: embedding.Values,
- Index: i,
- })
- }
-
- // calculate usage
- // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
- // Google has not yet clarified how embedding models will be billed
- // refer to openai billing method to use input tokens billing
- // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
- usage := &dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: info.PromptTokens,
- }
- openAIResponse.Usage = *usage
-
- jsonResponse, jsonErr := common.Marshal(openAIResponse)
- if jsonErr != nil {
- return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return usage, nil
-}
-
-func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- _ = resp.Body.Close()
-
- var geminiResponse dto.GeminiImageResponse
- if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
- return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- if len(geminiResponse.Predictions) == 0 {
- return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- // convert to openai format response
- openAIResponse := dto.ImageResponse{
- Created: common.GetTimestamp(),
- Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
- }
-
- for _, prediction := range geminiResponse.Predictions {
- if prediction.RaiFilteredReason != "" {
- continue // skip filtered image
- }
- openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
- B64Json: prediction.BytesBase64Encoded,
- })
- }
-
- jsonResponse, jsonErr := json.Marshal(openAIResponse)
- if jsonErr != nil {
- return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
-
- // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
- // each image has fixed 258 tokens
- const imageTokens = 258
- generatedImages := len(openAIResponse.Data)
-
- usage := &dto.Usage{
- PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
- CompletionTokens: 0, // image generation does not calculate completion tokens
- TotalTokens: imageTokens * generatedImages,
- }
-
- return usage, nil
-}
diff --git a/new-api/relay/channel/jimeng/adaptor.go b/new-api/relay/channel/jimeng/adaptor.go
deleted file mode 100644
index 4b48f4d98c6af4e79fc273c5756e6270366730e7..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jimeng/adaptor.go
+++ /dev/null
@@ -1,142 +0,0 @@
-package jimeng
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
- return errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return request, nil
-}
-
-type LogoInfo struct {
- AddLogo bool `json:"add_logo,omitempty"`
- Position int `json:"position,omitempty"`
- Language int `json:"language,omitempty"`
- Opacity float64 `json:"opacity,omitempty"`
- LogoTextContent string `json:"logo_text_content,omitempty"`
-}
-
-type imageRequestPayload struct {
- ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
- Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
- Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
- Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
- Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
- UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
- UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
- ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
- LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
- ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
- BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- payload := imageRequestPayload{
- ReqKey: request.Model,
- Prompt: request.Prompt,
- }
- if request.ResponseFormat == "" || request.ResponseFormat == "url" {
- payload.ReturnURL = true // Default to returning image URLs
- }
-
- if len(request.ExtraFields) > 0 {
- if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
- return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
- }
- }
-
- return payload, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- fullRequestURL, err := a.GetRequestURL(info)
- if err != nil {
- return nil, fmt.Errorf("get request url failed: %w", err)
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return nil, fmt.Errorf("new request failed: %w", err)
- }
- err = Sign(c, req, info.ApiKey)
- if err != nil {
- return nil, fmt.Errorf("setup request header failed: %w", err)
- }
- resp, err := channel.DoRequest(c, req, info)
- if err != nil {
- return nil, fmt.Errorf("do request failed: %w", err)
- }
- return resp, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.RelayMode == relayconstant.RelayModeImagesGenerations {
- usage, err = jimengImageHandler(c, resp, info)
- } else if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/jimeng/constants.go b/new-api/relay/channel/jimeng/constants.go
deleted file mode 100644
index 74fad49cc515f913a085eadde7c30b9f196cfcd8..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jimeng/constants.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package jimeng
-
-const (
- ChannelName = "jimeng"
-)
-
-var ModelList = []string{
- "jimeng_high_aes_general_v21_L",
-}
diff --git a/new-api/relay/channel/jimeng/image.go b/new-api/relay/channel/jimeng/image.go
deleted file mode 100644
index f061a63b5eb4e3b59f6defec6ccb247b581735cf..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jimeng/image.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package jimeng
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type ImageResponse struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data struct {
- BinaryDataBase64 []string `json:"binary_data_base64"`
- ImageUrls []string `json:"image_urls"`
- RephraseResult string `json:"rephraser_result"`
- RequestID string `json:"request_id"`
- // Other fields are omitted for brevity
- } `json:"data"`
- RequestID string `json:"request_id"`
- Status int `json:"status"`
- TimeElapsed string `json:"time_elapsed"`
-}
-
-func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
- imageResponse := dto.ImageResponse{
- Created: info.StartTime.Unix(),
- }
-
- for _, base64Data := range response.Data.BinaryDataBase64 {
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{
- B64Json: base64Data,
- })
- }
- for _, imageUrl := range response.Data.ImageUrls {
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{
- Url: imageUrl,
- })
- }
-
- return &imageResponse
-}
-
-// jimengImageHandler handles the Jimeng image generation response
-func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
- var jimengResponse ImageResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
-
- err = json.Unmarshal(responseBody, &jimengResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- // Check if the response indicates an error
- if jimengResponse.Code != 10000 {
- return nil, types.WithOpenAIError(types.OpenAIError{
- Message: jimengResponse.Message,
- Type: "jimeng_error",
- Param: "",
- Code: fmt.Sprintf("%d", jimengResponse.Code),
- }, resp.StatusCode)
- }
-
- // Convert Jimeng response to OpenAI format
- fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
-
- return &dto.Usage{}, nil
-}
diff --git a/new-api/relay/channel/jimeng/sign.go b/new-api/relay/channel/jimeng/sign.go
deleted file mode 100644
index 8d18f45d69c876385ecb32336dfdadf006a54867..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jimeng/sign.go
+++ /dev/null
@@ -1,176 +0,0 @@
-package jimeng
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "io"
- "net/http"
- "net/url"
- "one-api/logger"
- "sort"
- "strings"
- "time"
-)
-
-// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
-//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
-// var bodyBytes []byte
-// var err error
-//
-// if req.Body != nil {
-// bodyBytes, err = io.ReadAll(req.Body)
-// if err != nil {
-// return fmt.Errorf("read request body failed: %w", err)
-// }
-// _ = req.Body.Close()
-// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
-// } else {
-// bodyBytes = []byte{}
-// }
-//
-// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
-//}
-
-const HexPayloadHashKey = "HexPayloadHash"
-
-func SetPayloadHash(c *gin.Context, req any) error {
- body, err := json.Marshal(req)
- if err != nil {
- return err
- }
- logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
- payloadHash := sha256.Sum256(body)
- hexPayloadHash := hex.EncodeToString(payloadHash[:])
- c.Set(HexPayloadHashKey, hexPayloadHash)
- return nil
-}
-func getPayloadHash(c *gin.Context) string {
- return c.GetString(HexPayloadHashKey)
-}
-
-func Sign(c *gin.Context, req *http.Request, apiKey string) error {
- header := req.Header
-
- var bodyBytes []byte
- var err error
-
- if req.Body != nil {
- bodyBytes, err = io.ReadAll(req.Body)
- if err != nil {
- return err
- }
- _ = req.Body.Close()
- req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
- }
-
- payloadHash := sha256.Sum256(bodyBytes)
- hexPayloadHash := hex.EncodeToString(payloadHash[:])
-
- method := c.Request.Method
- u := req.URL
- keyParts := strings.Split(apiKey, "|")
- if len(keyParts) != 2 {
- return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- secretKey := strings.TrimSpace(keyParts[1])
- t := time.Now().UTC()
- xDate := t.Format("20060102T150405Z")
- shortDate := t.Format("20060102")
-
- host := u.Host
- header.Set("Host", host)
- header.Set("X-Date", xDate)
- header.Set("X-Content-Sha256", hexPayloadHash)
-
- // Sort and encode query parameters to create canonical query string
- queryParams := u.Query()
- sortedKeys := make([]string, 0, len(queryParams))
- for k := range queryParams {
- sortedKeys = append(sortedKeys, k)
- }
- sort.Strings(sortedKeys)
- var queryParts []string
- for _, k := range sortedKeys {
- values := queryParams[k]
- sort.Strings(values)
- for _, v := range values {
- queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
- }
- }
- canonicalQueryString := strings.Join(queryParts, "&")
-
- headersToSign := map[string]string{
- "host": host,
- "x-date": xDate,
- "x-content-sha256": hexPayloadHash,
- }
- if header.Get("Content-Type") == "" {
- header.Set("Content-Type", "application/json")
- }
- headersToSign["content-type"] = header.Get("Content-Type")
-
- var signedHeaderKeys []string
- for k := range headersToSign {
- signedHeaderKeys = append(signedHeaderKeys, k)
- }
- sort.Strings(signedHeaderKeys)
-
- var canonicalHeaders strings.Builder
- for _, k := range signedHeaderKeys {
- canonicalHeaders.WriteString(k)
- canonicalHeaders.WriteString(":")
- canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
- canonicalHeaders.WriteString("\n")
- }
- signedHeaders := strings.Join(signedHeaderKeys, ";")
-
- canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
- method,
- u.Path,
- canonicalQueryString,
- canonicalHeaders.String(),
- signedHeaders,
- hexPayloadHash,
- )
-
- hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
- hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
-
- region := "cn-north-1"
- serviceName := "cv"
- credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
- stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
- xDate,
- credentialScope,
- hexHashedCanonicalRequest,
- )
-
- kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
- kRegion := hmacSHA256(kDate, []byte(region))
- kService := hmacSHA256(kRegion, []byte(serviceName))
- kSigning := hmacSHA256(kService, []byte("request"))
- signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
-
- authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- accessKey,
- credentialScope,
- signedHeaders,
- signature,
- )
- header.Set("Authorization", authorization)
- return nil
-}
-
-// hmacSHA256 计算 HMAC-SHA256
-func hmacSHA256(key []byte, data []byte) []byte {
- h := hmac.New(sha256.New, key)
- h.Write(data)
- return h.Sum(nil)
-}
diff --git a/new-api/relay/channel/jina/adaptor.go b/new-api/relay/channel/jina/adaptor.go
deleted file mode 100644
index 8e4cf7cb918aaa25a9519db3faf7bf7a01041513..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jina/adaptor.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package jina
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/common_handler"
- "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeEmbeddings {
- return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
- }
- return "", errors.New("invalid relay mode")
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- request.EncodingFormat = ""
- return request, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.RelayMode == constant.RelayModeRerank {
- usage, err = common_handler.RerankHandler(c, info, resp)
- } else if info.RelayMode == constant.RelayModeEmbeddings {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/jina/constant.go b/new-api/relay/channel/jina/constant.go
deleted file mode 100644
index 009cfaf27d6e8f0a7275c76f544da9429232e872..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jina/constant.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package jina
-
-var ModelList = []string{
- "jina-clip-v1",
- "jina-reranker-v2-base-multilingual",
- "jina-reranker-m0",
-}
-
-var ChannelName = "jina"
diff --git a/new-api/relay/channel/jina/relay-jina.go b/new-api/relay/channel/jina/relay-jina.go
deleted file mode 100644
index 783296fb6b2eec274a8a14c9b9527645c797aa13..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/jina/relay-jina.go
+++ /dev/null
@@ -1 +0,0 @@
-package jina
diff --git a/new-api/relay/channel/lingyiwanwu/constrants.go b/new-api/relay/channel/lingyiwanwu/constrants.go
deleted file mode 100644
index 12b390c897bfd2f84c1b5e703291279f40110dc7..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/lingyiwanwu/constrants.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package lingyiwanwu
-
-// https://platform.lingyiwanwu.com/docs
-
-var ModelList = []string{
- "yi-large", "yi-medium", "yi-vision", "yi-medium-200k", "yi-spark", "yi-large-rag", "yi-large-turbo", "yi-large-preview", "yi-large-rag-preview",
-}
-
-var ChannelName = "lingyiwanwu"
diff --git a/new-api/relay/channel/minimax/constants.go b/new-api/relay/channel/minimax/constants.go
deleted file mode 100644
index 9a1a8180a50b964300139336361c18b14afeea86..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/minimax/constants.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package minimax
-
-// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
-
-var ModelList = []string{
- "abab6.5-chat",
- "abab6.5s-chat",
- "abab6-chat",
- "abab5.5-chat",
- "abab5.5s-chat",
-}
-
-var ChannelName = "minimax"
diff --git a/new-api/relay/channel/minimax/relay-minimax.go b/new-api/relay/channel/minimax/relay-minimax.go
deleted file mode 100644
index 9581ba85a3623bffeb7db1aa94e050fdfc78ff54..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/minimax/relay-minimax.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package minimax
-
-import (
- "fmt"
- relaycommon "one-api/relay/common"
-)
-
-func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
-}
diff --git a/new-api/relay/channel/mistral/adaptor.go b/new-api/relay/channel/mistral/adaptor.go
deleted file mode 100644
index b8fe1890f8c85f70c4c6338480be9eacf8d7fcd0..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mistral/adaptor.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package mistral
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return requestOpenAI2Mistral(request), nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/mistral/constants.go b/new-api/relay/channel/mistral/constants.go
deleted file mode 100644
index 2fb408dbac7b72e26b9d17eb1ce9f60535993cfa..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mistral/constants.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package mistral
-
-var ModelList = []string{
- "open-mistral-7b",
- "open-mixtral-8x7b",
- "mistral-small-latest",
- "mistral-medium-latest",
- "mistral-large-latest",
- "mistral-embed",
-}
-
-var ChannelName = "mistral"
diff --git a/new-api/relay/channel/mistral/text.go b/new-api/relay/channel/mistral/text.go
deleted file mode 100644
index 0545e5e56f2243e978609e60f6ae2e93c44d9430..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mistral/text.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package mistral
-
-import (
- "one-api/common"
- "one-api/dto"
- "regexp"
-)
-
-var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
-
-func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
- messages := make([]dto.Message, 0, len(request.Messages))
- idMap := make(map[string]string)
- for _, message := range request.Messages {
- // 1. tool_calls.id
- toolCalls := message.ParseToolCalls()
- if toolCalls != nil {
- for i := range toolCalls {
- if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
- if newId, ok := idMap[toolCalls[i].ID]; ok {
- toolCalls[i].ID = newId
- } else {
- newId, err := common.GenerateRandomCharsKey(9)
- if err == nil {
- idMap[toolCalls[i].ID] = newId
- toolCalls[i].ID = newId
- }
- }
- }
- }
- message.SetToolCalls(toolCalls)
- }
-
- // 2. tool_call_id
- if message.ToolCallId != "" {
- if newId, ok := idMap[message.ToolCallId]; ok {
- message.ToolCallId = newId
- } else {
- if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
- newId, err := common.GenerateRandomCharsKey(9)
- if err == nil {
- idMap[message.ToolCallId] = newId
- message.ToolCallId = newId
- }
- }
- }
- }
-
- mediaMessages := message.ParseContent()
- if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
- mediaMessages = []dto.MediaContent{}
- }
- for j, mediaMessage := range mediaMessages {
- if mediaMessage.Type == dto.ContentTypeImageURL {
- imageUrl := mediaMessage.GetImageMedia()
- mediaMessage.ImageUrl = imageUrl.Url
- mediaMessages[j] = mediaMessage
- }
- }
- message.SetMediaContent(mediaMessages)
- messages = append(messages, dto.Message{
- Role: message.Role,
- Content: message.Content,
- ToolCalls: message.ToolCalls,
- ToolCallId: message.ToolCallId,
- })
- }
- return &dto.GeneralOpenAIRequest{
- Model: request.Model,
- Stream: request.Stream,
- Messages: messages,
- Temperature: request.Temperature,
- TopP: request.TopP,
- MaxTokens: request.GetMaxTokens(),
- Tools: request.Tools,
- ToolChoice: request.ToolChoice,
- }
-}
diff --git a/new-api/relay/channel/mokaai/adaptor.go b/new-api/relay/channel/mokaai/adaptor.go
deleted file mode 100644
index bdb7d21ea386f1ee7a7c5105715e2c129570ea5a..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mokaai/adaptor.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package mokaai
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return request, nil
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
- suffix := "chat/"
- if strings.HasPrefix(info.UpstreamModelName, "m3e") {
- suffix = "embeddings"
- }
- fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
- return fullRequestURL, nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
- return baiduEmbeddingRequest, nil
- default:
- return nil, errors.New("not implemented")
- }
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
-
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- return mokaEmbeddingHandler(c, info, resp)
- default:
- // err, usage = mokaHandler(c, resp)
-
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/mokaai/constants.go b/new-api/relay/channel/mokaai/constants.go
deleted file mode 100644
index 01eb60e7c913e8dc31bbec213732c6eb09f08b2c..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mokaai/constants.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package mokaai
-
-var ModelList = []string{
- "m3e-large",
- "m3e-base",
- "m3e-small",
-}
-
-var ChannelName = "mokaai"
diff --git a/new-api/relay/channel/mokaai/relay-mokaai.go b/new-api/relay/channel/mokaai/relay-mokaai.go
deleted file mode 100644
index 97c210e457268136ec23e38ea90d07275607667d..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/mokaai/relay-mokaai.go
+++ /dev/null
@@ -1,83 +0,0 @@
-package mokaai
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
- var input []string // Change input to []string
-
- switch v := request.Input.(type) {
- case string:
- input = []string{v} // Convert string to []string
- case []string:
- input = v // Already a []string, no conversion needed
- case []interface{}:
- for _, part := range v {
- if str, ok := part.(string); ok {
- input = append(input, str) // Append each string to the slice
- }
- }
- }
- return &dto.EmbeddingRequest{
- Input: input,
- Model: request.Model,
- }
-}
-
-func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
- openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
- Model: "baidu-embedding",
- Usage: response.Usage,
- }
- for _, item := range response.Data {
- openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
- Object: item.Object,
- Index: item.Index,
- Embedding: item.Embedding,
- })
- }
- return &openAIEmbeddingResponse
-}
-
-func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var baiduResponse dto.EmbeddingResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &baiduResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- // if baiduResponse.ErrorMsg != "" {
- // return &dto.OpenAIErrorWithStatusCode{
- // Error: dto.OpenAIError{
- // Type: "baidu_error",
- // Param: "",
- // },
- // StatusCode: resp.StatusCode,
- // }, nil
- // }
- fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return &fullTextResponse.Usage, nil
-}
diff --git a/new-api/relay/channel/moonshot/adaptor.go b/new-api/relay/channel/moonshot/adaptor.go
deleted file mode 100644
index d4d047e55263e1359d755e5fc5fb2c1f1abd5409..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/moonshot/adaptor.go
+++ /dev/null
@@ -1,110 +0,0 @@
-package moonshot
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := claude.Adaptor{}
- return adaptor.ConvertClaudeRequest(c, info, req)
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not supported")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- adaptor := openai.Adaptor{}
- return adaptor.ConvertImageRequest(c, info, request)
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
- default:
- if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeEmbeddings {
- return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeChatCompletions {
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeCompletions {
- return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
- }
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- if info.IsStream {
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- } else {
- return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
- }
- default:
- adaptor := openai.Adaptor{}
- return adaptor.DoResponse(c, resp, info)
- }
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/moonshot/constants.go b/new-api/relay/channel/moonshot/constants.go
deleted file mode 100644
index 6bb77a44f48476f0edf782304240565b621ffab3..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/moonshot/constants.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package moonshot
-
-var ModelList = []string{
- "moonshot-v1-8k",
- "moonshot-v1-32k",
- "moonshot-v1-128k",
-}
-
-var ChannelName = "moonshot"
diff --git a/new-api/relay/channel/ollama/adaptor.go b/new-api/relay/channel/ollama/adaptor.go
deleted file mode 100644
index e7b7ffabfacb71f065c637a81744a019512b7538..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ollama/adaptor.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package ollama
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- openaiAdaptor := openai.Adaptor{}
- openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
- if err != nil {
- return nil, err
- }
- openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
- IncludeUsage: true,
- }
- // map to ollama chat request (Claude -> OpenAI -> Ollama chat)
- return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
- if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
- return info.ChannelBaseUrl + "/api/chat", nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil { return nil, errors.New("request is nil") }
- // decide generate or chat
- if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
- return openAIToGenerate(c, request)
- }
- return openAIChatToOllamaChat(c, request)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return requestOpenAI2Embeddings(request), nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayMode {
- case relayconstant.RelayModeEmbeddings:
- return ollamaEmbeddingHandler(c, info, resp)
- default:
- if info.IsStream {
- return ollamaStreamHandler(c, info, resp)
- }
- return ollamaChatHandler(c, info, resp)
- }
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/ollama/constants.go b/new-api/relay/channel/ollama/constants.go
deleted file mode 100644
index e4f5702c26da2a237bf5aca6bb9b9c7f816e40c5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ollama/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package ollama
-
-var ModelList = []string{
- "llama3-7b",
-}
-
-var ChannelName = "ollama"
diff --git a/new-api/relay/channel/ollama/dto.go b/new-api/relay/channel/ollama/dto.go
deleted file mode 100644
index 01758ab6e625bef954ce0dec6e09cf7d407f774c..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ollama/dto.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package ollama
-
-import (
- "encoding/json"
-)
-
-type OllamaChatMessage struct {
- Role string `json:"role"`
- Content string `json:"content,omitempty"`
- Images []string `json:"images,omitempty"`
- ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
- ToolName string `json:"tool_name,omitempty"`
- Thinking json.RawMessage `json:"thinking,omitempty"`
-}
-
-type OllamaToolFunction struct {
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- Parameters interface{} `json:"parameters,omitempty"`
-}
-
-type OllamaTool struct {
- Type string `json:"type"`
- Function OllamaToolFunction `json:"function"`
-}
-
-type OllamaToolCall struct {
- Function struct {
- Name string `json:"name"`
- Arguments interface{} `json:"arguments"`
- } `json:"function"`
-}
-
-type OllamaChatRequest struct {
- Model string `json:"model"`
- Messages []OllamaChatMessage `json:"messages"`
- Tools interface{} `json:"tools,omitempty"`
- Format interface{} `json:"format,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Options map[string]any `json:"options,omitempty"`
- KeepAlive interface{} `json:"keep_alive,omitempty"`
- Think json.RawMessage `json:"think,omitempty"`
-}
-
-type OllamaGenerateRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt,omitempty"`
- Suffix string `json:"suffix,omitempty"`
- Images []string `json:"images,omitempty"`
- Format interface{} `json:"format,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Options map[string]any `json:"options,omitempty"`
- KeepAlive interface{} `json:"keep_alive,omitempty"`
- Think json.RawMessage `json:"think,omitempty"`
-}
-
-type OllamaEmbeddingRequest struct {
- Model string `json:"model"`
- Input interface{} `json:"input"`
- Options map[string]any `json:"options,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
-}
-
-type OllamaEmbeddingResponse struct {
- Error string `json:"error,omitempty"`
- Model string `json:"model"`
- Embeddings [][]float64 `json:"embeddings"`
- PromptEvalCount int `json:"prompt_eval_count,omitempty"`
-}
-
diff --git a/new-api/relay/channel/ollama/relay-ollama.go b/new-api/relay/channel/ollama/relay-ollama.go
deleted file mode 100644
index 359dfd4454b728c8fa190a90b8a947b29fa07ffa..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ollama/relay-ollama.go
+++ /dev/null
@@ -1,190 +0,0 @@
-package ollama
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
- chatReq := &OllamaChatRequest{
- Model: r.Model,
- Stream: r.Stream,
- Options: map[string]any{},
- Think: r.Think,
- }
- if r.ResponseFormat != nil {
- if r.ResponseFormat.Type == "json" {
- chatReq.Format = "json"
- } else if r.ResponseFormat.Type == "json_schema" {
- if len(r.ResponseFormat.JsonSchema) > 0 {
- var schema any
- _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
- chatReq.Format = schema
- }
- }
- }
-
- // options mapping
- if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
- if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
- if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
- if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
- if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
- if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
- if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
-
- if r.Stop != nil {
- switch v := r.Stop.(type) {
- case string:
- chatReq.Options["stop"] = []string{v}
- case []string:
- chatReq.Options["stop"] = v
- case []any:
- arr := make([]string,0,len(v))
- for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
- if len(arr)>0 { chatReq.Options["stop"] = arr }
- }
- }
-
- if len(r.Tools) > 0 {
- tools := make([]OllamaTool,0,len(r.Tools))
- for _, t := range r.Tools {
- tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
- }
- chatReq.Tools = tools
- }
-
- chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
- for _, m := range r.Messages {
- var textBuilder strings.Builder
- var images []string
- if m.IsStringContent() {
- textBuilder.WriteString(m.StringContent())
- } else {
- parts := m.ParseContent()
- for _, part := range parts {
- if part.Type == dto.ContentTypeImageURL {
- img := part.GetImageMedia()
- if img != nil && img.Url != "" {
- var base64Data string
- if strings.HasPrefix(img.Url, "http") {
- fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
- if err != nil { return nil, err }
- base64Data = fileData.Base64Data
- } else if strings.HasPrefix(img.Url, "data:") {
- if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
- } else {
- base64Data = img.Url
- }
- if base64Data != "" { images = append(images, base64Data) }
- }
- } else if part.Type == dto.ContentTypeText {
- textBuilder.WriteString(part.Text)
- }
- }
- }
- cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
- if len(images)>0 { cm.Images = images }
- if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
- if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
- parsed := m.ParseToolCalls()
- if len(parsed) > 0 {
- calls := make([]OllamaToolCall,0,len(parsed))
- for _, tc := range parsed {
- var args interface{}
- if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
- if args==nil { args = map[string]any{} }
- oc := OllamaToolCall{}
- oc.Function.Name = tc.Function.Name
- oc.Function.Arguments = args
- calls = append(calls, oc)
- }
- cm.ToolCalls = calls
- }
- }
- chatReq.Messages = append(chatReq.Messages, cm)
- }
- return chatReq, nil
-}
-
-// openAIToGenerate converts OpenAI completions request to Ollama generate
-func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
- gen := &OllamaGenerateRequest{
- Model: r.Model,
- Stream: r.Stream,
- Options: map[string]any{},
- Think: r.Think,
- }
- // Prompt may be in r.Prompt (string or []any)
- if r.Prompt != nil {
- switch v := r.Prompt.(type) {
- case string:
- gen.Prompt = v
- case []any:
- var sb strings.Builder
- for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
- gen.Prompt = sb.String()
- default:
- gen.Prompt = fmt.Sprintf("%v", r.Prompt)
- }
- }
- if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
- if r.ResponseFormat != nil {
- if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
- }
- if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
- if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
- if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
- if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
- if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
- if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
- if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
- if r.Stop != nil {
- switch v := r.Stop.(type) {
- case string: gen.Options["stop"] = []string{v}
- case []string: gen.Options["stop"] = v
- case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
- }
- }
- return gen, nil
-}
-
-func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
- opts := map[string]any{}
- if r.Temperature != nil { opts["temperature"] = r.Temperature }
- if r.TopP != 0 { opts["top_p"] = r.TopP }
- if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
- if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
- if r.Seed != 0 { opts["seed"] = int(r.Seed) }
- if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
- input := r.ParseInput()
- if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
- return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
-}
-
-func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var oResp OllamaEmbeddingResponse
- body, err := io.ReadAll(resp.Body)
- if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
- service.CloseResponseBodyGracefully(resp)
- if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
- if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
- data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
- for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
- usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
- embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
- out, _ := common.Marshal(embResp)
- service.IOCopyBytesGracefully(c, resp, out)
- return usage, nil
-}
-
diff --git a/new-api/relay/channel/ollama/stream.go b/new-api/relay/channel/ollama/stream.go
deleted file mode 100644
index 7b325c5d51829cd95c25d5187e3c3849fcae705e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/ollama/stream.go
+++ /dev/null
@@ -1,210 +0,0 @@
-package ollama
-
-import (
- "bufio"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-type ollamaChatStreamChunk struct {
- Model string `json:"model"`
- CreatedAt string `json:"created_at"`
- // chat
- Message *struct {
- Role string `json:"role"`
- Content string `json:"content"`
- Thinking json.RawMessage `json:"thinking"`
- ToolCalls []struct {
- Function struct {
- Name string `json:"name"`
- Arguments interface{} `json:"arguments"`
- } `json:"function"`
- } `json:"tool_calls"`
- } `json:"message"`
- // generate
- Response string `json:"response"`
- Done bool `json:"done"`
- DoneReason string `json:"done_reason"`
- TotalDuration int64 `json:"total_duration"`
- LoadDuration int64 `json:"load_duration"`
- PromptEvalCount int `json:"prompt_eval_count"`
- EvalCount int `json:"eval_count"`
- PromptEvalDuration int64 `json:"prompt_eval_duration"`
- EvalDuration int64 `json:"eval_duration"`
-}
-
-func toUnix(ts string) int64 {
- if ts == "" { return time.Now().Unix() }
- // try time.RFC3339 or with nanoseconds
- t, err := time.Parse(time.RFC3339Nano, ts)
- if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
- return t.Unix()
-}
-
-func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
- defer service.CloseResponseBodyGracefully(resp)
-
- helper.SetEventStreamHeaders(c)
- scanner := bufio.NewScanner(resp.Body)
- usage := &dto.Usage{}
- var model = info.UpstreamModelName
- var responseId = common.GetUUID()
- var created = time.Now().Unix()
- var toolCallIndex int
- start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
- if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
-
- for scanner.Scan() {
- line := scanner.Text()
- line = strings.TrimSpace(line)
- if line == "" { continue }
- var chunk ollamaChatStreamChunk
- if err := json.Unmarshal([]byte(line), &chunk); err != nil {
- logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
- return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if chunk.Model != "" { model = chunk.Model }
- created = toUnix(chunk.CreatedAt)
-
- if !chunk.Done {
- // delta content
- var content string
- if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
- delta := dto.ChatCompletionsStreamResponse{
- Id: responseId,
- Object: "chat.completion.chunk",
- Created: created,
- Model: model,
- Choices: []dto.ChatCompletionsStreamResponseChoice{ {
- Index: 0,
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
- } },
- }
- if content != "" { delta.Choices[0].Delta.SetContentString(content) }
- if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
- raw := strings.TrimSpace(string(chunk.Message.Thinking))
- if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
- }
- // tool calls
- if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
- delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
- for _, tc := range chunk.Message.ToolCalls {
- // arguments -> string
- argBytes, _ := json.Marshal(tc.Function.Arguments)
- toolId := fmt.Sprintf("call_%d", toolCallIndex)
- tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
- tr.SetIndex(toolCallIndex)
- toolCallIndex++
- delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
- }
- }
- if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
- continue
- }
- // done frame
- // finalize once and break loop
- usage.PromptTokens = chunk.PromptEvalCount
- usage.CompletionTokens = chunk.EvalCount
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- finishReason := chunk.DoneReason
- if finishReason == "" { finishReason = "stop" }
- // emit stop delta
- if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
- if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
- }
- // emit usage frame
- if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
- if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
- }
- // send [DONE]
- helper.Done(c)
- break
- }
- if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
- return usage, nil
-}
-
-// non-stream handler for chat/generate
-func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- body, err := io.ReadAll(resp.Body)
- if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
- service.CloseResponseBodyGracefully(resp)
- raw := string(body)
- if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
-
- lines := strings.Split(raw, "\n")
- var (
- aggContent strings.Builder
- reasoningBuilder strings.Builder
- lastChunk ollamaChatStreamChunk
- parsedAny bool
- )
- for _, ln := range lines {
- ln = strings.TrimSpace(ln)
- if ln == "" { continue }
- var ck ollamaChatStreamChunk
- if err := json.Unmarshal([]byte(ln), &ck); err != nil {
- if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
- continue
- }
- parsedAny = true
- lastChunk = ck
- if ck.Message != nil && len(ck.Message.Thinking) > 0 {
- raw := strings.TrimSpace(string(ck.Message.Thinking))
- if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
- }
- if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
- }
-
- if !parsedAny {
- var single ollamaChatStreamChunk
- if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
- lastChunk = single
- if single.Message != nil {
- if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
- aggContent.WriteString(single.Message.Content)
- } else { aggContent.WriteString(single.Response) }
- }
-
- model := lastChunk.Model
- if model == "" { model = info.UpstreamModelName }
- created := toUnix(lastChunk.CreatedAt)
- usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
- content := aggContent.String()
- finishReason := lastChunk.DoneReason
- if finishReason == "" { finishReason = "stop" }
-
- msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
- if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
- full := dto.OpenAITextResponse{
- Id: common.GetUUID(),
- Model: model,
- Object: "chat.completion",
- Created: created,
- Choices: []dto.OpenAITextResponseChoice{ {
- Index: 0,
- Message: msg,
- FinishReason: finishReason,
- } },
- Usage: *usage,
- }
- out, _ := common.Marshal(full)
- service.IOCopyBytesGracefully(c, resp, out)
- return usage, nil
-}
-
-func contentPtr(s string) *string { if s=="" { return nil }; return &s }
diff --git a/new-api/relay/channel/openai/adaptor.go b/new-api/relay/channel/openai/adaptor.go
deleted file mode 100644
index adea21f76b70ead34e4a0755a0577327b6b0b014..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openai/adaptor.go
+++ /dev/null
@@ -1,627 +0,0 @@
-package openai
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "net/textproto"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/ai360"
- "one-api/relay/channel/lingyiwanwu"
- "one-api/relay/channel/minimax"
- "one-api/relay/channel/openrouter"
- "one-api/relay/channel/xinference"
- relaycommon "one-api/relay/common"
- "one-api/relay/common_handler"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- "one-api/types"
- "path/filepath"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
- ChannelType int
- ResponseFormat string
-}
-
-// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
-// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
-// minimal effort only available in gpt-5
-func parseReasoningEffortFromModelSuffix(model string) (string, string) {
- effortSuffixes := []string{"-high", "-minimal", "-low", "-medium"}
- for _, suffix := range effortSuffixes {
- if strings.HasSuffix(model, suffix) {
- effort := strings.TrimPrefix(suffix, "-")
- originModel := strings.TrimSuffix(model, suffix)
- return effort, originModel
- }
- }
- return "", model
-}
-
-func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
- // 使用 service.GeminiToOpenAIRequest 转换请求格式
- openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
- if err != nil {
- return nil, err
- }
- return a.ConvertOpenAIRequest(c, info, openaiRequest)
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- //if !strings.Contains(request.Model, "claude") {
- // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
- //}
- //if common.DebugEnabled {
- // bodyBytes := []byte(common.GetJsonString(request))
- // err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
- // if err != nil {
- // println(fmt.Sprintf("failed to save request body to file: %v", err))
- // }
- //}
- aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
- if err != nil {
- return nil, err
- }
- //if common.DebugEnabled {
- // println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest)))
- // // Save request body to file for debugging
- // bodyBytes := []byte(common.GetJsonString(aiRequest))
- // err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
- // if err != nil {
- // println(fmt.Sprintf("failed to save request body to file: %v", err))
- // }
- //}
- if info.SupportStreamOptions && info.IsStream {
- aiRequest.StreamOptions = &dto.StreamOptions{
- IncludeUsage: true,
- }
- }
- return a.ConvertOpenAIRequest(c, info, aiRequest)
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
-
- // initialize ThinkingContentInfo when thinking_to_content is enabled
- if info.ChannelSetting.ThinkingToContent {
- info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
- IsFirstThinkingContent: true,
- SendLastThinkingContent: false,
- HasSentThinkingContent: false,
- }
- }
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayMode == relayconstant.RelayModeRealtime {
- if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
- baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
- baseUrl = "wss://" + baseUrl
- info.ChannelBaseUrl = baseUrl
- } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
- baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
- baseUrl = "ws://" + baseUrl
- info.ChannelBaseUrl = baseUrl
- }
- }
- switch info.ChannelType {
- case constant.ChannelTypeAzure:
- apiVersion := info.ApiVersion
- if apiVersion == "" {
- apiVersion = constant.AzureDefaultAPIVersion
- }
- // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
- requestURL := strings.Split(info.RequestURLPath, "?")[0]
- requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
- task := strings.TrimPrefix(requestURL, "/v1/")
-
- if info.RelayFormat == types.RelayFormatClaude {
- task = strings.TrimPrefix(task, "messages")
- task = "chat/completions" + task
- }
-
- // 特殊处理 responses API
- if info.RelayMode == relayconstant.RelayModeResponses {
- responsesApiVersion := "preview"
-
- subUrl := "/openai/v1/responses"
- if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
- subUrl = "/openai/responses"
- responsesApiVersion = apiVersion
- }
-
- if info.ChannelOtherSettings.AzureResponsesVersion != "" {
- responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion
- }
-
- requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
- }
-
- model_ := info.UpstreamModelName
- // 2025年5月10日后创建的渠道不移除.
- if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
- model_ = strings.Replace(model_, ".", "", -1)
- }
- // https://github.com/songquanpeng/one-api/issues/67
- requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- if info.RelayMode == relayconstant.RelayModeRealtime {
- requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
- }
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
- case constant.ChannelTypeMiniMax:
- return minimax.GetRequestURL(info)
- case constant.ChannelTypeCustom:
- url := info.ChannelBaseUrl
- url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
- return url, nil
- default:
- if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
- }
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, header)
- if info.ChannelType == constant.ChannelTypeAzure {
- header.Set("api-key", info.ApiKey)
- return nil
- }
- if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
- header.Set("OpenAI-Organization", info.Organization)
- }
- if info.RelayMode == relayconstant.RelayModeRealtime {
- swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
- if swp != "" {
- items := []string{
- "realtime",
- "openai-insecure-api-key." + info.ApiKey,
- "openai-beta.realtime-v1",
- }
- header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
- //req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
- //req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
- //req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
- } else {
- header.Set("openai-beta", "realtime=v1")
- header.Set("Authorization", "Bearer "+info.ApiKey)
- }
- } else {
- header.Set("Authorization", "Bearer "+info.ApiKey)
- }
- if info.ChannelType == constant.ChannelTypeOpenRouter {
- header.Set("HTTP-Referer", "https://www.newapi.ai")
- header.Set("X-Title", "New API")
- }
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
- request.StreamOptions = nil
- }
- if info.ChannelType == constant.ChannelTypeOpenRouter {
- if len(request.Usage) == 0 {
- request.Usage = json.RawMessage(`{"include":true}`)
- }
- // 适配 OpenRouter 的 thinking 后缀
- if strings.HasSuffix(info.UpstreamModelName, "-thinking") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- request.Model = info.UpstreamModelName
- if len(request.Reasoning) == 0 {
- reasoning := map[string]any{
- "enabled": true,
- }
- if request.ReasoningEffort != "" && request.ReasoningEffort != "none" {
- reasoning["effort"] = request.ReasoningEffort
- }
- marshal, err := common.Marshal(reasoning)
- if err != nil {
- return nil, fmt.Errorf("error marshalling reasoning: %w", err)
- }
- request.Reasoning = marshal
- }
- // 清空多余的ReasoningEffort
- request.ReasoningEffort = ""
- } else {
- if len(request.Reasoning) == 0 {
- // 适配 OpenAI 的 ReasoningEffort 格式
- if request.ReasoningEffort != "" {
- reasoning := map[string]any{
- "enabled": true,
- }
- if request.ReasoningEffort != "none" {
- reasoning["effort"] = request.ReasoningEffort
- marshal, err := common.Marshal(reasoning)
- if err != nil {
- return nil, fmt.Errorf("error marshalling reasoning: %w", err)
- }
- request.Reasoning = marshal
- }
- }
- }
- request.ReasoningEffort = ""
- }
-
- // https://docs.anthropic.com/en/api/openai-sdk#extended-thinking-support
- // 没有做排除3.5Haiku等,要出问题再加吧,最佳兼容性(不是
- if request.THINKING != nil && strings.HasPrefix(info.UpstreamModelName, "anthropic") {
- var thinking dto.Thinking // Claude标准Thinking格式
- if err := json.Unmarshal(request.THINKING, &thinking); err != nil {
- return nil, fmt.Errorf("error Unmarshal thinking: %w", err)
- }
-
- // 只有当 thinking.Type 是 "enabled" 时才处理
- if thinking.Type == "enabled" {
- // 检查 BudgetTokens 是否为 nil
- if thinking.BudgetTokens == nil {
- return nil, fmt.Errorf("BudgetTokens is nil when thinking is enabled")
- }
-
- reasoning := openrouter.RequestReasoning{
- MaxTokens: *thinking.BudgetTokens,
- }
-
- marshal, err := common.Marshal(reasoning)
- if err != nil {
- return nil, fmt.Errorf("error marshalling reasoning: %w", err)
- }
-
- request.Reasoning = marshal
- }
-
- // 清空 THINKING
- request.THINKING = nil
- }
-
- }
- if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
- if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
- request.MaxCompletionTokens = request.MaxTokens
- request.MaxTokens = 0
- }
-
- if strings.HasPrefix(info.UpstreamModelName, "o") {
- request.Temperature = nil
- }
-
- if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
- if info.UpstreamModelName != "gpt-5-chat-latest" {
- request.Temperature = nil
- }
- }
-
- // 转换模型推理力度后缀
- effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
- if effort != "" {
- request.ReasoningEffort = effort
- info.UpstreamModelName = originModel
- request.Model = originModel
- }
-
- info.ReasoningEffort = request.ReasoningEffort
-
- // o系列模型developer适配(o1-mini除外)
- if !strings.HasPrefix(info.UpstreamModelName, "o1-mini") && !strings.HasPrefix(info.UpstreamModelName, "o1-preview") {
- //修改第一个Message的内容,将system改为developer
- if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
- request.Messages[0].Role = "developer"
- }
- }
- }
-
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- a.ResponseFormat = request.ResponseFormat
- if info.RelayMode == relayconstant.RelayModeAudioSpeech {
- jsonData, err := json.Marshal(request)
- if err != nil {
- return nil, fmt.Errorf("error marshalling object: %w", err)
- }
- return bytes.NewReader(jsonData), nil
- } else {
- var requestBody bytes.Buffer
- writer := multipart.NewWriter(&requestBody)
-
- writer.WriteField("model", request.Model)
-
- // 获取所有表单字段
- formData := c.Request.PostForm
-
- // 遍历表单字段并打印输出
- for key, values := range formData {
- if key == "model" {
- continue
- }
- for _, value := range values {
- writer.WriteField(key, value)
- }
- }
-
- // 添加文件字段
- file, header, err := c.Request.FormFile("file")
- if err != nil {
- return nil, errors.New("file is required")
- }
- defer file.Close()
-
- part, err := writer.CreateFormFile("file", header.Filename)
- if err != nil {
- return nil, errors.New("create form file failed")
- }
- if _, err := io.Copy(part, file); err != nil {
- return nil, errors.New("copy file failed")
- }
-
- // 关闭 multipart 编写器以设置分界线
- writer.Close()
- c.Request.Header.Set("Content-Type", writer.FormDataContentType())
- return &requestBody, nil
- }
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- switch info.RelayMode {
- case relayconstant.RelayModeImagesEdits:
-
- var requestBody bytes.Buffer
- writer := multipart.NewWriter(&requestBody)
-
- writer.WriteField("model", request.Model)
- // 使用已解析的 multipart 表单,避免重复解析
- mf := c.Request.MultipartForm
- if mf == nil {
- if _, err := c.MultipartForm(); err != nil {
- return nil, errors.New("failed to parse multipart form")
- }
- mf = c.Request.MultipartForm
- }
-
- // 写入所有非文件字段
- if mf != nil {
- for key, values := range mf.Value {
- if key == "model" {
- continue
- }
- for _, value := range values {
- writer.WriteField(key, value)
- }
- }
- }
-
- if mf != nil && mf.File != nil {
- // Check if "image" field exists in any form, including array notation
- var imageFiles []*multipart.FileHeader
- var exists bool
-
- // First check for standard "image" field
- if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
- // If not found, check for "image[]" field
- if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
- // If still not found, iterate through all fields to find any that start with "image["
- foundArrayImages := false
- for fieldName, files := range mf.File {
- if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
- foundArrayImages = true
- imageFiles = append(imageFiles, files...)
- }
- }
-
- // If no image fields found at all
- if !foundArrayImages && (len(imageFiles) == 0) {
- return nil, errors.New("image is required")
- }
- }
- }
-
- // Process all image files
- for i, fileHeader := range imageFiles {
- file, err := fileHeader.Open()
- if err != nil {
- return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
- }
-
- // If multiple images, use image[] as the field name
- fieldName := "image"
- if len(imageFiles) > 1 {
- fieldName = "image[]"
- }
-
- // Determine MIME type based on file extension
- mimeType := detectImageMimeType(fileHeader.Filename)
-
- // Create a form file with the appropriate content type
- h := make(textproto.MIMEHeader)
- h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
- h.Set("Content-Type", mimeType)
-
- part, err := writer.CreatePart(h)
- if err != nil {
- return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
- }
-
- if _, err := io.Copy(part, file); err != nil {
- return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
- }
-
- // 复制完立即关闭,避免在循环内使用 defer 占用资源
- _ = file.Close()
- }
-
- // Handle mask file if present
- if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 {
- maskFile, err := maskFiles[0].Open()
- if err != nil {
- return nil, errors.New("failed to open mask file")
- }
- // 复制完立即关闭,避免在循环内使用 defer 占用资源
-
- // Determine MIME type for mask file
- mimeType := detectImageMimeType(maskFiles[0].Filename)
-
- // Create a form file with the appropriate content type
- h := make(textproto.MIMEHeader)
- h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
- h.Set("Content-Type", mimeType)
-
- maskPart, err := writer.CreatePart(h)
- if err != nil {
- return nil, errors.New("create form file failed for mask")
- }
-
- if _, err := io.Copy(maskPart, maskFile); err != nil {
- return nil, errors.New("copy mask file failed")
- }
- _ = maskFile.Close()
- }
- } else {
- return nil, errors.New("no multipart form data found")
- }
-
- // 关闭 multipart 编写器以设置分界线
- writer.Close()
- c.Request.Header.Set("Content-Type", writer.FormDataContentType())
- return &requestBody, nil
-
- default:
- return request, nil
- }
-}
-
-// detectImageMimeType determines the MIME type based on the file extension
-func detectImageMimeType(filename string) string {
- ext := strings.ToLower(filepath.Ext(filename))
- switch ext {
- case ".jpg", ".jpeg":
- return "image/jpeg"
- case ".png":
- return "image/png"
- case ".webp":
- return "image/webp"
- default:
- // Try to detect from extension if possible
- if strings.HasPrefix(ext, ".jp") {
- return "image/jpeg"
- }
- // Default to png as a fallback
- return "image/png"
- }
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // 转换模型推理力度后缀
- effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
- if effort != "" {
- if request.Reasoning == nil {
- request.Reasoning = &dto.Reasoning{
- Effort: effort,
- }
- } else {
- request.Reasoning.Effort = effort
- }
- request.Model = originModel
- }
- return request, nil
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
- info.RelayMode == relayconstant.RelayModeAudioTranslation ||
- info.RelayMode == relayconstant.RelayModeImagesEdits {
- return channel.DoFormRequest(a, c, info, requestBody)
- } else if info.RelayMode == relayconstant.RelayModeRealtime {
- return channel.DoWssRequest(a, c, info, requestBody)
- } else {
- return channel.DoApiRequest(a, c, info, requestBody)
- }
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayMode {
- case relayconstant.RelayModeRealtime:
- err, usage = OpenaiRealtimeHandler(c, info)
- case relayconstant.RelayModeAudioSpeech:
- usage = OpenaiTTSHandler(c, resp, info)
- case relayconstant.RelayModeAudioTranslation:
- fallthrough
- case relayconstant.RelayModeAudioTranscription:
- err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
- case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
- usage, err = OpenaiHandlerWithUsage(c, info, resp)
- case relayconstant.RelayModeRerank:
- usage, err = common_handler.RerankHandler(c, info, resp)
- case relayconstant.RelayModeResponses:
- if info.IsStream {
- usage, err = OaiResponsesStreamHandler(c, info, resp)
- } else {
- usage, err = OaiResponsesHandler(c, info, resp)
- }
- default:
- if info.IsStream {
- usage, err = OaiStreamHandler(c, info, resp)
- } else {
- usage, err = OpenaiHandler(c, info, resp)
- }
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- switch a.ChannelType {
- case constant.ChannelType360:
- return ai360.ModelList
- case constant.ChannelTypeLingYiWanWu:
- return lingyiwanwu.ModelList
- case constant.ChannelTypeMiniMax:
- return minimax.ModelList
- case constant.ChannelTypeXinference:
- return xinference.ModelList
- case constant.ChannelTypeOpenRouter:
- return openrouter.ModelList
- default:
- return ModelList
- }
-}
-
-func (a *Adaptor) GetChannelName() string {
- switch a.ChannelType {
- case constant.ChannelType360:
- return ai360.ChannelName
- case constant.ChannelTypeLingYiWanWu:
- return lingyiwanwu.ChannelName
- case constant.ChannelTypeMiniMax:
- return minimax.ChannelName
- case constant.ChannelTypeXinference:
- return xinference.ChannelName
- case constant.ChannelTypeOpenRouter:
- return openrouter.ChannelName
- default:
- return ChannelName
- }
-}
diff --git a/new-api/relay/channel/openai/constant.go b/new-api/relay/channel/openai/constant.go
deleted file mode 100644
index 1ae9efcdfc2f6a1485390cdeb17988f09f1ecb23..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openai/constant.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package openai
-
-var ModelList = []string{
- "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
- "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
- "gpt-3.5-turbo-instruct",
- "gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
- "gpt-4-32k", "gpt-4-32k-0613",
- "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
- "gpt-4-vision-preview",
- "chatgpt-4o-latest",
- "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
- "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
- "gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
- "gpt-4.1", "gpt-4.1-2025-04-14",
- "gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
- "gpt-4.1-nano", "gpt-4.1-nano-2025-04-14",
- "o1", "o1-2024-12-17",
- "o1-preview", "o1-preview-2024-09-12",
- "o1-mini", "o1-mini-2024-09-12",
- "o1-pro", "o1-pro-2025-03-19",
- "o3-mini", "o3-mini-2025-01-31",
- "o3-mini-high", "o3-mini-2025-01-31-high",
- "o3-mini-low", "o3-mini-2025-01-31-low",
- "o3-mini-medium", "o3-mini-2025-01-31-medium",
- "o3", "o3-2025-04-16",
- "o3-pro", "o3-pro-2025-06-10",
- "o3-deep-research", "o3-deep-research-2025-06-26",
- "o4-mini", "o4-mini-2025-04-16",
- "o4-mini-deep-research", "o4-mini-deep-research-2025-06-26",
- "gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
- "gpt-5-mini", "gpt-5-mini-2025-08-07",
- "gpt-5-nano", "gpt-5-nano-2025-08-07",
- "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
- "gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
- "gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
- "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
- "text-curie-001", "text-babbage-001", "text-ada-001",
- "text-moderation-latest", "text-moderation-stable",
- "text-davinci-edit-001",
- "davinci-002", "babbage-002",
- "dall-e-3", "gpt-image-1",
- "whisper-1",
- "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
-}
-
-var ChannelName = "openai"
diff --git a/new-api/relay/channel/openai/helper.go b/new-api/relay/channel/openai/helper.go
deleted file mode 100644
index c5ffe65552608f2ce85666c187808210b4045991..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openai/helper.go
+++ /dev/null
@@ -1,260 +0,0 @@
-package openai
-
-import (
- "encoding/json"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/samber/lo"
-
- "github.com/gin-gonic/gin"
-)
-
-// 辅助函数
-func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
- info.SendResponseCount++
-
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- return sendStreamData(c, info, data, forceFormat, thinkToContent)
- case types.RelayFormatClaude:
- return handleClaudeFormat(c, data, info)
- case types.RelayFormatGemini:
- return handleGeminiFormat(c, data, info)
- }
- return nil
-}
-
-func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
- var streamResponse dto.ChatCompletionsStreamResponse
- if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
- return err
- }
-
- if streamResponse.Usage != nil {
- info.ClaudeConvertInfo.Usage = streamResponse.Usage
- }
- claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
- for _, resp := range claudeResponses {
- helper.ClaudeData(c, *resp)
- }
- return nil
-}
-
-func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
- var streamResponse dto.ChatCompletionsStreamResponse
- if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
- logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
- return err
- }
-
- geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
-
- // 如果返回 nil,表示没有实际内容,跳过发送
- if geminiResponse == nil {
- return nil
- }
-
- geminiResponseStr, err := common.Marshal(geminiResponse)
- if err != nil {
- logger.LogError(c, "failed to marshal gemini response: "+err.Error())
- return err
- }
-
- // send gemini format response
- c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
- _ = helper.FlushWriter(c)
- return nil
-}
-
-func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
- responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
- if choice.Delta.ToolCalls != nil {
- if len(choice.Delta.ToolCalls) > *toolCount {
- *toolCount = len(choice.Delta.ToolCalls)
- }
- for _, tool := range choice.Delta.ToolCalls {
- responseTextBuilder.WriteString(tool.Function.Name)
- responseTextBuilder.WriteString(tool.Function.Arguments)
- }
- }
- }
- return nil
-}
-
-func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
- streamResp := "[" + strings.Join(streamItems, ",") + "]"
-
- switch relayMode {
- case relayconstant.RelayModeChatCompletions:
- return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
- case relayconstant.RelayModeCompletions:
- return processCompletions(streamResp, streamItems, responseTextBuilder)
- }
- return nil
-}
-
-func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
- var streamResponses []dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
- // 一次性解析失败,逐个解析
- common.SysLog("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
- return err
- }
- if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
- common.SysLog("error processing stream response: " + err.Error())
- }
- }
- return nil
- }
-
- // 批量处理所有响应
- for _, streamResponse := range streamResponses {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
- responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
- if choice.Delta.ToolCalls != nil {
- if len(choice.Delta.ToolCalls) > *toolCount {
- *toolCount = len(choice.Delta.ToolCalls)
- }
- for _, tool := range choice.Delta.ToolCalls {
- responseTextBuilder.WriteString(tool.Function.Name)
- responseTextBuilder.WriteString(tool.Function.Arguments)
- }
- }
- }
- }
- return nil
-}
-
-func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
- var streamResponses []dto.CompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
- // 一次性解析失败,逐个解析
- common.SysLog("error unmarshalling stream response: " + err.Error())
- for _, item := range streamItems {
- var streamResponse dto.CompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
- continue
- }
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Text)
- }
- }
- return nil
- }
-
- // 批量处理所有响应
- for _, streamResponse := range streamResponses {
- for _, choice := range streamResponse.Choices {
- responseTextBuilder.WriteString(choice.Text)
- }
- }
- return nil
-}
-
-func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
- systemFingerprint *string, model *string, usage **dto.Usage,
- containStreamUsage *bool, info *relaycommon.RelayInfo,
- shouldSendLastResp *bool) error {
-
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
- return err
- }
-
- *responseId = lastStreamResponse.Id
- *createAt = lastStreamResponse.Created
- *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
- *model = lastStreamResponse.Model
-
- if service.ValidUsage(lastStreamResponse.Usage) {
- *containStreamUsage = true
- *usage = lastStreamResponse.Usage
- if !info.ShouldIncludeUsage {
- *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
- return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
- })
- }
- }
-
- return nil
-}
-
-func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
- responseId string, createAt int64, model string, systemFingerprint string,
- usage *dto.Usage, containStreamUsage bool) {
-
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- if info.ShouldIncludeUsage && !containStreamUsage {
- response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
- response.SetSystemFingerprint(systemFingerprint)
- helper.ObjectData(c, response)
- }
- helper.Done(c)
-
- case types.RelayFormatClaude:
- info.ClaudeConvertInfo.Done = true
- var streamResponse dto.ChatCompletionsStreamResponse
- if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return
- }
-
- info.ClaudeConvertInfo.Usage = usage
-
- claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
- for _, resp := range claudeResponses {
- _ = helper.ClaudeData(c, *resp)
- }
-
- case types.RelayFormatGemini:
- var streamResponse dto.ChatCompletionsStreamResponse
- if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return
- }
-
- // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
- // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
- // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
- // 暂不知是否有程序会不兼容。
-
- geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
-
- // openai 流响应开头的空数据
- if geminiResponse == nil {
- return
- }
-
- geminiResponseStr, err := common.Marshal(geminiResponse)
- if err != nil {
- common.SysLog("error marshalling gemini response: " + err.Error())
- return
- }
-
- // 发送最终的 Gemini 响应
- c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
- _ = helper.FlushWriter(c)
- }
-}
-
-func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
- if data == "" {
- return
- }
- helper.ResponseChunkData(c, streamResponse, data)
-}
diff --git a/new-api/relay/channel/openai/relay-openai.go b/new-api/relay/channel/openai/relay-openai.go
deleted file mode 100644
index 6bc2c93693b2759fcc9d37038cfe72345e262e7b..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openai/relay-openai.go
+++ /dev/null
@@ -1,635 +0,0 @@
-package openai
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "math"
- "mime/multipart"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/relay/channel/openrouter"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "os"
- "path/filepath"
- "strings"
-
- "one-api/types"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "github.com/pkg/errors"
-)
-
-func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
- if data == "" {
- return nil
- }
-
- if !forceFormat && !thinkToContent {
- return helper.StringData(c, data)
- }
-
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
- return err
- }
-
- if !thinkToContent {
- return helper.ObjectData(c, lastStreamResponse)
- }
-
- hasThinkingContent := false
- hasContent := false
- var thinkingContent strings.Builder
- for _, choice := range lastStreamResponse.Choices {
- if len(choice.Delta.GetReasoningContent()) > 0 {
- hasThinkingContent = true
- thinkingContent.WriteString(choice.Delta.GetReasoningContent())
- }
- if len(choice.Delta.GetContentString()) > 0 {
- hasContent = true
- }
- }
-
- // Handle think to content conversion
- if info.ThinkingContentInfo.IsFirstThinkingContent {
- if hasThinkingContent {
- response := lastStreamResponse.Copy()
- for i := range response.Choices {
- // send `think` tag with thinking content
- response.Choices[i].Delta.SetContentString("\n" + thinkingContent.String())
- response.Choices[i].Delta.ReasoningContent = nil
- response.Choices[i].Delta.Reasoning = nil
- }
- info.ThinkingContentInfo.IsFirstThinkingContent = false
- info.ThinkingContentInfo.HasSentThinkingContent = true
- return helper.ObjectData(c, response)
- }
- }
-
- if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
- return helper.ObjectData(c, lastStreamResponse)
- }
-
- // Process each choice
- for i, choice := range lastStreamResponse.Choices {
- // Handle transition from thinking to content
- // only send `` tag when previous thinking content has been sent
- if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
- response := lastStreamResponse.Copy()
- for j := range response.Choices {
- response.Choices[j].Delta.SetContentString("\n\n")
- response.Choices[j].Delta.ReasoningContent = nil
- response.Choices[j].Delta.Reasoning = nil
- }
- info.ThinkingContentInfo.SendLastThinkingContent = true
- helper.ObjectData(c, response)
- }
-
- // Convert reasoning content to regular content if any
- if len(choice.Delta.GetReasoningContent()) > 0 {
- lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
- lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
- lastStreamResponse.Choices[i].Delta.Reasoning = nil
- } else if !hasThinkingContent && !hasContent {
- // flush thinking content
- lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
- lastStreamResponse.Choices[i].Delta.Reasoning = nil
- }
- }
-
- return helper.ObjectData(c, lastStreamResponse)
-}
-
-func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- if resp == nil || resp.Body == nil {
- logger.LogError(c, "invalid response or response body")
- return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
- }
-
- defer service.CloseResponseBodyGracefully(resp)
-
- model := info.UpstreamModelName
- var responseId string
- var createAt int64 = 0
- var systemFingerprint string
- var containStreamUsage bool
- var responseTextBuilder strings.Builder
- var toolCount int
- var usage = &dto.Usage{}
- var streamItems []string // store stream items
- var lastStreamData string
-
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- if lastStreamData != "" {
- err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
- if err != nil {
- common.SysLog("error handling stream format: " + err.Error())
- }
- }
- if len(data) > 0 {
- lastStreamData = data
- streamItems = append(streamItems, data)
- }
- return true
- })
-
- // 处理最后的响应
- shouldSendLastResp := true
- if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
- &containStreamUsage, info, &shouldSendLastResp); err != nil {
- logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
- }
-
- if info.RelayFormat == types.RelayFormatOpenAI {
- if shouldSendLastResp {
- _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
- }
- }
-
- // 处理token计算
- if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
- logger.LogError(c, "error processing tokens: "+err.Error())
- }
-
- if !containStreamUsage {
- usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- } else {
- if info.ChannelType == constant.ChannelTypeDeepSeek {
- if usage.PromptCacheHitTokens != 0 {
- usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
- }
- }
- }
- HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
-
- return usage, nil
-}
-
-func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- var simpleResponse dto.OpenAITextResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- if common.DebugEnabled {
- println("upstream response body:", string(responseBody))
- }
- // Unmarshal to simpleResponse
- if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
- // 尝试解析为 openrouter enterprise
- var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
- err = common.Unmarshal(responseBody, &enterpriseResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if enterpriseResponse.Success {
- responseBody = enterpriseResponse.Data
- } else {
- logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
- return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- }
-
- err = common.Unmarshal(responseBody, &simpleResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
- return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
- }
-
- forceFormat := false
- if info.ChannelSetting.ForceFormat {
- forceFormat = true
- }
-
- usageModified := false
- if simpleResponse.Usage.PromptTokens == 0 {
- completionTokens := simpleResponse.Usage.CompletionTokens
- if completionTokens == 0 {
- for _, choice := range simpleResponse.Choices {
- ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
- completionTokens += ctkm
- }
- }
- simpleResponse.Usage = dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: info.PromptTokens + completionTokens,
- }
- usageModified = true
- }
-
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- if usageModified {
- var bodyMap map[string]interface{}
- err = common.Unmarshal(responseBody, &bodyMap)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- bodyMap["usage"] = simpleResponse.Usage
- responseBody, _ = common.Marshal(bodyMap)
- }
- if forceFormat {
- responseBody, err = common.Marshal(simpleResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- } else {
- break
- }
- case types.RelayFormatClaude:
- claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
- claudeRespStr, err := common.Marshal(claudeResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- responseBody = claudeRespStr
- case types.RelayFormatGemini:
- geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
- geminiRespStr, err := common.Marshal(geminiResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- responseBody = geminiRespStr
- }
-
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- return &simpleResponse.Usage, nil
-}
-
-func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
- // the status code has been judged before, if there is a body reading failure,
- // it should be regarded as a non-recoverable error, so it should not return err for external retry.
- // Analogous to nginx's load balancing, it will only retry if it can't be requested or
- // if the upstream returns a specific status code, once the upstream has already written the header,
- // the subsequent failure of the response body should be regarded as a non-recoverable error,
- // and can be terminated directly.
- defer service.CloseResponseBodyGracefully(resp)
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.TotalTokens = info.PromptTokens
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- c.Writer.WriteHeaderNow()
- _, err := io.Copy(c.Writer, resp.Body)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- return usage
-}
-
-func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
- defer service.CloseResponseBodyGracefully(resp)
-
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- var responseData struct {
- Usage *dto.Usage `json:"usage"`
- }
- if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
- if responseData.Usage.TotalTokens > 0 {
- usage := responseData.Usage
- if usage.PromptTokens == 0 {
- usage.PromptTokens = usage.InputTokens
- }
- if usage.CompletionTokens == 0 {
- usage.CompletionTokens = usage.OutputTokens
- }
- return nil, usage
- }
- }
-
- audioTokens, err := countAudioTokens(c)
- if err != nil {
- return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
- }
- usage := &dto.Usage{}
- usage.PromptTokens = audioTokens
- usage.CompletionTokens = 0
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return nil, usage
-}
-
-func countAudioTokens(c *gin.Context) (int, error) {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return 0, errors.WithStack(err)
- }
-
- var reqBody struct {
- File *multipart.FileHeader `form:"file" binding:"required"`
- }
- c.Request.Body = io.NopCloser(bytes.NewReader(body))
- if err = c.ShouldBind(&reqBody); err != nil {
- return 0, errors.WithStack(err)
- }
- ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
- reqFp, err := reqBody.File.Open()
- if err != nil {
- return 0, errors.WithStack(err)
- }
- defer reqFp.Close()
-
- tmpFp, err := os.CreateTemp("", "audio-*"+ext)
- if err != nil {
- return 0, errors.WithStack(err)
- }
- defer os.Remove(tmpFp.Name())
-
- _, err = io.Copy(tmpFp, reqFp)
- if err != nil {
- return 0, errors.WithStack(err)
- }
- if err = tmpFp.Close(); err != nil {
- return 0, errors.WithStack(err)
- }
-
- duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
- if err != nil {
- return 0, errors.WithStack(err)
- }
-
- return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
-}
-
-func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
- if info == nil || info.ClientWs == nil || info.TargetWs == nil {
- return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
- }
-
- info.IsStream = true
- clientConn := info.ClientWs
- targetConn := info.TargetWs
-
- clientClosed := make(chan struct{})
- targetClosed := make(chan struct{})
- sendChan := make(chan []byte, 100)
- receiveChan := make(chan []byte, 100)
- errChan := make(chan error, 2)
-
- usage := &dto.RealtimeUsage{}
- localUsage := &dto.RealtimeUsage{}
- sumUsage := &dto.RealtimeUsage{}
-
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in client reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := clientConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from client: %v", err)
- }
- close(clientClosed)
- return
- }
-
- realtimeEvent := &dto.RealtimeEvent{}
- err = common.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
-
- if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
- if realtimeEvent.Session != nil {
- if realtimeEvent.Session.Tools != nil {
- info.RealtimeTools = realtimeEvent.Session.Tools
- }
- }
- }
-
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
-
- err = helper.WssString(c, targetConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to target: %v", err)
- return
- }
-
- select {
- case sendChan <- message:
- default:
- }
- }
- }
- })
-
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in target reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := targetConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from target: %v", err)
- }
- close(targetClosed)
- return
- }
- info.SetFirstResponseTime()
- realtimeEvent := &dto.RealtimeEvent{}
- err = common.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
-
- if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
- realtimeUsage := realtimeEvent.Response.Usage
- if realtimeUsage != nil {
- usage.TotalTokens += realtimeUsage.TotalTokens
- usage.InputTokens += realtimeUsage.InputTokens
- usage.OutputTokens += realtimeUsage.OutputTokens
- usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
- usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
- usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
- usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
- usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
- err := preConsumeUsage(c, info, usage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- usage = &dto.RealtimeUsage{}
-
- localUsage = &dto.RealtimeUsage{}
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- info.IsFirstRequest = false
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
- err = preConsumeUsage(c, info, localUsage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- localUsage = &dto.RealtimeUsage{}
- // print now usage
- }
- logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
- logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
-
- } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
- realtimeSession := realtimeEvent.Session
- if realtimeSession != nil {
- // update audio format
- info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
- info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
- }
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.OutputTokens += textToken + audioToken
- localUsage.OutputTokenDetails.TextTokens += textToken
- localUsage.OutputTokenDetails.AudioTokens += audioToken
- }
-
- err = helper.WssString(c, clientConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to client: %v", err)
- return
- }
-
- select {
- case receiveChan <- message:
- default:
- }
- }
- }
- })
-
- select {
- case <-clientClosed:
- case <-targetClosed:
- case err := <-errChan:
- //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
- logger.LogError(c, "realtime error: "+err.Error())
- case <-c.Done():
- }
-
- if usage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, usage, sumUsage)
- }
-
- if localUsage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, localUsage, sumUsage)
- }
-
- // check usage total tokens, if 0, use local usage
-
- return nil, sumUsage
-}
-
-func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
- if usage == nil || totalUsage == nil {
- return fmt.Errorf("invalid usage pointer")
- }
-
- totalUsage.TotalTokens += usage.TotalTokens
- totalUsage.InputTokens += usage.InputTokens
- totalUsage.OutputTokens += usage.OutputTokens
- totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
- totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
- totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
- totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
- totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
- // clear usage
- err := service.PreWssConsumeQuota(ctx, info, usage)
- return err
-}
-
-func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
-
- var usageResp dto.SimpleResponse
- err = common.Unmarshal(responseBody, &usageResp)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
-
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- // Once we've written to the client, we should not return errors anymore
- // because the upstream has already consumed resources and returned content
- // We should still perform billing even if parsing fails
- // format
- if usageResp.InputTokens > 0 {
- usageResp.PromptTokens += usageResp.InputTokens
- }
- if usageResp.OutputTokens > 0 {
- usageResp.CompletionTokens += usageResp.OutputTokens
- }
- if usageResp.InputTokensDetails != nil {
- usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
- usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
- }
- return &usageResp.Usage, nil
-}
diff --git a/new-api/relay/channel/openai/relay_responses.go b/new-api/relay/channel/openai/relay_responses.go
deleted file mode 100644
index 78a0c2854cf0f85138d9c73790650dc2f6fcc74a..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openai/relay_responses.go
+++ /dev/null
@@ -1,149 +0,0 @@
-package openai
-
-import (
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- // read response body
- var responsesResponse dto.OpenAIResponsesResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- err = common.Unmarshal(responseBody, &responsesResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
- return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
- }
-
- if responsesResponse.HasImageGenerationCall() {
- c.Set("image_generation_call", true)
- c.Set("image_generation_call_quality", responsesResponse.GetQuality())
- c.Set("image_generation_call_size", responsesResponse.GetSize())
- }
-
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
-
- // compute usage
- usage := dto.Usage{}
- if responsesResponse.Usage != nil {
- usage.PromptTokens = responsesResponse.Usage.InputTokens
- usage.CompletionTokens = responsesResponse.Usage.OutputTokens
- usage.TotalTokens = responsesResponse.Usage.TotalTokens
- if responsesResponse.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
- }
- }
- if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
- return &usage, nil
- }
- // 解析 Tools 用量
- for _, tool := range responsesResponse.Tools {
- buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
- if !ok || buildToolinfo == nil {
- logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
- continue
- }
- buildToolinfo.CallCount++
- }
- return &usage, nil
-}
-
-func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- if resp == nil || resp.Body == nil {
- logger.LogError(c, "invalid response or response body")
- return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
- }
-
- defer service.CloseResponseBodyGracefully(resp)
-
- var usage = &dto.Usage{}
- var responseTextBuilder strings.Builder
-
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-
- // 检查当前数据是否包含 completed 状态和 usage 信息
- var streamResponse dto.ResponsesStreamResponse
- if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
- sendResponsesStreamData(c, streamResponse, data)
- switch streamResponse.Type {
- case "response.completed":
- if streamResponse.Response != nil {
- if streamResponse.Response.Usage != nil {
- if streamResponse.Response.Usage.InputTokens != 0 {
- usage.PromptTokens = streamResponse.Response.Usage.InputTokens
- }
- if streamResponse.Response.Usage.OutputTokens != 0 {
- usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
- }
- if streamResponse.Response.Usage.TotalTokens != 0 {
- usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
- }
- if streamResponse.Response.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
- }
- }
- if streamResponse.Response.HasImageGenerationCall() {
- c.Set("image_generation_call", true)
- c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
- c.Set("image_generation_call_size", streamResponse.Response.GetSize())
- }
- }
- case "response.output_text.delta":
- // 处理输出文本
- responseTextBuilder.WriteString(streamResponse.Delta)
- case dto.ResponsesOutputTypeItemDone:
- // 函数调用处理
- if streamResponse.Item != nil {
- switch streamResponse.Item.Type {
- case dto.BuildInCallWebSearchCall:
- if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
- if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
- webSearchTool.CallCount++
- }
- }
- }
- }
- }
- } else {
- logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
- }
- return true
- })
-
- if usage.CompletionTokens == 0 {
- // 计算输出文本的 token 数量
- tempStr := responseTextBuilder.String()
- if len(tempStr) > 0 {
- // 非正常结束,使用输出文本的 token 数量
- completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
- usage.CompletionTokens = completionTokens
- }
- }
-
- if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
- usage.PromptTokens = info.PromptTokens
- }
-
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-
- return usage, nil
-}
diff --git a/new-api/relay/channel/openrouter/constant.go b/new-api/relay/channel/openrouter/constant.go
deleted file mode 100644
index 26889beac27fb89a92e284d8dbde4dcb052947ee..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openrouter/constant.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package openrouter
-
-var ModelList = []string{}
-
-var ChannelName = "openrouter"
diff --git a/new-api/relay/channel/openrouter/dto.go b/new-api/relay/channel/openrouter/dto.go
deleted file mode 100644
index 31b18a6daafeb8fd84971f1b97da73b0db3e9752..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/openrouter/dto.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package openrouter
-
-import "encoding/json"
-
-type RequestReasoning struct {
- // One of the following (not both):
- Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
- MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
- // Optional: Default is false. All models support this.
- Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
-}
-
-type OpenRouterEnterpriseResponse struct {
- Data json.RawMessage `json:"data"`
- Success bool `json:"success"`
-}
diff --git a/new-api/relay/channel/palm/adaptor.go b/new-api/relay/channel/palm/adaptor.go
deleted file mode 100644
index 55e5f0ae09312b02f84ccbf220b63683c8134ab3..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/palm/adaptor.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package palm
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("x-goog-api-key", info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- var responseText string
- err, responseText = palmStreamHandler(c, resp)
- usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
- } else {
- usage, err = palmHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/palm/constants.go b/new-api/relay/channel/palm/constants.go
deleted file mode 100644
index 26f9fff94d286220646851be09d7bd0ac04b4771..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/palm/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package palm
-
-var ModelList = []string{
- "PaLM-2",
-}
-
-var ChannelName = "google palm"
diff --git a/new-api/relay/channel/palm/dto.go b/new-api/relay/channel/palm/dto.go
deleted file mode 100644
index 2a6d581d8af0cab7295e9050c5ac07daa96a9bf6..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/palm/dto.go
+++ /dev/null
@@ -1,38 +0,0 @@
-package palm
-
-import "one-api/dto"
-
-type PaLMChatMessage struct {
- Author string `json:"author"`
- Content string `json:"content"`
-}
-
-type PaLMFilter struct {
- Reason string `json:"reason"`
- Message string `json:"message"`
-}
-
-type PaLMPrompt struct {
- Messages []PaLMChatMessage `json:"messages"`
-}
-
-type PaLMChatRequest struct {
- Prompt PaLMPrompt `json:"prompt"`
- Temperature *float64 `json:"temperature,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK uint `json:"topK,omitempty"`
-}
-
-type PaLMError struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status string `json:"status"`
-}
-
-type PaLMChatResponse struct {
- Candidates []PaLMChatMessage `json:"candidates"`
- Messages []dto.Message `json:"messages"`
- Filters []PaLMFilter `json:"filters"`
- Error PaLMError `json:"error"`
-}
diff --git a/new-api/relay/channel/palm/relay-palm.go b/new-api/relay/channel/palm/relay-palm.go
deleted file mode 100644
index c9ec4c4a499c51edbc6fce2927f464be411d8e14..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/palm/relay-palm.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package palm
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
-// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
-
-func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
- fullTextResponse := dto.OpenAITextResponse{
- Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
- }
- for i, candidate := range response.Candidates {
- choice := dto.OpenAITextResponseChoice{
- Index: i,
- Message: dto.Message{
- Role: "assistant",
- Content: candidate.Content,
- },
- FinishReason: "stop",
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- }
- return &fullTextResponse
-}
-
-func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
- var choice dto.ChatCompletionsStreamResponseChoice
- if len(palmResponse.Candidates) > 0 {
- choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
- }
- choice.FinishReason = &constant.FinishReasonStop
- var response dto.ChatCompletionsStreamResponse
- response.Object = "chat.completion.chunk"
- response.Model = "palm2"
- response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
- return &response
-}
-
-func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
- responseText := ""
- responseId := helper.GetResponseID(c)
- createdTime := common.GetTimestamp()
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- common.SysLog("error reading stream response: " + err.Error())
- stopChan <- true
- return
- }
- service.CloseResponseBodyGracefully(resp)
- var palmResponse PaLMChatResponse
- err = json.Unmarshal(responseBody, &palmResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- stopChan <- true
- return
- }
- fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
- fullTextResponse.Id = responseId
- fullTextResponse.Created = createdTime
- if len(palmResponse.Candidates) > 0 {
- responseText = palmResponse.Candidates[0].Content
- }
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- stopChan <- true
- return
- }
- dataChan <- string(jsonResponse)
- stopChan <- true
- }()
- helper.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- c.Render(-1, common.CustomEvent{Data: "data: " + data})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- service.CloseResponseBodyGracefully(resp)
- return nil, responseText
-}
-
-func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- var palmResponse PaLMChatResponse
- err = json.Unmarshal(responseBody, &palmResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
- return nil, types.WithOpenAIError(types.OpenAIError{
- Message: palmResponse.Error.Message,
- Type: palmResponse.Error.Status,
- Param: "",
- Code: palmResponse.Error.Code,
- }, resp.StatusCode)
- }
- fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
- usage := dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: info.PromptTokens + completionTokens,
- }
- fullTextResponse.Usage = usage
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return &usage, nil
-}
diff --git a/new-api/relay/channel/perplexity/adaptor.go b/new-api/relay/channel/perplexity/adaptor.go
deleted file mode 100644
index 0bfa09f94a86585498a40042025d49b72f2dffa0..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/perplexity/adaptor.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package perplexity
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if request.TopP >= 1 {
- request.TopP = 0.99
- }
- return requestOpenAI2Perplexity(*request), nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/perplexity/constants.go b/new-api/relay/channel/perplexity/constants.go
deleted file mode 100644
index a692a2d9ef6dadf37a5cfef77c5df5b9a77a8bda..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/perplexity/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package perplexity
-
-var ModelList = []string{
- "llama-3-sonar-small-32k-chat", "llama-3-sonar-small-32k-online", "llama-3-sonar-large-32k-chat", "llama-3-sonar-large-32k-online", "llama-3-8b-instruct", "llama-3-70b-instruct", "mixtral-8x7b-instruct",
-}
-
-var ChannelName = "perplexity"
diff --git a/new-api/relay/channel/perplexity/relay-perplexity.go b/new-api/relay/channel/perplexity/relay-perplexity.go
deleted file mode 100644
index ebb2b078ab2ee7644d8c3c04d6af96279f2087dc..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/perplexity/relay-perplexity.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package perplexity
-
-import "one-api/dto"
-
-func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
- messages := make([]dto.Message, 0, len(request.Messages))
- for _, message := range request.Messages {
- messages = append(messages, dto.Message{
- Role: message.Role,
- Content: message.Content,
- })
- }
- return &dto.GeneralOpenAIRequest{
- Model: request.Model,
- Stream: request.Stream,
- Messages: messages,
- Temperature: request.Temperature,
- TopP: request.TopP,
- MaxTokens: request.GetMaxTokens(),
- }
-}
diff --git a/new-api/relay/channel/siliconflow/adaptor.go b/new-api/relay/channel/siliconflow/adaptor.go
deleted file mode 100644
index 0d2fe7d26f2a393c516f0f2acce930f79e7491bb..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/siliconflow/adaptor.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package siliconflow
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
- return adaptor.ConvertClaudeRequest(c, info, req)
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not supported")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- adaptor := openai.Adaptor{}
- return adaptor.ConvertImageRequest(c, info, request)
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeEmbeddings {
- return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeChatCompletions {
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
- } else if info.RelayMode == constant.RelayModeCompletions {
- return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
- }
- return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayMode {
- case constant.RelayModeRerank:
- usage, err = siliconflowRerankHandler(c, info, resp)
- case constant.RelayModeEmbeddings:
- usage, err = openai.OpenaiHandler(c, info, resp)
- case constant.RelayModeCompletions:
- fallthrough
- case constant.RelayModeChatCompletions:
- fallthrough
- default:
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
-
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/siliconflow/constant.go b/new-api/relay/channel/siliconflow/constant.go
deleted file mode 100644
index 9fc455dded031f6b5ba24f8bebdede469d3e9fdc..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/siliconflow/constant.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package siliconflow
-
-var ModelList = []string{
- "THUDM/glm-4-9b-chat",
- //"stabilityai/stable-diffusion-xl-base-1.0",
- //"TencentARC/PhotoMaker",
- "InstantX/InstantID",
- //"stabilityai/stable-diffusion-2-1",
- //"stabilityai/sd-turbo",
- //"stabilityai/sdxl-turbo",
- "ByteDance/SDXL-Lightning",
- "deepseek-ai/deepseek-llm-67b-chat",
- "Qwen/Qwen1.5-14B-Chat",
- "Qwen/Qwen1.5-7B-Chat",
- "Qwen/Qwen1.5-110B-Chat",
- "Qwen/Qwen1.5-32B-Chat",
- "01-ai/Yi-1.5-6B-Chat",
- "01-ai/Yi-1.5-9B-Chat-16K",
- "01-ai/Yi-1.5-34B-Chat-16K",
- "THUDM/chatglm3-6b",
- "deepseek-ai/DeepSeek-V2-Chat",
- "Qwen/Qwen2-72B-Instruct",
- "Qwen/Qwen2-7B-Instruct",
- "Qwen/Qwen2-57B-A14B-Instruct",
- //"stabilityai/stable-diffusion-3-medium",
- "deepseek-ai/DeepSeek-Coder-V2-Instruct",
- "Qwen/Qwen2-1.5B-Instruct",
- "internlm/internlm2_5-7b-chat",
- "BAAI/bge-large-en-v1.5",
- "BAAI/bge-large-zh-v1.5",
- "Pro/Qwen/Qwen2-7B-Instruct",
- "Pro/Qwen/Qwen2-1.5B-Instruct",
- "Pro/Qwen/Qwen1.5-7B-Chat",
- "Pro/THUDM/glm-4-9b-chat",
- "Pro/THUDM/chatglm3-6b",
- "Pro/01-ai/Yi-1.5-9B-Chat-16K",
- "Pro/01-ai/Yi-1.5-6B-Chat",
- "Pro/google/gemma-2-9b-it",
- "Pro/internlm/internlm2_5-7b-chat",
- "Pro/meta-llama/Meta-Llama-3-8B-Instruct",
- "Pro/mistralai/Mistral-7B-Instruct-v0.2",
- "black-forest-labs/FLUX.1-schnell",
- "FunAudioLLM/SenseVoiceSmall",
- "netease-youdao/bce-embedding-base_v1",
- "BAAI/bge-m3",
- "internlm/internlm2_5-20b-chat",
- "Qwen/Qwen2-Math-72B-Instruct",
- "netease-youdao/bce-reranker-base_v1",
- "BAAI/bge-reranker-v2-m3",
-}
-var ChannelName = "siliconflow"
diff --git a/new-api/relay/channel/siliconflow/dto.go b/new-api/relay/channel/siliconflow/dto.go
deleted file mode 100644
index 0683cc7e544c035966cab3b7e48f09cfd72cd162..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/siliconflow/dto.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package siliconflow
-
-import "one-api/dto"
-
-type SFTokens struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
-}
-
-type SFMeta struct {
- Tokens SFTokens `json:"tokens"`
-}
-
-type SFRerankResponse struct {
- Results []dto.RerankResponseResult `json:"results"`
- Meta SFMeta `json:"meta"`
-}
diff --git a/new-api/relay/channel/siliconflow/relay-siliconflow.go b/new-api/relay/channel/siliconflow/relay-siliconflow.go
deleted file mode 100644
index e5a9fd69207fcd4391b57d24350015c9d03888aa..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/siliconflow/relay-siliconflow.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package siliconflow
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- var siliconflowResp SFRerankResponse
- err = json.Unmarshal(responseBody, &siliconflowResp)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- usage := &dto.Usage{
- PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
- CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
- TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
- }
- rerankResp := &dto.RerankResponse{
- Results: siliconflowResp.Results,
- Usage: *usage,
- }
-
- jsonResponse, err := json.Marshal(rerankResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return usage, nil
-}
diff --git a/new-api/relay/channel/submodel/adaptor.go b/new-api/relay/channel/submodel/adaptor.go
deleted file mode 100644
index cff8bc08f06d6d9d33f98bb206ab773aa06ec9b5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/submodel/adaptor.go
+++ /dev/null
@@ -1,86 +0,0 @@
-package submodel
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- return nil, errors.New("submodel channel: endpoint not supported")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = openai.OaiStreamHandler(c, info, resp)
- } else {
- usage, err = openai.OpenaiHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/submodel/constants.go b/new-api/relay/channel/submodel/constants.go
deleted file mode 100644
index c59bcca3ea8efd3ef300bcc484e451668c9e3791..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/submodel/constants.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package submodel
-
-var ModelList = []string{
- "NousResearch/Hermes-4-405B-FP8",
- "Qwen/Qwen3-235B-A22B-Thinking-2507",
- "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8",
- "Qwen/Qwen3-235B-A22B-Instruct-2507",
- "zai-org/GLM-4.5-FP8",
- "openai/gpt-oss-120b",
- "deepseek-ai/DeepSeek-R1-0528",
- "deepseek-ai/DeepSeek-R1",
- "deepseek-ai/DeepSeek-V3-0324",
- "deepseek-ai/DeepSeek-V3.1",
-}
-
-const ChannelName = "submodel"
\ No newline at end of file
diff --git a/new-api/relay/channel/task/jimeng/adaptor.go b/new-api/relay/channel/task/jimeng/adaptor.go
deleted file mode 100644
index c7f648329daf86b65adda273ab6db027d54c5780..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/jimeng/adaptor.go
+++ /dev/null
@@ -1,404 +0,0 @@
-package jimeng
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "one-api/model"
- "sort"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
-
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/service"
-)
-
-// ============================
-// Request / Response structures
-// ============================
-
-type requestPayload struct {
- ReqKey string `json:"req_key"`
- BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
- ImageUrls []string `json:"image_urls,omitempty"`
- Prompt string `json:"prompt,omitempty"`
- Seed int64 `json:"seed"`
- AspectRatio string `json:"aspect_ratio"`
- Frames int `json:"frames,omitempty"`
-}
-
-type responsePayload struct {
- Code int `json:"code"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
- Data struct {
- TaskID string `json:"task_id"`
- } `json:"data"`
-}
-
-type responseTask struct {
- Code int `json:"code"`
- Data struct {
- BinaryDataBase64 []interface{} `json:"binary_data_base64"`
- ImageUrls interface{} `json:"image_urls"`
- RespData string `json:"resp_data"`
- Status string `json:"status"`
- VideoUrl string `json:"video_url"`
- } `json:"data"`
- Message string `json:"message"`
- RequestId string `json:"request_id"`
- Status int `json:"status"`
- TimeElapsed string `json:"time_elapsed"`
-}
-
-// ============================
-// Adaptor implementation
-// ============================
-
-type TaskAdaptor struct {
- ChannelType int
- accessKey string
- secretKey string
- baseURL string
-}
-
-func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
-
- // apiKey format: "access_key|secret_key"
- keyParts := strings.Split(info.ApiKey, "|")
- if len(keyParts) == 2 {
- a.accessKey = strings.TrimSpace(keyParts[0])
- a.secretKey = strings.TrimSpace(keyParts[1])
- }
-}
-
-// ValidateRequestAndSetAction parses body, validates fields and sets default action.
-func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // Accept only POST /v1/video/generations as "generate" action.
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
-}
-
-// BuildRequestURL constructs the upstream URL.
-func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if isNewAPIRelay(info.ApiKey) {
- return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
- }
- return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
-}
-
-// BuildRequestHeader sets required headers.
-func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- if isNewAPIRelay(info.ApiKey) {
- req.Header.Set("Authorization", "Bearer "+info.ApiKey)
- } else {
- return a.signRequest(req, a.accessKey, a.secretKey)
- }
- return nil
-}
-
-// BuildRequestBody converts request into Jimeng specific format.
-func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
- }
- req := v.(relaycommon.TaskSubmitReq)
-
- body, err := a.convertToRequestPayload(&req)
- if err != nil {
- return nil, errors.Wrap(err, "convert request payload failed")
- }
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-// DoRequest delegates to common helper.
-func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- return channel.DoTaskApiRequest(a, c, info, requestBody)
-}
-
-// DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
- _ = resp.Body.Close()
-
- // Parse Jimeng response
- var jResp responsePayload
- if err := json.Unmarshal(responseBody, &jResp); err != nil {
- taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
- return
- }
-
- if jResp.Code != 10000 {
- taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
- return
- }
-
- c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
- return jResp.Data.TaskID, responseBody, nil
-}
-
-// FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
-
- uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
- if isNewAPIRelay(key) {
- uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
- }
- payload := map[string]string{
- "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
- "task_id": taskID,
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- return nil, errors.Wrap(err, "marshal fetch task payload failed")
- }
-
- req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
- if err != nil {
- return nil, err
- }
-
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Content-Type", "application/json")
-
- if isNewAPIRelay(key) {
- req.Header.Set("Authorization", "Bearer "+key)
- } else {
- keyParts := strings.Split(key, "|")
- if len(keyParts) != 2 {
- return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- secretKey := strings.TrimSpace(keyParts[1])
-
- if err := a.signRequest(req, accessKey, secretKey); err != nil {
- return nil, errors.Wrap(err, "sign request failed")
- }
- }
- return service.GetHttpClient().Do(req)
-}
-
-func (a *TaskAdaptor) GetModelList() []string {
- return []string{"jimeng_vgfm_t2v_l20"}
-}
-
-func (a *TaskAdaptor) GetChannelName() string {
- return "jimeng"
-}
-
-func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
- var bodyBytes []byte
- var err error
-
- if req.Body != nil {
- bodyBytes, err = io.ReadAll(req.Body)
- if err != nil {
- return errors.Wrap(err, "read request body failed")
- }
- _ = req.Body.Close()
- req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
- } else {
- bodyBytes = []byte{}
- }
-
- payloadHash := sha256.Sum256(bodyBytes)
- hexPayloadHash := hex.EncodeToString(payloadHash[:])
-
- t := time.Now().UTC()
- xDate := t.Format("20060102T150405Z")
- shortDate := t.Format("20060102")
-
- req.Header.Set("Host", req.URL.Host)
- req.Header.Set("X-Date", xDate)
- req.Header.Set("X-Content-Sha256", hexPayloadHash)
-
- // Sort and encode query parameters to create canonical query string
- queryParams := req.URL.Query()
- sortedKeys := make([]string, 0, len(queryParams))
- for k := range queryParams {
- sortedKeys = append(sortedKeys, k)
- }
- sort.Strings(sortedKeys)
- var queryParts []string
- for _, k := range sortedKeys {
- values := queryParams[k]
- sort.Strings(values)
- for _, v := range values {
- queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
- }
- }
- canonicalQueryString := strings.Join(queryParts, "&")
-
- headersToSign := map[string]string{
- "host": req.URL.Host,
- "x-date": xDate,
- "x-content-sha256": hexPayloadHash,
- }
- if req.Header.Get("Content-Type") != "" {
- headersToSign["content-type"] = req.Header.Get("Content-Type")
- }
-
- var signedHeaderKeys []string
- for k := range headersToSign {
- signedHeaderKeys = append(signedHeaderKeys, k)
- }
- sort.Strings(signedHeaderKeys)
-
- var canonicalHeaders strings.Builder
- for _, k := range signedHeaderKeys {
- canonicalHeaders.WriteString(k)
- canonicalHeaders.WriteString(":")
- canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
- canonicalHeaders.WriteString("\n")
- }
- signedHeaders := strings.Join(signedHeaderKeys, ";")
-
- canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
- req.Method,
- req.URL.Path,
- canonicalQueryString,
- canonicalHeaders.String(),
- signedHeaders,
- hexPayloadHash,
- )
-
- hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
- hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
-
- region := "cn-north-1"
- serviceName := "cv"
- credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
- stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
- xDate,
- credentialScope,
- hexHashedCanonicalRequest,
- )
-
- kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
- kRegion := hmacSHA256(kDate, []byte(region))
- kService := hmacSHA256(kRegion, []byte(serviceName))
- kSigning := hmacSHA256(kService, []byte("request"))
- signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
-
- authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- accessKey,
- credentialScope,
- signedHeaders,
- signature,
- )
- req.Header.Set("Authorization", authorization)
- return nil
-}
-
-func hmacSHA256(key []byte, data []byte) []byte {
- h := hmac.New(sha256.New, key)
- h.Write(data)
- return h.Sum(nil)
-}
-
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- r := requestPayload{
- ReqKey: req.Model,
- Prompt: req.Prompt,
- }
-
- switch req.Duration {
- case 10:
- r.Frames = 241 // 24*10+1 = 241
- default:
- r.Frames = 121 // 24*5+1 = 121
- }
-
- // Handle one-of image_urls or binary_data_base64
- if req.HasImage() {
- if strings.HasPrefix(req.Images[0], "http") {
- r.ImageUrls = req.Images
- } else {
- r.BinaryDataBase64 = req.Images
- }
- }
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
- return nil, errors.Wrap(err, "unmarshal metadata failed")
- }
-
- // 即梦视频3.0 ReqKey转换
- // https://www.volcengine.com/docs/85621/1792707
- if strings.Contains(r.ReqKey, "jimeng_v30") {
- if len(r.ImageUrls) > 1 {
- // 多张图片:首尾帧生成
- r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
- } else if len(r.ImageUrls) == 1 {
- // 单张图片:图生视频
- r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
- } else {
- // 无图片:文生视频
- r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
- }
- }
-
- return &r, nil
-}
-
-func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- resTask := responseTask{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
- return nil, errors.Wrap(err, "unmarshal task result failed")
- }
- taskResult := relaycommon.TaskInfo{}
- if resTask.Code == 10000 {
- taskResult.Code = 0
- } else {
- taskResult.Code = resTask.Code // todo uni code
- taskResult.Reason = resTask.Message
- taskResult.Status = model.TaskStatusFailure
- taskResult.Progress = "100%"
- }
- switch resTask.Data.Status {
- case "in_queue":
- taskResult.Status = model.TaskStatusQueued
- taskResult.Progress = "10%"
- case "done":
- taskResult.Status = model.TaskStatusSuccess
- taskResult.Progress = "100%"
- }
- taskResult.Url = resTask.Data.VideoUrl
- return &taskResult, nil
-}
-
-func isNewAPIRelay(apiKey string) bool {
- return strings.HasPrefix(apiKey, "sk-")
-}
diff --git a/new-api/relay/channel/task/kling/adaptor.go b/new-api/relay/channel/task/kling/adaptor.go
deleted file mode 100644
index 46e12f6d1a62b18930260c7ec9e9d0eb474c8fc6..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/kling/adaptor.go
+++ /dev/null
@@ -1,371 +0,0 @@
-package kling
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/model"
- "strings"
- "time"
-
- "github.com/samber/lo"
-
- "github.com/gin-gonic/gin"
- "github.com/golang-jwt/jwt"
- "github.com/pkg/errors"
-
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/service"
-)
-
-// ============================
-// Request / Response structures
-// ============================
-
-type TrajectoryPoint struct {
- X int `json:"x"`
- Y int `json:"y"`
-}
-
-type DynamicMask struct {
- Mask string `json:"mask,omitempty"`
- Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
-}
-
-type CameraConfig struct {
- Horizontal float64 `json:"horizontal,omitempty"`
- Vertical float64 `json:"vertical,omitempty"`
- Pan float64 `json:"pan,omitempty"`
- Tilt float64 `json:"tilt,omitempty"`
- Roll float64 `json:"roll,omitempty"`
- Zoom float64 `json:"zoom,omitempty"`
-}
-
-type CameraControl struct {
- Type string `json:"type,omitempty"`
- Config *CameraConfig `json:"config,omitempty"`
-}
-
-type requestPayload struct {
- Prompt string `json:"prompt,omitempty"`
- Image string `json:"image,omitempty"`
- ImageTail string `json:"image_tail,omitempty"`
- NegativePrompt string `json:"negative_prompt,omitempty"`
- Mode string `json:"mode,omitempty"`
- Duration string `json:"duration,omitempty"`
- AspectRatio string `json:"aspect_ratio,omitempty"`
- ModelName string `json:"model_name,omitempty"`
- Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
- CfgScale float64 `json:"cfg_scale,omitempty"`
- StaticMask string `json:"static_mask,omitempty"`
- DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
- CameraControl *CameraControl `json:"camera_control,omitempty"`
- CallbackUrl string `json:"callback_url,omitempty"`
- ExternalTaskId string `json:"external_task_id,omitempty"`
-}
-
-type responsePayload struct {
- Code int `json:"code"`
- Message string `json:"message"`
- TaskId string `json:"task_id"`
- RequestId string `json:"request_id"`
- Data struct {
- TaskId string `json:"task_id"`
- TaskStatus string `json:"task_status"`
- TaskStatusMsg string `json:"task_status_msg"`
- TaskResult struct {
- Videos []struct {
- Id string `json:"id"`
- Url string `json:"url"`
- Duration string `json:"duration"`
- } `json:"videos"`
- } `json:"task_result"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
- } `json:"data"`
-}
-
-// ============================
-// Adaptor implementation
-// ============================
-
-type TaskAdaptor struct {
- ChannelType int
- apiKey string
- baseURL string
-}
-
-func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
- a.apiKey = info.ApiKey
-
- // apiKey format: "access_key|secret_key"
-}
-
-// ValidateRequestAndSetAction parses body, validates fields and sets default action.
-func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // Use the standard validation method for TaskSubmitReq
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
-}
-
-// BuildRequestURL constructs the upstream URL.
-func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
-
- if isNewAPIRelay(info.ApiKey) {
- return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
- }
-
- return fmt.Sprintf("%s%s", a.baseURL, path), nil
-}
-
-// BuildRequestHeader sets required headers.
-func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- token, err := a.createJWTToken()
- if err != nil {
- return fmt.Errorf("failed to create JWT token: %w", err)
- }
-
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Bearer "+token)
- req.Header.Set("User-Agent", "kling-sdk/1.0")
- return nil
-}
-
-// BuildRequestBody converts request into Kling specific format.
-func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
- }
- req := v.(relaycommon.TaskSubmitReq)
-
- body, err := a.convertToRequestPayload(&req)
- if err != nil {
- return nil, err
- }
- if body.Image == "" && body.ImageTail == "" {
- c.Set("action", constant.TaskActionTextGenerate)
- }
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-// DoRequest delegates to common helper.
-func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- if action := c.GetString("action"); action != "" {
- info.Action = action
- }
- return channel.DoTaskApiRequest(a, c, info, requestBody)
-}
-
-// DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
-
- var kResp responsePayload
- err = json.Unmarshal(responseBody, &kResp)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
- return
- }
- if kResp.Code != 0 {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
- return
- }
- kResp.TaskId = kResp.Data.TaskId
- c.JSON(http.StatusOK, kResp)
- return kResp.Data.TaskId, responseBody, nil
-}
-
-// FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
- action, ok := body["action"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid action")
- }
- path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
- url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
- if isNewAPIRelay(key) {
- url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
- }
-
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return nil, err
- }
-
- token, err := a.createJWTTokenWithKey(key)
- if err != nil {
- token = key
- }
-
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Bearer "+token)
- req.Header.Set("User-Agent", "kling-sdk/1.0")
-
- return service.GetHttpClient().Do(req)
-}
-
-func (a *TaskAdaptor) GetModelList() []string {
- return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
-}
-
-func (a *TaskAdaptor) GetChannelName() string {
- return "kling"
-}
-
-// ============================
-// helpers
-// ============================
-
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- r := requestPayload{
- Prompt: req.Prompt,
- Image: req.Image,
- Mode: defaultString(req.Mode, "std"),
- Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
- AspectRatio: a.getAspectRatio(req.Size),
- ModelName: req.Model,
- Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
- CfgScale: 0.5,
- StaticMask: "",
- DynamicMasks: []DynamicMask{},
- CameraControl: nil,
- CallbackUrl: "",
- ExternalTaskId: "",
- }
- if r.ModelName == "" {
- r.ModelName = "kling-v1"
- }
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
- return nil, errors.Wrap(err, "unmarshal metadata failed")
- }
- return &r, nil
-}
-
-func (a *TaskAdaptor) getAspectRatio(size string) string {
- switch size {
- case "1024x1024", "512x512":
- return "1:1"
- case "1280x720", "1920x1080":
- return "16:9"
- case "720x1280", "1080x1920":
- return "9:16"
- default:
- return "1:1"
- }
-}
-
-func defaultString(s, def string) string {
- if strings.TrimSpace(s) == "" {
- return def
- }
- return s
-}
-
-func defaultInt(v int, def int) int {
- if v == 0 {
- return def
- }
- return v
-}
-
-// ============================
-// JWT helpers
-// ============================
-
-func (a *TaskAdaptor) createJWTToken() (string, error) {
- return a.createJWTTokenWithKey(a.apiKey)
-}
-
-//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
-// parts := strings.Split(apiKey, "|")
-// if len(parts) != 2 {
-// return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
-// }
-// return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
-//}
-
-func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
- if isNewAPIRelay(apiKey) {
- return apiKey, nil // new api relay
- }
- keyParts := strings.Split(apiKey, "|")
- if len(keyParts) != 2 {
- return "", errors.New("invalid api_key, required format is accessKey|secretKey")
- }
- accessKey := strings.TrimSpace(keyParts[0])
- if len(keyParts) == 1 {
- return accessKey, nil
- }
- secretKey := strings.TrimSpace(keyParts[1])
- now := time.Now().Unix()
- claims := jwt.MapClaims{
- "iss": accessKey,
- "exp": now + 1800, // 30 minutes
- "nbf": now - 5,
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- token.Header["typ"] = "JWT"
- return token.SignedString([]byte(secretKey))
-}
-
-func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- taskInfo := &relaycommon.TaskInfo{}
- resPayload := responsePayload{}
- err := json.Unmarshal(respBody, &resPayload)
- if err != nil {
- return nil, errors.Wrap(err, "failed to unmarshal response body")
- }
- taskInfo.Code = resPayload.Code
- taskInfo.TaskID = resPayload.Data.TaskId
- taskInfo.Reason = resPayload.Message
- //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
- status := resPayload.Data.TaskStatus
- switch status {
- case "submitted":
- taskInfo.Status = model.TaskStatusSubmitted
- case "processing":
- taskInfo.Status = model.TaskStatusInProgress
- case "succeed":
- taskInfo.Status = model.TaskStatusSuccess
- case "failed":
- taskInfo.Status = model.TaskStatusFailure
- default:
- return nil, fmt.Errorf("unknown task status: %s", status)
- }
- if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
- video := videos[0]
- taskInfo.Url = video.Url
- }
- return taskInfo, nil
-}
-
-func isNewAPIRelay(apiKey string) bool {
- return strings.HasPrefix(apiKey, "sk-")
-}
diff --git a/new-api/relay/channel/task/suno/adaptor.go b/new-api/relay/channel/task/suno/adaptor.go
deleted file mode 100644
index 5f49fd3e766e5d79ee553ef26834c6b63da40ec7..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/suno/adaptor.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package suno
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-type TaskAdaptor struct {
- ChannelType int
-}
-
-func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
- return nil, fmt.Errorf("not implement") // todo implement this method if needed
-}
-
-func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
-}
-
-func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- action := strings.ToUpper(c.Param("action"))
-
- var sunoRequest *dto.SunoSubmitReq
- err := common.UnmarshalBodyReusable(c, &sunoRequest)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
- return
- }
- err = actionValidate(c, sunoRequest, action)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
- return
- }
-
- if sunoRequest.ContinueClipId != "" {
- if sunoRequest.TaskID == "" {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
- return
- }
- info.OriginTaskID = sunoRequest.TaskID
- }
-
- info.Action = action
- c.Set("task_request", sunoRequest)
- return nil
-}
-
-func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- baseURL := info.ChannelBaseUrl
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
- return fullRequestURL, nil
-}
-
-func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- req.Header.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- sunoRequest, ok := c.Get("task_request")
- if !ok {
- err := common.UnmarshalBodyReusable(c, &sunoRequest)
- if err != nil {
- return nil, err
- }
- }
- data, err := json.Marshal(sunoRequest)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- return channel.DoTaskApiRequest(a, c, info, requestBody)
-}
-
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
- var sunoResponse dto.TaskResponse[string]
- err = json.Unmarshal(responseBody, &sunoResponse)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- return
- }
- if !sunoResponse.IsSuccess() {
- taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
- return
- }
-
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
-
- _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
- }
-
- return sunoResponse.Data, nil, nil
-}
-
-func (a *TaskAdaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *TaskAdaptor) GetChannelName() string {
- return ChannelName
-}
-
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
- byteBody, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
-
- req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
- if err != nil {
- common.SysLog(fmt.Sprintf("Get Task error: %v", err))
- return nil, err
- }
- defer req.Body.Close()
- // 设置超时时间
- timeout := time.Second * 15
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+key)
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- return nil, err
- }
- return resp, nil
-}
-
-func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
- switch action {
- case constant.SunoActionMusic:
- if sunoRequest.Mv == "" {
- sunoRequest.Mv = "chirp-v3-0"
- }
- case constant.SunoActionLyrics:
- if sunoRequest.Prompt == "" {
- err = fmt.Errorf("prompt_empty")
- return
- }
- default:
- err = fmt.Errorf("invalid_action")
- }
- return
-}
diff --git a/new-api/relay/channel/task/suno/models.go b/new-api/relay/channel/task/suno/models.go
deleted file mode 100644
index 08dbc9caa6ac18b94f2e7d3daabc5317a3584e1c..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/suno/models.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package suno
-
-var ModelList = []string{
- "suno_music", "suno_lyrics",
-}
-
-var ChannelName = "suno"
diff --git a/new-api/relay/channel/task/vertex/adaptor.go b/new-api/relay/channel/task/vertex/adaptor.go
deleted file mode 100644
index 7326b44c6a13f14c385a6fd0bf2d75e9eedb0b79..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/vertex/adaptor.go
+++ /dev/null
@@ -1,355 +0,0 @@
-package vertex
-
-import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/model"
- "regexp"
- "strings"
-
- "github.com/gin-gonic/gin"
-
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- vertexcore "one-api/relay/channel/vertex"
- relaycommon "one-api/relay/common"
- "one-api/service"
-)
-
-// ============================
-// Request / Response structures
-// ============================
-
-type requestPayload struct {
- Instances []map[string]any `json:"instances"`
- Parameters map[string]any `json:"parameters,omitempty"`
-}
-
-type submitResponse struct {
- Name string `json:"name"`
-}
-
-type operationVideo struct {
- MimeType string `json:"mimeType"`
- BytesBase64Encoded string `json:"bytesBase64Encoded"`
- Encoding string `json:"encoding"`
-}
-
-type operationResponse struct {
- Name string `json:"name"`
- Done bool `json:"done"`
- Response struct {
- Type string `json:"@type"`
- RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
- Videos []operationVideo `json:"videos"`
- BytesBase64Encoded string `json:"bytesBase64Encoded"`
- Encoding string `json:"encoding"`
- Video string `json:"video"`
- } `json:"response"`
- Error struct {
- Message string `json:"message"`
- } `json:"error"`
-}
-
-// ============================
-// Adaptor implementation
-// ============================
-
-type TaskAdaptor struct {
- ChannelType int
- apiKey string
- baseURL string
-}
-
-func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
- a.apiKey = info.ApiKey
-}
-
-// ValidateRequestAndSetAction parses body, validates fields and sets default action.
-func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // Use the standard validation method for TaskSubmitReq
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
-}
-
-// BuildRequestURL constructs the upstream URL.
-func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
- return "", fmt.Errorf("failed to decode credentials: %w", err)
- }
- modelName := info.OriginModelName
- if modelName == "" {
- modelName = "veo-3.0-generate-001"
- }
-
- region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
- if strings.TrimSpace(region) == "" {
- region = "global"
- }
- if region == "global" {
- return fmt.Sprintf(
- "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
- adc.ProjectID,
- modelName,
- ), nil
- }
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
- region,
- adc.ProjectID,
- region,
- modelName,
- ), nil
-}
-
-// BuildRequestHeader sets required headers.
-func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
-
- adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
- return fmt.Errorf("failed to decode credentials: %w", err)
- }
-
- token, err := vertexcore.AcquireAccessToken(*adc, "")
- if err != nil {
- return fmt.Errorf("failed to acquire access token: %w", err)
- }
- req.Header.Set("Authorization", "Bearer "+token)
- req.Header.Set("x-goog-user-project", adc.ProjectID)
- return nil
-}
-
-// BuildRequestBody converts request into Vertex specific format.
-func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, ok := c.Get("task_request")
- if !ok {
- return nil, fmt.Errorf("request not found in context")
- }
- req := v.(relaycommon.TaskSubmitReq)
-
- body := requestPayload{
- Instances: []map[string]any{{"prompt": req.Prompt}},
- Parameters: map[string]any{},
- }
- if req.Metadata != nil {
- if v, ok := req.Metadata["storageUri"]; ok {
- body.Parameters["storageUri"] = v
- }
- if v, ok := req.Metadata["sampleCount"]; ok {
- body.Parameters["sampleCount"] = v
- }
- }
- if _, ok := body.Parameters["sampleCount"]; !ok {
- body.Parameters["sampleCount"] = 1
- }
-
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-// DoRequest delegates to common helper.
-func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- return channel.DoTaskApiRequest(a, c, info, requestBody)
-}
-
-// DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- _ = resp.Body.Close()
-
- var s submitResponse
- if err := json.Unmarshal(responseBody, &s); err != nil {
- return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
- }
- if strings.TrimSpace(s.Name) == "" {
- return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
- }
- localID := encodeLocalTaskID(s.Name)
- c.JSON(http.StatusOK, gin.H{"task_id": localID})
- return localID, responseBody, nil
-}
-
-func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
-func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
-
-// FetchTask fetch task status
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
- upstreamName, err := decodeLocalTaskID(taskID)
- if err != nil {
- return nil, fmt.Errorf("decode task_id failed: %w", err)
- }
- region := extractRegionFromOperationName(upstreamName)
- if region == "" {
- region = "us-central1"
- }
- project := extractProjectFromOperationName(upstreamName)
- modelName := extractModelFromOperationName(upstreamName)
- if project == "" || modelName == "" {
- return nil, fmt.Errorf("cannot extract project/model from operation name")
- }
- var url string
- if region == "global" {
- url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
- } else {
- url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
- }
- payload := map[string]string{"operationName": upstreamName}
- data, err := json.Marshal(payload)
- if err != nil {
- return nil, err
- }
- adc := &vertexcore.Credentials{}
- if err := json.Unmarshal([]byte(key), adc); err != nil {
- return nil, fmt.Errorf("failed to decode credentials: %w", err)
- }
- token, err := vertexcore.AcquireAccessToken(*adc, "")
- if err != nil {
- return nil, fmt.Errorf("failed to acquire access token: %w", err)
- }
- req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Bearer "+token)
- req.Header.Set("x-goog-user-project", adc.ProjectID)
- return service.GetHttpClient().Do(req)
-}
-
-func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- var op operationResponse
- if err := json.Unmarshal(respBody, &op); err != nil {
- return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
- }
- ti := &relaycommon.TaskInfo{}
- if op.Error.Message != "" {
- ti.Status = model.TaskStatusFailure
- ti.Reason = op.Error.Message
- ti.Progress = "100%"
- return ti, nil
- }
- if !op.Done {
- ti.Status = model.TaskStatusInProgress
- ti.Progress = "50%"
- return ti, nil
- }
- ti.Status = model.TaskStatusSuccess
- ti.Progress = "100%"
- if len(op.Response.Videos) > 0 {
- v0 := op.Response.Videos[0]
- if v0.BytesBase64Encoded != "" {
- mime := strings.TrimSpace(v0.MimeType)
- if mime == "" {
- enc := strings.TrimSpace(v0.Encoding)
- if enc == "" {
- enc = "mp4"
- }
- if strings.Contains(enc, "/") {
- mime = enc
- } else {
- mime = "video/" + enc
- }
- }
- ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
- return ti, nil
- }
- }
- if op.Response.BytesBase64Encoded != "" {
- enc := strings.TrimSpace(op.Response.Encoding)
- if enc == "" {
- enc = "mp4"
- }
- mime := enc
- if !strings.Contains(enc, "/") {
- mime = "video/" + enc
- }
- ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
- return ti, nil
- }
- if op.Response.Video != "" { // some variants use `video` as base64
- enc := strings.TrimSpace(op.Response.Encoding)
- if enc == "" {
- enc = "mp4"
- }
- mime := enc
- if !strings.Contains(enc, "/") {
- mime = "video/" + enc
- }
- ti.Url = "data:" + mime + ";base64," + op.Response.Video
- return ti, nil
- }
- return ti, nil
-}
-
-// ============================
-// helpers
-// ============================
-
-func encodeLocalTaskID(name string) string {
- return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
- b, err := base64.RawURLEncoding.DecodeString(local)
- if err != nil {
- return "", err
- }
- return string(b), nil
-}
-
-var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
-
-func extractRegionFromOperationName(name string) string {
- m := regionRe.FindStringSubmatch(name)
- if len(m) == 2 {
- return m[1]
- }
- return ""
-}
-
-var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
-
-func extractModelFromOperationName(name string) string {
- m := modelRe.FindStringSubmatch(name)
- if len(m) == 2 {
- return m[1]
- }
- idx := strings.Index(name, "models/")
- if idx >= 0 {
- s := name[idx+len("models/"):]
- if p := strings.Index(s, "/operations/"); p > 0 {
- return s[:p]
- }
- }
- return ""
-}
-
-var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
-
-func extractProjectFromOperationName(name string) string {
- m := projectRe.FindStringSubmatch(name)
- if len(m) == 2 {
- return m[1]
- }
- return ""
-}
diff --git a/new-api/relay/channel/task/vidu/adaptor.go b/new-api/relay/channel/task/vidu/adaptor.go
deleted file mode 100644
index c25221f1b798cb94636cfbdf7084964d8cf8bf2a..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/task/vidu/adaptor.go
+++ /dev/null
@@ -1,258 +0,0 @@
-package vidu
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
-
- "github.com/gin-gonic/gin"
-
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/service"
-
- "github.com/pkg/errors"
-)
-
-// ============================
-// Request / Response structures
-// ============================
-
-type requestPayload struct {
- Model string `json:"model"`
- Images []string `json:"images"`
- Prompt string `json:"prompt,omitempty"`
- Duration int `json:"duration,omitempty"`
- Seed int `json:"seed,omitempty"`
- Resolution string `json:"resolution,omitempty"`
- MovementAmplitude string `json:"movement_amplitude,omitempty"`
- Bgm bool `json:"bgm,omitempty"`
- Payload string `json:"payload,omitempty"`
- CallbackUrl string `json:"callback_url,omitempty"`
-}
-
-type responsePayload struct {
- TaskId string `json:"task_id"`
- State string `json:"state"`
- Model string `json:"model"`
- Images []string `json:"images"`
- Prompt string `json:"prompt"`
- Duration int `json:"duration"`
- Seed int `json:"seed"`
- Resolution string `json:"resolution"`
- Bgm bool `json:"bgm"`
- MovementAmplitude string `json:"movement_amplitude"`
- Payload string `json:"payload"`
- CreatedAt string `json:"created_at"`
-}
-
-type taskResultResponse struct {
- State string `json:"state"`
- ErrCode string `json:"err_code"`
- Credits int `json:"credits"`
- Payload string `json:"payload"`
- Creations []creation `json:"creations"`
-}
-
-type creation struct {
- ID string `json:"id"`
- URL string `json:"url"`
- CoverURL string `json:"cover_url"`
-}
-
-// ============================
-// Adaptor implementation
-// ============================
-
-type TaskAdaptor struct {
- ChannelType int
- baseURL string
-}
-
-func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
-}
-
-func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
-}
-
-func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
- }
- req := v.(relaycommon.TaskSubmitReq)
-
- body, err := a.convertToRequestPayload(&req)
- if err != nil {
- return nil, err
- }
-
- if len(body.Images) == 0 {
- c.Set("action", constant.TaskActionTextGenerate)
- }
-
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
-}
-
-func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- var path string
- switch info.Action {
- case constant.TaskActionGenerate:
- path = "/img2video"
- case constant.TaskActionFirstTailGenerate:
- path = "/start-end2video"
- case constant.TaskActionReferenceGenerate:
- path = "/reference2video"
- default:
- path = "/text2video"
- }
- return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
-}
-
-func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Token "+info.ApiKey)
- return nil
-}
-
-func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- if action := c.GetString("action"); action != "" {
- info.Action = action
- }
- return channel.DoTaskApiRequest(a, c, info, requestBody)
-}
-
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
-
- var vResp responsePayload
- err = json.Unmarshal(responseBody, &vResp)
- if err != nil {
- taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
- return
- }
-
- if vResp.State == "failed" {
- taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
- return
- }
-
- c.JSON(http.StatusOK, vResp)
- return vResp.TaskId, responseBody, nil
-}
-
-func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
-
- url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
-
- req, err := http.NewRequest(http.MethodGet, url, nil)
- if err != nil {
- return nil, err
- }
-
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Token "+key)
-
- return service.GetHttpClient().Do(req)
-}
-
-func (a *TaskAdaptor) GetModelList() []string {
- return []string{"viduq1", "vidu2.0", "vidu1.5"}
-}
-
-func (a *TaskAdaptor) GetChannelName() string {
- return "vidu"
-}
-
-// ============================
-// helpers
-// ============================
-
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- r := requestPayload{
- Model: defaultString(req.Model, "viduq1"),
- Images: req.Images,
- Prompt: req.Prompt,
- Duration: defaultInt(req.Duration, 5),
- Resolution: defaultString(req.Size, "1080p"),
- MovementAmplitude: "auto",
- Bgm: false,
- }
- metadata := req.Metadata
- medaBytes, err := json.Marshal(metadata)
- if err != nil {
- return nil, errors.Wrap(err, "metadata marshal metadata failed")
- }
- err = json.Unmarshal(medaBytes, &r)
- if err != nil {
- return nil, errors.Wrap(err, "unmarshal metadata failed")
- }
- return &r, nil
-}
-
-func defaultString(value, defaultValue string) string {
- if value == "" {
- return defaultValue
- }
- return value
-}
-
-func defaultInt(value, defaultValue int) int {
- if value == 0 {
- return defaultValue
- }
- return value
-}
-
-func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- taskInfo := &relaycommon.TaskInfo{}
-
- var taskResp taskResultResponse
- err := json.Unmarshal(respBody, &taskResp)
- if err != nil {
- return nil, errors.Wrap(err, "failed to unmarshal response body")
- }
-
- state := taskResp.State
- switch state {
- case "created", "queueing":
- taskInfo.Status = model.TaskStatusSubmitted
- case "processing":
- taskInfo.Status = model.TaskStatusInProgress
- case "success":
- taskInfo.Status = model.TaskStatusSuccess
- if len(taskResp.Creations) > 0 {
- taskInfo.Url = taskResp.Creations[0].URL
- }
- case "failed":
- taskInfo.Status = model.TaskStatusFailure
- if taskResp.ErrCode != "" {
- taskInfo.Reason = taskResp.ErrCode
- }
- default:
- return nil, fmt.Errorf("unknown task state: %s", state)
- }
-
- return taskInfo, nil
-}
diff --git a/new-api/relay/channel/tencent/adaptor.go b/new-api/relay/channel/tencent/adaptor.go
deleted file mode 100644
index a4c630a5bdc32ed20600aafd995dcac298e0d73e..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/tencent/adaptor.go
+++ /dev/null
@@ -1,118 +0,0 @@
-package tencent
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/types"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
- Sign string
- AppID int64
- Action string
- Version string
- Timestamp int64
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- a.Action = "ChatCompletions"
- a.Version = "2023-09-01"
- a.Timestamp = common.GetTimestamp()
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", a.Sign)
- req.Set("X-TC-Action", a.Action)
- req.Set("X-TC-Version", a.Version)
- req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
- appId, secretId, secretKey, err := parseTencentConfig(apiKey)
- a.AppID = appId
- if err != nil {
- return nil, err
- }
- tencentRequest := requestOpenAI2Tencent(a, *request)
- // we have to calculate the sign here
- a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
- return tencentRequest, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = tencentStreamHandler(c, info, resp)
- } else {
- usage, err = tencentHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/tencent/constants.go b/new-api/relay/channel/tencent/constants.go
deleted file mode 100644
index 4fccfadf07f258f2fbea8d2bd15108daab036fcf..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/tencent/constants.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package tencent
-
-var ModelList = []string{
- "hunyuan-lite",
- "hunyuan-standard",
- "hunyuan-standard-256K",
- "hunyuan-pro",
-}
-
-var ChannelName = "tencent"
diff --git a/new-api/relay/channel/tencent/dto.go b/new-api/relay/channel/tencent/dto.go
deleted file mode 100644
index c50a2a3eba9c400b23ef62bfb5cfae3ce09f2634..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/tencent/dto.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package tencent
-
-type TencentMessage struct {
- Role string `json:"Role"`
- Content string `json:"Content"`
-}
-
-type TencentChatRequest struct {
- // 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
- // 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
- //
- // 注意:
- // 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
- Model *string `json:"Model"`
- // 聊天上下文信息。
- // 说明:
- // 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
- // 2. Message.Role 可选值:system、user、assistant。
- // 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
- // 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
- Messages []*TencentMessage `json:"Messages"`
- // 流式调用开关。
- // 说明:
- // 1. 未传值时默认为非流式调用(false)。
- // 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
- // 3. 非流式调用时:
- // 调用方式与普通 HTTP 请求无异。
- // 接口响应耗时较长,**如需更低时延建议设置为 true**。
- // 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
- //
- // 注意:
- // 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
- Stream *bool `json:"Stream,omitempty"`
- // 说明:
- // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
- // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
- // 3. 非必要不建议使用,不合理的取值会影响效果。
- TopP *float64 `json:"TopP,omitempty"`
- // 说明:
- // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
- // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
- // 3. 非必要不建议使用,不合理的取值会影响效果。
- Temperature *float64 `json:"Temperature,omitempty"`
-}
-
-type TencentError struct {
- Code int `json:"Code"`
- Message string `json:"Message"`
-}
-
-type TencentUsage struct {
- PromptTokens int `json:"PromptTokens"`
- CompletionTokens int `json:"CompletionTokens"`
- TotalTokens int `json:"TotalTokens"`
-}
-
-type TencentResponseChoices struct {
- FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
- Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
- Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
-}
-
-type TencentChatResponse struct {
- Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
- Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
- Id string `json:"Id,omitempty"` // 会话 id
- Usage TencentUsage `json:"Usage,omitempty"` // token 数量
- Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
- Note string `json:"Note,omitempty"` // 注释
- ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
-}
-
-type TencentChatResponseSB struct {
- Response TencentChatResponse `json:"Response,omitempty"`
-}
diff --git a/new-api/relay/channel/tencent/relay-tencent.go b/new-api/relay/channel/tencent/relay-tencent.go
deleted file mode 100644
index 784e7b5c0e237b7d7c41266fb4ce487bd2b5ee19..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/tencent/relay-tencent.go
+++ /dev/null
@@ -1,233 +0,0 @@
-package tencent
-
-import (
- "bufio"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-// https://cloud.tencent.com/document/product/1729/97732
-
-func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
- messages := make([]*TencentMessage, 0, len(request.Messages))
- for i := 0; i < len(request.Messages); i++ {
- message := request.Messages[i]
- messages = append(messages, &TencentMessage{
- Content: message.StringContent(),
- Role: message.Role,
- })
- }
- var req = TencentChatRequest{
- Stream: &request.Stream,
- Messages: messages,
- Model: &request.Model,
- }
- if request.TopP != 0 {
- req.TopP = &request.TopP
- }
- req.Temperature = request.Temperature
- return &req
-}
-
-func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
- fullTextResponse := dto.OpenAITextResponse{
- Id: response.Id,
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Usage: dto.Usage{
- PromptTokens: response.Usage.PromptTokens,
- CompletionTokens: response.Usage.CompletionTokens,
- TotalTokens: response.Usage.TotalTokens,
- },
- }
- if len(response.Choices) > 0 {
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: response.Choices[0].Messages.Content,
- },
- FinishReason: response.Choices[0].FinishReason,
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- }
- return &fullTextResponse
-}
-
-func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
- response := dto.ChatCompletionsStreamResponse{
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "tencent-hunyuan",
- }
- if len(TencentResponse.Choices) > 0 {
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
- if TencentResponse.Choices[0].FinishReason == "stop" {
- choice.FinishReason = &constant.FinishReasonStop
- }
- response.Choices = append(response.Choices, choice)
- }
- return &response
-}
-
-func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var responseText string
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
-
- helper.SetEventStreamHeaders(c)
-
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 5 || !strings.HasPrefix(data, "data:") {
- continue
- }
- data = strings.TrimPrefix(data, "data:")
-
- var tencentResponse TencentChatResponse
- err := json.Unmarshal([]byte(data), &tencentResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- continue
- }
-
- response := streamResponseTencent2OpenAI(&tencentResponse)
- if len(response.Choices) != 0 {
- responseText += response.Choices[0].Delta.GetContentString()
- }
-
- err = helper.ObjectData(c, response)
- if err != nil {
- common.SysLog(err.Error())
- }
- }
-
- if err := scanner.Err(); err != nil {
- common.SysLog("error reading stream: " + err.Error())
- }
-
- helper.Done(c)
-
- service.CloseResponseBodyGracefully(resp)
-
- return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
-}
-
-func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var tencentSb TencentChatResponseSB
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &tencentSb)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if tencentSb.Response.Error.Code != 0 {
- return nil, types.WithOpenAIError(types.OpenAIError{
- Message: tencentSb.Response.Error.Message,
- Code: tencentSb.Response.Error.Code,
- }, resp.StatusCode)
- }
- fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
- jsonResponse, err := common.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- service.IOCopyBytesGracefully(c, resp, jsonResponse)
- return &fullTextResponse.Usage, nil
-}
-
-func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
- parts := strings.Split(config, "|")
- if len(parts) != 3 {
- err = errors.New("invalid tencent config")
- return
- }
- appId, err = strconv.ParseInt(parts[0], 10, 64)
- secretId = parts[1]
- secretKey = parts[2]
- return
-}
-
-func sha256hex(s string) string {
- b := sha256.Sum256([]byte(s))
- return hex.EncodeToString(b[:])
-}
-
-func hmacSha256(s, key string) string {
- hashed := hmac.New(sha256.New, []byte(key))
- hashed.Write([]byte(s))
- return string(hashed.Sum(nil))
-}
-
-func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
- // build canonical request string
- host := "hunyuan.tencentcloudapi.com"
- httpRequestMethod := "POST"
- canonicalURI := "/"
- canonicalQueryString := ""
- canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
- "application/json", host, strings.ToLower(adaptor.Action))
- signedHeaders := "content-type;host;x-tc-action"
- payload, _ := json.Marshal(req)
- hashedRequestPayload := sha256hex(string(payload))
- canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
- httpRequestMethod,
- canonicalURI,
- canonicalQueryString,
- canonicalHeaders,
- signedHeaders,
- hashedRequestPayload)
- // build string to sign
- algorithm := "TC3-HMAC-SHA256"
- requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
- timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
- t := time.Unix(timestamp, 0).UTC()
- // must be the format 2006-01-02, ref to package time for more info
- date := t.Format("2006-01-02")
- credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
- hashedCanonicalRequest := sha256hex(canonicalRequest)
- string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
- algorithm,
- requestTimestamp,
- credentialScope,
- hashedCanonicalRequest)
-
- // sign string
- secretDate := hmacSha256(date, "TC3"+secKey)
- secretService := hmacSha256("hunyuan", secretDate)
- secretKey := hmacSha256("tc3_request", secretService)
- signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
-
- // build authorization
- authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
- algorithm,
- secId,
- credentialScope,
- signedHeaders,
- signature)
- return authorization
-}
diff --git a/new-api/relay/channel/vertex/adaptor.go b/new-api/relay/channel/vertex/adaptor.go
deleted file mode 100644
index 6dfaa33b3e15a332fbd4af1fa80c3eb3a08b11af..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/vertex/adaptor.go
+++ /dev/null
@@ -1,348 +0,0 @@
-package vertex
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/gemini"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- RequestModeClaude = 1
- RequestModeGemini = 2
- RequestModeLlama = 3
-)
-
-var claudeModelMap = map[string]string{
- "claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
- "claude-3-opus-20240229": "claude-3-opus@20240229",
- "claude-3-haiku-20240307": "claude-3-haiku@20240307",
- "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
- "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
- "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
- "claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
- "claude-opus-4-20250514": "claude-opus-4@20250514",
- "claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
- "claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
-}
-
-const anthropicVersion = "vertex-2023-10-16"
-
-type Adaptor struct {
- RequestMode int
- AccountCredentials Credentials
-}
-
-func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
- geminiAdaptor := gemini.Adaptor{}
- return geminiAdaptor.ConvertGeminiRequest(c, info, request)
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
- c.Set("request_model", v)
- } else {
- c.Set("request_model", request.Model)
- }
- vertexClaudeReq := copyRequest(request, anthropicVersion)
- return vertexClaudeReq, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- geminiAdaptor := gemini.Adaptor{}
- return geminiAdaptor.ConvertImageRequest(c, info, request)
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- if strings.HasPrefix(info.UpstreamModelName, "claude") {
- a.RequestMode = RequestModeClaude
- } else if strings.Contains(info.UpstreamModelName, "llama") {
- a.RequestMode = RequestModeLlama
- } else {
- a.RequestMode = RequestModeGemini
- }
-}
-
-func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
- region := GetModelRegion(info.ApiVersion, info.OriginModelName)
- if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
- adc := &Credentials{}
- if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
- return "", fmt.Errorf("failed to decode credentials file: %w", err)
- }
- a.AccountCredentials = *adc
-
- if a.RequestMode == RequestModeLlama {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
- region,
- adc.ProjectID,
- region,
- ), nil
- }
-
- if region == "global" {
- return fmt.Sprintf(
- "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
- adc.ProjectID,
- modelName,
- suffix,
- ), nil
- } else {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- modelName,
- suffix,
- ), nil
- }
- } else {
- if region == "global" {
- return fmt.Sprintf(
- "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
- modelName,
- suffix,
- info.ApiKey,
- ), nil
- } else {
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
- region,
- modelName,
- suffix,
- info.ApiKey,
- ), nil
- }
- }
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- suffix := ""
- if a.RequestMode == RequestModeGemini {
- if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // 新增逻辑:处理 -thinking- 格式
- if strings.Contains(info.UpstreamModelName, "-thinking-") {
- parts := strings.Split(info.UpstreamModelName, "-thinking-")
- info.UpstreamModelName = parts[0]
- } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
- }
- }
-
- if info.IsStream {
- suffix = "streamGenerateContent?alt=sse"
- } else {
- suffix = "generateContent"
- }
-
- if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- suffix = "predict"
- }
- return a.getRequestUrl(info, info.UpstreamModelName, suffix)
- } else if a.RequestMode == RequestModeClaude {
- if info.IsStream {
- suffix = "streamRawPredict?alt=sse"
- } else {
- suffix = "rawPredict"
- }
- model := info.UpstreamModelName
- if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
- model = v
- }
- return a.getRequestUrl(info, model, suffix)
- } else if a.RequestMode == RequestModeLlama {
- return a.getRequestUrl(info, "", "")
- }
- return "", errors.New("unsupported request mode")
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
- accessToken, err := getAccessToken(a, info)
- if err != nil {
- return err
- }
- req.Set("Authorization", "Bearer "+accessToken)
- }
- if a.AccountCredentials.ProjectID != "" {
- req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
- }
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
- prompt := ""
- for _, m := range request.Messages {
- if m.Role == "user" {
- prompt = m.StringContent()
- if prompt != "" {
- break
- }
- }
- }
- if prompt == "" {
- if p, ok := request.Prompt.(string); ok {
- prompt = p
- }
- }
- if prompt == "" {
- return nil, errors.New("prompt is required for image generation")
- }
-
- imgReq := dto.ImageRequest{
- Model: request.Model,
- Prompt: prompt,
- N: 1,
- Size: "1024x1024",
- }
- if request.N > 0 {
- imgReq.N = uint(request.N)
- }
- if request.Size != "" {
- imgReq.Size = request.Size
- }
- if len(request.ExtraBody) > 0 {
- var extra map[string]any
- if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
- if n, ok := extra["n"].(float64); ok && n > 0 {
- imgReq.N = uint(n)
- }
- if size, ok := extra["size"].(string); ok {
- imgReq.Size = size
- }
- // accept aspectRatio in extra body (top-level or under parameters)
- if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
- imgReq.Size = ar
- }
- if params, ok := extra["parameters"].(map[string]any); ok {
- if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
- imgReq.Size = ar
- }
- }
- }
- }
- c.Set("request_model", request.Model)
- return a.ConvertImageRequest(c, info, imgReq)
- }
- if a.RequestMode == RequestModeClaude {
- claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
- if err != nil {
- return nil, err
- }
- vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
- c.Set("request_model", claudeReq.Model)
- info.UpstreamModelName = claudeReq.Model
- return vertexClaudeReq, nil
- } else if a.RequestMode == RequestModeGemini {
- geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
- if err != nil {
- return nil, err
- }
- c.Set("request_model", request.Model)
- return geminiRequest, nil
- } else if a.RequestMode == RequestModeLlama {
- return request, nil
- }
- return nil, errors.New("unsupported request mode")
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- switch a.RequestMode {
- case RequestModeClaude:
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- case RequestModeGemini:
- if info.RelayMode == constant.RelayModeGemini {
- return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
- } else {
- return gemini.GeminiChatStreamHandler(c, info, resp)
- }
- case RequestModeLlama:
- return openai.OaiStreamHandler(c, info, resp)
- }
- } else {
- switch a.RequestMode {
- case RequestModeClaude:
- return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
- case RequestModeGemini:
- if info.RelayMode == constant.RelayModeGemini {
- return gemini.GeminiTextGenerationHandler(c, info, resp)
- } else {
- if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return gemini.GeminiImageHandler(c, info, resp)
- }
- return gemini.GeminiChatHandler(c, info, resp)
- }
- case RequestModeLlama:
- return openai.OpenaiHandler(c, info, resp)
- }
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- var modelList []string
- for i, s := range ModelList {
- modelList = append(modelList, s)
- ModelList[i] = s
- }
- for i, s := range claude.ModelList {
- modelList = append(modelList, s)
- claude.ModelList[i] = s
- }
- for i, s := range gemini.ModelList {
- modelList = append(modelList, s)
- gemini.ModelList[i] = s
- }
- return modelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/vertex/constants.go b/new-api/relay/channel/vertex/constants.go
deleted file mode 100644
index e47a43a2ebe209d9316c17fae4d696c8e2a4bace..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/vertex/constants.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package vertex
-
-var ModelList = []string{
- //"claude-3-sonnet-20240229",
- //"claude-3-opus-20240229",
- //"claude-3-haiku-20240307",
- //"claude-3-5-sonnet-20240620",
-
- //"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
- //"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
-
- "meta/llama3-405b-instruct-maas",
-}
-
-var ChannelName = "vertex-ai"
diff --git a/new-api/relay/channel/vertex/dto.go b/new-api/relay/channel/vertex/dto.go
deleted file mode 100644
index 97b4d35103a3bda7d9cd2f7cb27a07b23cc00007..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/vertex/dto.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package vertex
-
-import (
- "one-api/dto"
-)
-
-type VertexAIClaudeRequest struct {
- AnthropicVersion string `json:"anthropic_version"`
- Messages []dto.ClaudeMessage `json:"messages"`
- System any `json:"system,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- StopSequences []string `json:"stop_sequences,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Tools any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- Thinking *dto.Thinking `json:"thinking,omitempty"`
-}
-
-func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
- return &VertexAIClaudeRequest{
- AnthropicVersion: version,
- System: req.System,
- Messages: req.Messages,
- MaxTokens: req.MaxTokens,
- Stream: req.Stream,
- Temperature: req.Temperature,
- TopP: req.TopP,
- TopK: req.TopK,
- StopSequences: req.StopSequences,
- Tools: req.Tools,
- ToolChoice: req.ToolChoice,
- Thinking: req.Thinking,
- }
-}
diff --git a/new-api/relay/channel/vertex/relay-vertex.go b/new-api/relay/channel/vertex/relay-vertex.go
deleted file mode 100644
index 14b2ca857bc369280bf0a67f56702494a12f04d8..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/vertex/relay-vertex.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package vertex
-
-import "one-api/common"
-
-func GetModelRegion(other string, localModelName string) string {
- // if other is json string
- if common.IsJsonObject(other) {
- m, err := common.StrToMap(other)
- if err != nil {
- return other // return original if parsing fails
- }
- if m[localModelName] != nil {
- return m[localModelName].(string)
- } else {
- if v, ok := m["default"]; ok {
- return v.(string)
- }
- return "global"
- }
- }
- return other
-}
diff --git a/new-api/relay/channel/vertex/service_account.go b/new-api/relay/channel/vertex/service_account.go
deleted file mode 100644
index de52e37f4cd012efd7b5023e8aef60189c4f1632..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/vertex/service_account.go
+++ /dev/null
@@ -1,182 +0,0 @@
-package vertex
-
-import (
- "crypto/rsa"
- "crypto/x509"
- "encoding/json"
- "encoding/pem"
- "errors"
- "net/http"
- "net/url"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "strings"
-
- "github.com/bytedance/gopkg/cache/asynccache"
- "github.com/golang-jwt/jwt"
-
- "fmt"
- "time"
-)
-
-type Credentials struct {
- ProjectID string `json:"project_id"`
- PrivateKeyID string `json:"private_key_id"`
- PrivateKey string `json:"private_key"`
- ClientEmail string `json:"client_email"`
- ClientID string `json:"client_id"`
-}
-
-var Cache = asynccache.NewAsyncCache(asynccache.Options{
- RefreshDuration: time.Minute * 35,
- EnableExpire: true,
- ExpireDuration: time.Minute * 30,
- Fetcher: func(key string) (interface{}, error) {
- return nil, errors.New("not found")
- },
-})
-
-func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
- var cacheKey string
- if info.ChannelIsMultiKey {
- cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
- } else {
- cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
- }
- val, err := Cache.Get(cacheKey)
- if err == nil {
- return val.(string), nil
- }
-
- signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
- if err != nil {
- return "", fmt.Errorf("failed to create signed JWT: %w", err)
- }
- newToken, err := exchangeJwtForAccessToken(signedJWT, info)
- if err != nil {
- return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
- }
- if err := Cache.SetDefault(cacheKey, newToken); err {
- return newToken, nil
- }
- return newToken, nil
-}
-
-func createSignedJWT(email, privateKeyPEM string) (string, error) {
-
- privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
- privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
- privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
- privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
- privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
-
- block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
- if block == nil {
- return "", fmt.Errorf("failed to parse PEM block containing the private key")
- }
-
- privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
- if err != nil {
- return "", err
- }
-
- rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
- if !ok {
- return "", fmt.Errorf("not an RSA private key")
- }
-
- now := time.Now()
- claims := jwt.MapClaims{
- "iss": email,
- "scope": "https://www.googleapis.com/auth/cloud-platform",
- "aud": "https://www.googleapis.com/oauth2/v4/token",
- "exp": now.Add(time.Minute * 35).Unix(),
- "iat": now.Unix(),
- }
-
- token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
- signedToken, err := token.SignedString(rsaPrivateKey)
- if err != nil {
- return "", err
- }
-
- return signedToken, nil
-}
-
-func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
-
- authURL := "https://www.googleapis.com/oauth2/v4/token"
- data := url.Values{}
- data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
- data.Set("assertion", signedJWT)
-
- var client *http.Client
- var err error
- if info.ChannelSetting.Proxy != "" {
- client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
- if err != nil {
- return "", fmt.Errorf("new proxy http client failed: %w", err)
- }
- } else {
- client = service.GetHttpClient()
- }
-
- resp, err := client.PostForm(authURL, data)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- var result map[string]interface{}
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- return "", err
- }
-
- if accessToken, ok := result["access_token"].(string); ok {
- return accessToken, nil
- }
-
- return "", fmt.Errorf("failed to get access token: %v", result)
-}
-
-func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
- signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
- if err != nil {
- return "", fmt.Errorf("failed to create signed JWT: %w", err)
- }
- return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
-}
-
-func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
- authURL := "https://www.googleapis.com/oauth2/v4/token"
- data := url.Values{}
- data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
- data.Set("assertion", signedJWT)
-
- var client *http.Client
- var err error
- if proxy != "" {
- client, err = service.NewProxyHttpClient(proxy)
- if err != nil {
- return "", fmt.Errorf("new proxy http client failed: %w", err)
- }
- } else {
- client = service.GetHttpClient()
- }
-
- resp, err := client.PostForm(authURL, data)
- if err != nil {
- return "", err
- }
- defer resp.Body.Close()
-
- var result map[string]interface{}
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- return "", err
- }
-
- if accessToken, ok := result["access_token"].(string); ok {
- return accessToken, nil
- }
- return "", fmt.Errorf("failed to get access token: %v", result)
-}
diff --git a/new-api/relay/channel/volcengine/adaptor.go b/new-api/relay/channel/volcengine/adaptor.go
deleted file mode 100644
index 0284c9645b29523a7bbde134308a01ffc372711b..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/volcengine/adaptor.go
+++ /dev/null
@@ -1,273 +0,0 @@
-package volcengine
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "net/textproto"
- channelconstant "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/constant"
- "one-api/types"
- "path/filepath"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- adaptor := openai.Adaptor{}
- return adaptor.ConvertClaudeRequest(c, info, req)
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- switch info.RelayMode {
- case constant.RelayModeImagesGenerations:
- return request, nil
- case constant.RelayModeImagesEdits:
-
- var requestBody bytes.Buffer
- writer := multipart.NewWriter(&requestBody)
-
- writer.WriteField("model", request.Model)
- // 获取所有表单字段
- formData := c.Request.PostForm
- // 遍历表单字段并打印输出
- for key, values := range formData {
- if key == "model" {
- continue
- }
- for _, value := range values {
- writer.WriteField(key, value)
- }
- }
-
- // Parse the multipart form to handle both single image and multiple images
- if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
- return nil, errors.New("failed to parse multipart form")
- }
-
- if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
- // Check if "image" field exists in any form, including array notation
- var imageFiles []*multipart.FileHeader
- var exists bool
-
- // First check for standard "image" field
- if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
- // If not found, check for "image[]" field
- if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
- // If still not found, iterate through all fields to find any that start with "image["
- foundArrayImages := false
- for fieldName, files := range c.Request.MultipartForm.File {
- if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
- foundArrayImages = true
- for _, file := range files {
- imageFiles = append(imageFiles, file)
- }
- }
- }
-
- // If no image fields found at all
- if !foundArrayImages && (len(imageFiles) == 0) {
- return nil, errors.New("image is required")
- }
- }
- }
-
- // Process all image files
- for i, fileHeader := range imageFiles {
- file, err := fileHeader.Open()
- if err != nil {
- return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
- }
- defer file.Close()
-
- // If multiple images, use image[] as the field name
- fieldName := "image"
- if len(imageFiles) > 1 {
- fieldName = "image[]"
- }
-
- // Determine MIME type based on file extension
- mimeType := detectImageMimeType(fileHeader.Filename)
-
- // Create a form file with the appropriate content type
- h := make(textproto.MIMEHeader)
- h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
- h.Set("Content-Type", mimeType)
-
- part, err := writer.CreatePart(h)
- if err != nil {
- return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
- }
-
- if _, err := io.Copy(part, file); err != nil {
- return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
- }
- }
-
- // Handle mask file if present
- if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
- maskFile, err := maskFiles[0].Open()
- if err != nil {
- return nil, errors.New("failed to open mask file")
- }
- defer maskFile.Close()
-
- // Determine MIME type for mask file
- mimeType := detectImageMimeType(maskFiles[0].Filename)
-
- // Create a form file with the appropriate content type
- h := make(textproto.MIMEHeader)
- h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
- h.Set("Content-Type", mimeType)
-
- maskPart, err := writer.CreatePart(h)
- if err != nil {
- return nil, errors.New("create form file failed for mask")
- }
-
- if _, err := io.Copy(maskPart, maskFile); err != nil {
- return nil, errors.New("copy mask file failed")
- }
- }
- } else {
- return nil, errors.New("no multipart form data found")
- }
-
- // 关闭 multipart 编写器以设置分界线
- writer.Close()
- c.Request.Header.Set("Content-Type", writer.FormDataContentType())
- return bytes.NewReader(requestBody.Bytes()), nil
-
- default:
- return request, nil
- }
-}
-
-// detectImageMimeType determines the MIME type based on the file extension
-func detectImageMimeType(filename string) string {
- ext := strings.ToLower(filepath.Ext(filename))
- switch ext {
- case ".jpg", ".jpeg":
- return "image/jpeg"
- case ".png":
- return "image/png"
- case ".webp":
- return "image/webp"
- default:
- // Try to detect from extension if possible
- if strings.HasPrefix(ext, ".jp") {
- return "image/jpeg"
- }
- // Default to png as a fallback
- return "image/png"
- }
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- // 支持自定义域名,如果未设置则使用默认域名
- baseUrl := info.ChannelBaseUrl
- if baseUrl == "" {
- baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
- }
-
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- if strings.HasPrefix(info.UpstreamModelName, "bot") {
- return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
- }
- return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
- default:
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
- if strings.HasPrefix(info.UpstreamModelName, "bot") {
- return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
- }
- return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
- case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
- case constant.RelayModeImagesGenerations:
- return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
- case constant.RelayModeImagesEdits:
- return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
- case constant.RelayModeRerank:
- return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
- default:
- }
- }
- return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- // 适配 方舟deepseek混合模型 的 thinking 后缀
- if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- request.Model = info.UpstreamModelName
- request.THINKING = json.RawMessage(`{"type": "enabled"}`)
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- adaptor := openai.Adaptor{}
- usage, err = adaptor.DoResponse(c, resp, info)
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/volcengine/constants.go b/new-api/relay/channel/volcengine/constants.go
deleted file mode 100644
index 7edd9b30e6fc1f7b82d61f8d71bb9b3f82fcb5c0..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/volcengine/constants.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package volcengine
-
-var ModelList = []string{
- "Doubao-pro-128k",
- "Doubao-pro-32k",
- "Doubao-pro-4k",
- "Doubao-lite-128k",
- "Doubao-lite-32k",
- "Doubao-lite-4k",
- "Doubao-embedding",
- "doubao-seedream-4-0-250828",
- "seedream-4-0-250828",
- "doubao-seedance-1-0-pro-250528",
- "seedance-1-0-pro-250528",
- "doubao-seed-1-6-thinking-250715",
- "seed-1-6-thinking-250715",
-}
-
-var ChannelName = "volcengine"
diff --git a/new-api/relay/channel/xai/adaptor.go b/new-api/relay/channel/xai/adaptor.go
deleted file mode 100644
index e16d293585a98b644a221f45b98270ec80cc3ab5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xai/adaptor.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package xai
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/types"
- "strings"
-
- "one-api/relay/constant"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- //panic("implement me")
- return nil, errors.New("not available")
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //not available
- return nil, errors.New("not available")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- xaiRequest := ImageRequest{
- Model: request.Model,
- Prompt: request.Prompt,
- N: int(request.N),
- ResponseFormat: request.ResponseFormat,
- }
- return xaiRequest, nil
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if strings.HasSuffix(info.UpstreamModelName, "-search") {
- info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
- request.Model = info.UpstreamModelName
- toMap := request.ToMap()
- toMap["search_parameters"] = map[string]any{
- "mode": "on",
- }
- return toMap, nil
- }
- if strings.HasPrefix(request.Model, "grok-3-mini") {
- if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
- request.MaxCompletionTokens = request.MaxTokens
- request.MaxTokens = 0
- }
- if strings.HasSuffix(request.Model, "-high") {
- request.ReasoningEffort = "high"
- request.Model = strings.TrimSuffix(request.Model, "-high")
- } else if strings.HasSuffix(request.Model, "-low") {
- request.ReasoningEffort = "low"
- request.Model = strings.TrimSuffix(request.Model, "-low")
- } else if strings.HasSuffix(request.Model, "-medium") {
- request.ReasoningEffort = "medium"
- request.Model = strings.TrimSuffix(request.Model, "-medium")
- }
- info.ReasoningEffort = request.ReasoningEffort
- info.UpstreamModelName = request.Model
- }
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //not available
- return nil, errors.New("not available")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayMode {
- case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
- usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
- default:
- if info.IsStream {
- usage, err = xAIStreamHandler(c, info, resp)
- } else {
- usage, err = xAIHandler(c, info, resp)
- }
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/xai/constants.go b/new-api/relay/channel/xai/constants.go
deleted file mode 100644
index e66827dbe18249a51a8d5f713c0ba1766b5f50f5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xai/constants.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package xai
-
-var ModelList = []string{
- // grok-4
- "grok-4", "grok-4-0709", "grok-4-0709-search",
- // grok-3
- "grok-3-beta", "grok-3-mini-beta",
- // grok-3 mini
- "grok-3-fast-beta", "grok-3-mini-fast-beta",
- // extend grok-3-mini reasoning
- "grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
- "grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
- // image model
- "grok-2-image",
- // legacy models
- "grok-2", "grok-2-vision",
- "grok-beta", "grok-vision-beta",
-}
-
-var ChannelName = "xai"
diff --git a/new-api/relay/channel/xai/dto.go b/new-api/relay/channel/xai/dto.go
deleted file mode 100644
index 0444dd7e06d11862dd6fda1ab1acc07127d75acf..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xai/dto.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package xai
-
-import "one-api/dto"
-
-// ChatCompletionResponse represents the response from XAI chat completion API
-type ChatCompletionResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []dto.OpenAITextResponseChoice `json:"choices"`
- Usage *dto.Usage `json:"usage"`
- SystemFingerprint string `json:"system_fingerprint"`
-}
-
-// quality, size or style are not supported by xAI API at the moment.
-type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N int `json:"n,omitempty"`
- // Size string `json:"size,omitempty"`
- // Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- // Style string `json:"style,omitempty"`
- // User string `json:"user,omitempty"`
- // ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
-}
diff --git a/new-api/relay/channel/xai/text.go b/new-api/relay/channel/xai/text.go
deleted file mode 100644
index 88d34334a3db4953902f7d16efedad4c5258f80b..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xai/text.go
+++ /dev/null
@@ -1,107 +0,0 @@
-package xai
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
- if xAIResp == nil {
- return nil
- }
- if xAIResp.Usage != nil {
- xAIResp.Usage.CompletionTokens = usage.CompletionTokens
- }
- openAIResp := &dto.ChatCompletionsStreamResponse{
- Id: xAIResp.Id,
- Object: xAIResp.Object,
- Created: xAIResp.Created,
- Model: xAIResp.Model,
- Choices: xAIResp.Choices,
- Usage: xAIResp.Usage,
- }
-
- return openAIResp
-}
-
-func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- usage := &dto.Usage{}
- var responseTextBuilder strings.Builder
- var toolCount int
- var containStreamUsage bool
-
- helper.SetEventStreamHeaders(c)
-
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var xAIResp *dto.ChatCompletionsStreamResponse
- err := json.Unmarshal([]byte(data), &xAIResp)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
-
- // 把 xAI 的usage转换为 OpenAI 的usage
- if xAIResp.Usage != nil {
- containStreamUsage = true
- usage.PromptTokens = xAIResp.Usage.PromptTokens
- usage.TotalTokens = xAIResp.Usage.TotalTokens
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
- }
-
- openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
- _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
- err = helper.ObjectData(c, openaiResponse)
- if err != nil {
- common.SysLog(err.Error())
- }
- return true
- })
-
- if !containStreamUsage {
- usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
-
- helper.Done(c)
- service.CloseResponseBodyGracefully(resp)
- return usage, nil
-}
-
-func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
-
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- var xaiResponse ChatCompletionResponse
- err = common.Unmarshal(responseBody, &xaiResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- if xaiResponse.Usage != nil {
- xaiResponse.Usage.CompletionTokens = xaiResponse.Usage.TotalTokens - xaiResponse.Usage.PromptTokens
- xaiResponse.Usage.CompletionTokenDetails.TextTokens = xaiResponse.Usage.CompletionTokens - xaiResponse.Usage.CompletionTokenDetails.ReasoningTokens
- }
-
- // new body
- encodeJson, err := common.Marshal(xaiResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
-
- service.IOCopyBytesGracefully(c, resp, encodeJson)
-
- return xaiResponse.Usage, nil
-}
diff --git a/new-api/relay/channel/xinference/constant.go b/new-api/relay/channel/xinference/constant.go
deleted file mode 100644
index 5815af253e3803a1cbe2d3cc1e38dfd1fa709b28..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xinference/constant.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package xinference
-
-var ModelList = []string{
- "bge-reranker-v2-m3",
- "jina-reranker-v2",
-}
-
-var ChannelName = "xinference"
diff --git a/new-api/relay/channel/xinference/dto.go b/new-api/relay/channel/xinference/dto.go
deleted file mode 100644
index 6eb63ddb6476222488cd462c9e36f040235f3ab4..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xinference/dto.go
+++ /dev/null
@@ -1,11 +0,0 @@
-package xinference
-
-type XinRerankResponseDocument struct {
- Document any `json:"document,omitempty"`
- Index int `json:"index"`
- RelevanceScore float64 `json:"relevance_score"`
-}
-
-type XinRerankResponse struct {
- Results []XinRerankResponseDocument `json:"results"`
-}
diff --git a/new-api/relay/channel/xunfei/adaptor.go b/new-api/relay/channel/xunfei/adaptor.go
deleted file mode 100644
index 6004a282ae289165e32fc77a5e885d245a0a7d65..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xunfei/adaptor.go
+++ /dev/null
@@ -1,104 +0,0 @@
-package xunfei
-
-import (
- "errors"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
- request *dto.GeneralOpenAIRequest
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return "", nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- a.request = request
- return request, nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- // xunfei's request is not http request, so we don't need to do anything here
- dummyResp := &http.Response{}
- dummyResp.StatusCode = http.StatusOK
- return dummyResp, nil
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- splits := strings.Split(info.ApiKey, "|")
- if len(splits) != 3 {
- return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
- }
- if a.request == nil {
- return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
- }
- if info.IsStream {
- usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
- } else {
- usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/xunfei/constants.go b/new-api/relay/channel/xunfei/constants.go
deleted file mode 100644
index 5095185d13b4c3a2485a513d7c0bd2ce4a5bbbbc..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xunfei/constants.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package xunfei
-
-var ModelList = []string{
- "SparkDesk",
- "SparkDesk-v1.1",
- "SparkDesk-v2.1",
- "SparkDesk-v3.1",
- "SparkDesk-v3.5",
- "SparkDesk-v4.0",
-}
-
-var ChannelName = "xunfei"
diff --git a/new-api/relay/channel/xunfei/dto.go b/new-api/relay/channel/xunfei/dto.go
deleted file mode 100644
index 41086aed043a56962410a513c164f5398aa5eff4..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xunfei/dto.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package xunfei
-
-import "one-api/dto"
-
-type XunfeiMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type XunfeiChatRequest struct {
- Header struct {
- AppId string `json:"app_id"`
- } `json:"header"`
- Parameter struct {
- Chat struct {
- Domain string `json:"domain,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopK int `json:"top_k,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- Auditing bool `json:"auditing,omitempty"`
- } `json:"chat"`
- } `json:"parameter"`
- Payload struct {
- Message struct {
- Text []XunfeiMessage `json:"text"`
- } `json:"message"`
- } `json:"payload"`
-}
-
-type XunfeiChatResponseTextItem struct {
- Content string `json:"content"`
- Role string `json:"role"`
- Index int `json:"index"`
-}
-
-type XunfeiChatResponse struct {
- Header struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Sid string `json:"sid"`
- Status int `json:"status"`
- } `json:"header"`
- Payload struct {
- Choices struct {
- Status int `json:"status"`
- Seq int `json:"seq"`
- Text []XunfeiChatResponseTextItem `json:"text"`
- } `json:"choices"`
- Usage struct {
- //Text struct {
- // QuestionTokens string `json:"question_tokens"`
- // PromptTokens string `json:"prompt_tokens"`
- // CompletionTokens string `json:"completion_tokens"`
- // TotalTokens string `json:"total_tokens"`
- //} `json:"text"`
- Text dto.Usage `json:"text"`
- } `json:"usage"`
- } `json:"payload"`
-}
diff --git a/new-api/relay/channel/xunfei/relay-xunfei.go b/new-api/relay/channel/xunfei/relay-xunfei.go
deleted file mode 100644
index 1f1fc6efd56458eeb29169fdc07d95fcd058e316..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/xunfei/relay-xunfei.go
+++ /dev/null
@@ -1,290 +0,0 @@
-package xunfei
-
-import (
- "crypto/hmac"
- "crypto/sha256"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/url"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/helper"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-// https://console.xfyun.cn/services/cbm
-// https://www.xfyun.cn/doc/spark/Web.html
-
-func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
- messages := make([]XunfeiMessage, 0, len(request.Messages))
- shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
- for _, message := range request.Messages {
- if message.Role == "system" && shouldCovertSystemMessage {
- messages = append(messages, XunfeiMessage{
- Role: "user",
- Content: message.StringContent(),
- })
- messages = append(messages, XunfeiMessage{
- Role: "assistant",
- Content: "Okay",
- })
- } else {
- messages = append(messages, XunfeiMessage{
- Role: message.Role,
- Content: message.StringContent(),
- })
- }
- }
- xunfeiRequest := XunfeiChatRequest{}
- xunfeiRequest.Header.AppId = xunfeiAppId
- xunfeiRequest.Parameter.Chat.Domain = domain
- xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
- xunfeiRequest.Parameter.Chat.TopK = request.N
- xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
- xunfeiRequest.Payload.Message.Text = messages
- return &xunfeiRequest
-}
-
-func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
- if len(response.Payload.Choices.Text) == 0 {
- response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
- {
- Content: "",
- },
- }
- }
- choice := dto.OpenAITextResponseChoice{
- Index: 0,
- Message: dto.Message{
- Role: "assistant",
- Content: response.Payload.Choices.Text[0].Content,
- },
- FinishReason: constant.FinishReasonStop,
- }
- fullTextResponse := dto.OpenAITextResponse{
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: []dto.OpenAITextResponseChoice{choice},
- Usage: response.Payload.Usage.Text,
- }
- return &fullTextResponse
-}
-
-func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
- if len(xunfeiResponse.Payload.Choices.Text) == 0 {
- xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
- {
- Content: "",
- },
- }
- }
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
- if xunfeiResponse.Payload.Choices.Status == 2 {
- choice.FinishReason = &constant.FinishReasonStop
- }
- response := dto.ChatCompletionsStreamResponse{
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "SparkDesk",
- Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
- }
- return &response
-}
-
-func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
- HmacWithShaToBase64 := func(algorithm, data, key string) string {
- mac := hmac.New(sha256.New, []byte(key))
- mac.Write([]byte(data))
- encodeData := mac.Sum(nil)
- return base64.StdEncoding.EncodeToString(encodeData)
- }
- ul, err := url.Parse(hostUrl)
- if err != nil {
- fmt.Println(err)
- }
- date := time.Now().UTC().Format(time.RFC1123)
- signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
- sign := strings.Join(signString, "\n")
- sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
- authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
- "hmac-sha256", "host date request-line", sha)
- authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
- v := url.Values{}
- v.Add("host", ul.Host)
- v.Add("date", date)
- v.Add("authorization", authorization)
- callUrl := hostUrl + "?" + v.Encode()
- return callUrl
-}
-
-func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
- dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
- }
- helper.SetEventStreamHeaders(c)
- var usage dto.Usage
- c.Stream(func(w io.Writer) bool {
- select {
- case xunfeiResponse := <-dataChan:
- usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
- usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
- usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
- response := streamResponseXunfei2OpenAI(&xunfeiResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- return &usage, nil
-}
-
-func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
- domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
- dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
- }
- var usage dto.Usage
- var content string
- var xunfeiResponse XunfeiChatResponse
- stop := false
- for !stop {
- select {
- case xunfeiResponse = <-dataChan:
- if len(xunfeiResponse.Payload.Choices.Text) == 0 {
- continue
- }
- content += xunfeiResponse.Payload.Choices.Text[0].Content
- usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
- usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
- usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
- case stop = <-stopChan:
- }
- }
- if len(xunfeiResponse.Payload.Choices.Text) == 0 {
- xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
- {
- Content: "",
- },
- }
- }
- xunfeiResponse.Payload.Choices.Text[0].Content = content
-
- response := responseXunfei2OpenAI(&xunfeiResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- _, _ = c.Writer.Write(jsonResponse)
- return &usage, nil
-}
-
-func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
- d := websocket.Dialer{
- HandshakeTimeout: 5 * time.Second,
- }
- conn, resp, err := d.Dial(authUrl, nil)
- if err != nil || resp.StatusCode != 101 {
- return nil, nil, err
- }
-
- data := requestOpenAI2Xunfei(textRequest, appId, domain)
- err = conn.WriteJSON(data)
- if err != nil {
- return nil, nil, err
- }
-
- dataChan := make(chan XunfeiChatResponse)
- stopChan := make(chan bool)
- go func() {
- defer func() {
- conn.Close()
- }()
- for {
- _, msg, err := conn.ReadMessage()
- if err != nil {
- common.SysLog("error reading stream response: " + err.Error())
- break
- }
- var response XunfeiChatResponse
- err = json.Unmarshal(msg, &response)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- break
- }
- dataChan <- response
- if response.Payload.Choices.Status == 2 {
- if err != nil {
- common.SysLog("error closing websocket connection: " + err.Error())
- }
- break
- }
- }
- stopChan <- true
- }()
-
- return dataChan, stopChan, nil
-}
-
-func apiVersion2domain(apiVersion string) string {
- switch apiVersion {
- case "v1.1":
- return "lite"
- case "v2.1":
- return "generalv2"
- case "v3.1":
- return "generalv3"
- case "v3.5":
- return "generalv3.5"
- case "v4.0":
- return "4.0Ultra"
- }
- return "general" + apiVersion
-}
-
-func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
- apiVersion := getAPIVersion(c, modelName)
- domain := apiVersion2domain(apiVersion)
- authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
- return domain, authUrl
-}
-
-func getAPIVersion(c *gin.Context, modelName string) string {
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion != "" {
- return apiVersion
- }
- parts := strings.Split(modelName, "-")
- if len(parts) == 2 {
- apiVersion = parts[1]
- return apiVersion
-
- }
- apiVersion = c.GetString("api_version")
- if apiVersion != "" {
- return apiVersion
- }
- apiVersion = "v1.1"
- common.SysLog("api_version not found, using default: " + apiVersion)
- return apiVersion
-}
diff --git a/new-api/relay/channel/zhipu/adaptor.go b/new-api/relay/channel/zhipu/adaptor.go
deleted file mode 100644
index 4899251898df46ee049af141b0f1d4206e1204b2..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu/adaptor.go
+++ /dev/null
@@ -1,101 +0,0 @@
-package zhipu
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- method := "invoke"
- if info.IsStream {
- method = "sse-invoke"
- }
- return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- token := getZhipuToken(info.ApiKey)
- req.Set("Authorization", token)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if request.TopP >= 1 {
- request.TopP = 0.99
- }
- return requestOpenAI2Zhipu(*request), nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- if info.IsStream {
- usage, err = zhipuStreamHandler(c, info, resp)
- } else {
- usage, err = zhipuHandler(c, info, resp)
- }
- return
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/zhipu/constants.go b/new-api/relay/channel/zhipu/constants.go
deleted file mode 100644
index fd888cc2ee07083d27a56499f8d3f1fedb613fcd..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package zhipu
-
-var ModelList = []string{
- "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
-}
-
-var ChannelName = "zhipu"
diff --git a/new-api/relay/channel/zhipu/dto.go b/new-api/relay/channel/zhipu/dto.go
deleted file mode 100644
index 13e8a0ac51d9ea0e71930f49b2a9a27eb0de4aff..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu/dto.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package zhipu
-
-import (
- "one-api/dto"
- "time"
-)
-
-type ZhipuMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
-}
-
-type ZhipuRequest struct {
- Prompt []ZhipuMessage `json:"prompt"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- RequestId string `json:"request_id,omitempty"`
- Incremental bool `json:"incremental,omitempty"`
-}
-
-type ZhipuResponseData struct {
- TaskId string `json:"task_id"`
- RequestId string `json:"request_id"`
- TaskStatus string `json:"task_status"`
- Choices []ZhipuMessage `json:"choices"`
- dto.Usage `json:"usage"`
-}
-
-type ZhipuResponse struct {
- Code int `json:"code"`
- Msg string `json:"msg"`
- Success bool `json:"success"`
- Data ZhipuResponseData `json:"data"`
-}
-
-type ZhipuStreamMetaResponse struct {
- RequestId string `json:"request_id"`
- TaskId string `json:"task_id"`
- TaskStatus string `json:"task_status"`
- dto.Usage `json:"usage"`
-}
-
-type zhipuTokenData struct {
- Token string
- ExpiryTime time.Time
-}
diff --git a/new-api/relay/channel/zhipu/relay-zhipu.go b/new-api/relay/channel/zhipu/relay-zhipu.go
deleted file mode 100644
index 09c825cec882caa45e21145d28758f1b9388d5a7..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu/relay-zhipu.go
+++ /dev/null
@@ -1,246 +0,0 @@
-package zhipu
-
-import (
- "bufio"
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "sync"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/golang-jwt/jwt"
-)
-
-// https://open.bigmodel.cn/doc/api#chatglm_std
-// chatglm_std, chatglm_lite
-// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
-// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
-
-var zhipuTokens sync.Map
-var expSeconds int64 = 24 * 3600
-
-func getZhipuToken(apikey string) string {
- data, ok := zhipuTokens.Load(apikey)
- if ok {
- tokenData := data.(zhipuTokenData)
- if time.Now().Before(tokenData.ExpiryTime) {
- return tokenData.Token
- }
- }
-
- split := strings.Split(apikey, ".")
- if len(split) != 2 {
- common.SysLog("invalid zhipu key: " + apikey)
- return ""
- }
-
- id := split[0]
- secret := split[1]
-
- expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
- expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
-
- timestamp := time.Now().UnixNano() / 1e6
-
- payload := jwt.MapClaims{
- "api_key": id,
- "exp": expMillis,
- "timestamp": timestamp,
- }
-
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
-
- token.Header["alg"] = "HS256"
- token.Header["sign_type"] = "SIGN"
-
- tokenString, err := token.SignedString([]byte(secret))
- if err != nil {
- return ""
- }
-
- zhipuTokens.Store(apikey, zhipuTokenData{
- Token: tokenString,
- ExpiryTime: expiryTime,
- })
-
- return tokenString
-}
-
-func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
- messages := make([]ZhipuMessage, 0, len(request.Messages))
- for _, message := range request.Messages {
- if message.Role == "system" {
- messages = append(messages, ZhipuMessage{
- Role: "system",
- Content: message.StringContent(),
- })
- messages = append(messages, ZhipuMessage{
- Role: "user",
- Content: "Okay",
- })
- } else {
- messages = append(messages, ZhipuMessage{
- Role: message.Role,
- Content: message.StringContent(),
- })
- }
- }
- return &ZhipuRequest{
- Prompt: messages,
- Temperature: request.Temperature,
- TopP: request.TopP,
- Incremental: false,
- }
-}
-
-func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
- fullTextResponse := dto.OpenAITextResponse{
- Id: response.Data.TaskId,
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
- Usage: response.Data.Usage,
- }
- for i, choice := range response.Data.Choices {
- openaiChoice := dto.OpenAITextResponseChoice{
- Index: i,
- Message: dto.Message{
- Role: choice.Role,
- Content: strings.Trim(choice.Content, "\""),
- },
- FinishReason: "",
- }
- if i == len(response.Data.Choices)-1 {
- openaiChoice.FinishReason = "stop"
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
- }
- return &fullTextResponse
-}
-
-func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString(zhipuResponse)
- response := dto.ChatCompletionsStreamResponse{
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "chatglm",
- Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
- }
- return &response
-}
-
-func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
- var choice dto.ChatCompletionsStreamResponseChoice
- choice.Delta.SetContentString("")
- choice.FinishReason = &constant.FinishReasonStop
- response := dto.ChatCompletionsStreamResponse{
- Id: zhipuResponse.RequestId,
- Object: "chat.completion.chunk",
- Created: common.GetTimestamp(),
- Model: "chatglm",
- Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
- }
- return &response, &zhipuResponse.Usage
-}
-
-func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var usage *dto.Usage
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(bufio.ScanLines)
- dataChan := make(chan string)
- metaChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- lines := strings.Split(data, "\n")
- for i, line := range lines {
- if len(line) < 5 {
- continue
- }
- if line[:5] == "data:" {
- dataChan <- line[5:]
- if i != len(lines)-1 {
- dataChan <- "\n"
- }
- } else if line[:5] == "meta:" {
- metaChan <- line[5:]
- }
- }
- }
- stopChan <- true
- }()
- helper.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- response := streamResponseZhipu2OpenAI(data)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case data := <-metaChan:
- var zhipuResponse ZhipuStreamMetaResponse
- err := json.Unmarshal([]byte(data), &zhipuResponse)
- if err != nil {
- common.SysLog("error unmarshalling stream response: " + err.Error())
- return true
- }
- response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysLog("error marshalling stream response: " + err.Error())
- return true
- }
- usage = zhipuUsage
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
- }
- })
- service.CloseResponseBodyGracefully(resp)
- return usage, nil
-}
-
-func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- var zhipuResponse ZhipuResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- err = json.Unmarshal(responseBody, &zhipuResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if !zhipuResponse.Success {
- return nil, types.WithOpenAIError(types.OpenAIError{
- Message: zhipuResponse.Msg,
- Code: zhipuResponse.Code,
- }, resp.StatusCode)
- }
- fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return &fullTextResponse.Usage, nil
-}
diff --git a/new-api/relay/channel/zhipu_4v/adaptor.go b/new-api/relay/channel/zhipu_4v/adaptor.go
deleted file mode 100644
index 2a426c0913815fb644b0cad00b6699aa22c98cf5..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu_4v/adaptor.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package zhipu_4v
-
-import (
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/dto"
- "one-api/relay/channel"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-type Adaptor struct {
-}
-
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
- return req, nil
-}
-
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil
- default:
- switch info.RelayMode {
- case relayconstant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil
- default:
- return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil
- }
- }
-}
-
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- return nil
-}
-
-func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
- if request == nil {
- return nil, errors.New("request is nil")
- }
- if request.TopP >= 1 {
- request.TopP = 0.99
- }
- return requestOpenAI2Zhipu(*request), nil
-}
-
-func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, nil
-}
-
-func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return request, nil
-}
-
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
-}
-
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
-}
-
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
- switch info.RelayFormat {
- case types.RelayFormatClaude:
- if info.IsStream {
- return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
- } else {
- return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
- }
- default:
- adaptor := openai.Adaptor{}
- return adaptor.DoResponse(c, resp, info)
- }
-}
-
-func (a *Adaptor) GetModelList() []string {
- return ModelList
-}
-
-func (a *Adaptor) GetChannelName() string {
- return ChannelName
-}
diff --git a/new-api/relay/channel/zhipu_4v/constants.go b/new-api/relay/channel/zhipu_4v/constants.go
deleted file mode 100644
index 64e02401d828b67d060d7b502f8f313962160f7f..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu_4v/constants.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package zhipu_4v
-
-var ModelList = []string{
- "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus",
-}
-
-var ChannelName = "zhipu_4v"
diff --git a/new-api/relay/channel/zhipu_4v/dto.go b/new-api/relay/channel/zhipu_4v/dto.go
deleted file mode 100644
index bc2c7ffb5a322472efa49b244e6e8790d62ce9ef..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu_4v/dto.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package zhipu_4v
-
-import (
- "one-api/dto"
- "time"
-)
-
-// type ZhipuMessage struct {
-// Role string `json:"role,omitempty"`
-// Content string `json:"content,omitempty"`
-// ToolCalls any `json:"tool_calls,omitempty"`
-// ToolCallId any `json:"tool_call_id,omitempty"`
-// }
-//
-// type ZhipuRequest struct {
-// Model string `json:"model"`
-// Stream bool `json:"stream,omitempty"`
-// Messages []ZhipuMessage `json:"messages"`
-// Temperature float64 `json:"temperature,omitempty"`
-// TopP float64 `json:"top_p,omitempty"`
-// MaxTokens int `json:"max_tokens,omitempty"`
-// Stop []string `json:"stop,omitempty"`
-// RequestId string `json:"request_id,omitempty"`
-// Tools any `json:"tools,omitempty"`
-// ToolChoice any `json:"tool_choice,omitempty"`
-// }
-//
-// type ZhipuV4TextResponseChoice struct {
-// Index int `json:"index"`
-// ZhipuMessage `json:"message"`
-// FinishReason string `json:"finish_reason"`
-// }
-type ZhipuV4Response struct {
- Id string `json:"id"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"`
- Usage dto.Usage `json:"usage"`
- Error dto.OpenAIError `json:"error"`
-}
-
-//
-//type ZhipuV4StreamResponseChoice struct {
-// Index int `json:"index,omitempty"`
-// Delta ZhipuMessage `json:"delta"`
-// FinishReason *string `json:"finish_reason,omitempty"`
-//}
-
-type ZhipuV4StreamResponse struct {
- Id string `json:"id"`
- Created int64 `json:"created"`
- Choices []dto.ChatCompletionsStreamResponseChoice `json:"choices"`
- Usage dto.Usage `json:"usage"`
-}
-
-type tokenData struct {
- Token string
- ExpiryTime time.Time
-}
diff --git a/new-api/relay/channel/zhipu_4v/relay-zhipu_v4.go b/new-api/relay/channel/zhipu_4v/relay-zhipu_v4.go
deleted file mode 100644
index baafeb6fda47fe1c6054a95d1e64da771900af78..0000000000000000000000000000000000000000
--- a/new-api/relay/channel/zhipu_4v/relay-zhipu_v4.go
+++ /dev/null
@@ -1,55 +0,0 @@
-package zhipu_4v
-
-import (
- "one-api/dto"
- "strings"
-)
-
-func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
- messages := make([]dto.Message, 0, len(request.Messages))
- for _, message := range request.Messages {
- if !message.IsStringContent() {
- mediaMessages := message.ParseContent()
- for j, mediaMessage := range mediaMessages {
- if mediaMessage.Type == dto.ContentTypeImageURL {
- imageUrl := mediaMessage.GetImageMedia()
- // check if base64
- if strings.HasPrefix(imageUrl.Url, "data:image/") {
- // 去除base64数据的URL前缀(如果有)
- if idx := strings.Index(imageUrl.Url, ","); idx != -1 {
- imageUrl.Url = imageUrl.Url[idx+1:]
- }
- }
- mediaMessage.ImageUrl = imageUrl
- mediaMessages[j] = mediaMessage
- }
- }
- message.SetMediaContent(mediaMessages)
- }
- messages = append(messages, dto.Message{
- Role: message.Role,
- Content: message.Content,
- ToolCalls: message.ToolCalls,
- ToolCallId: message.ToolCallId,
- })
- }
- str, ok := request.Stop.(string)
- var Stop []string
- if ok {
- Stop = []string{str}
- } else {
- Stop, _ = request.Stop.([]string)
- }
- return &dto.GeneralOpenAIRequest{
- Model: request.Model,
- Stream: request.Stream,
- Messages: messages,
- Temperature: request.Temperature,
- TopP: request.TopP,
- MaxTokens: request.GetMaxTokens(),
- Stop: Stop,
- Tools: request.Tools,
- ToolChoice: request.ToolChoice,
- THINKING: request.THINKING,
- }
-}
diff --git a/new-api/relay/claude_handler.go b/new-api/relay/claude_handler.go
deleted file mode 100644
index 05093bddd162d4ef6c25969d65cf44e806ac3dec..0000000000000000000000000000000000000000
--- a/new-api/relay/claude_handler.go
+++ /dev/null
@@ -1,157 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
-
- info.InitChannelMeta(c)
-
- claudeReq, ok := info.Request.(*dto.ClaudeRequest)
-
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(claudeReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- if request.MaxTokens == 0 {
- request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
- }
-
- if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
- strings.HasSuffix(request.Model, "-thinking") {
- if request.Thinking == nil {
- // 因为BudgetTokens 必须大于1024
- if request.MaxTokens < 1280 {
- request.MaxTokens = 1280
- }
-
- // BudgetTokens 为 max_tokens 的 80%
- request.Thinking = &dto.Thinking{
- Type: "enabled",
- BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
- }
- // TODO: 临时处理
- // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
- request.TopP = 0
- request.Temperature = common.GetPointer[float64](1.0)
- }
- request.Model = strings.TrimSuffix(request.Model, "-thinking")
- info.UpstreamModelName = request.Model
- }
-
- if info.ChannelSetting.SystemPrompt != "" {
- if request.System == nil {
- request.SetStringSystem(info.ChannelSetting.SystemPrompt)
- } else if info.ChannelSetting.SystemPromptOverride {
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- if request.IsStringSystem() {
- existing := strings.TrimSpace(request.GetStringSystem())
- if existing == "" {
- request.SetStringSystem(info.ChannelSetting.SystemPrompt)
- } else {
- request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
- }
- } else {
- systemContents := request.ParseSystem()
- newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
- newSystem.SetText(info.ChannelSetting.SystemPrompt)
- if len(systemContents) == 0 {
- request.System = []dto.ClaudeMediaMessage{newSystem}
- } else {
- request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
- }
- }
- }
- }
-
- var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- if resp != nil {
- httpResp = resp.(*http.Response)
- info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- //log.Printf("usage: %v", usage)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
-
- service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
- return nil
-}
diff --git a/new-api/relay/common/override.go b/new-api/relay/common/override.go
deleted file mode 100644
index d8ed9bb5a08352249682c8be4bf450dc2ce65d18..0000000000000000000000000000000000000000
--- a/new-api/relay/common/override.go
+++ /dev/null
@@ -1,435 +0,0 @@
-package common
-
-import (
- "encoding/json"
- "fmt"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "regexp"
- "strconv"
- "strings"
-)
-
-type ConditionOperation struct {
- Path string `json:"path"` // JSON路径
- Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
- Value interface{} `json:"value"` // 匹配的值
- Invert bool `json:"invert"` // 反选功能,true表示取反结果
- PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
-}
-
-type ParamOperation struct {
- Path string `json:"path"`
- Mode string `json:"mode"` // delete, set, move, prepend, append
- Value interface{} `json:"value"`
- KeepOrigin bool `json:"keep_origin"`
- From string `json:"from,omitempty"`
- To string `json:"to,omitempty"`
- Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
- Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
-}
-
-func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
- if len(paramOverride) == 0 {
- return jsonData, nil
- }
-
- // 尝试断言为操作格式
- if operations, ok := tryParseOperations(paramOverride); ok {
- // 使用新方法
- result, err := applyOperations(string(jsonData), operations)
- return []byte(result), err
- }
-
- // 直接使用旧方法
- return applyOperationsLegacy(jsonData, paramOverride)
-}
-
-func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
- // 检查是否包含 "operations" 字段
- if opsValue, exists := paramOverride["operations"]; exists {
- if opsSlice, ok := opsValue.([]interface{}); ok {
- var operations []ParamOperation
- for _, op := range opsSlice {
- if opMap, ok := op.(map[string]interface{}); ok {
- operation := ParamOperation{}
-
- // 断言必要字段
- if path, ok := opMap["path"].(string); ok {
- operation.Path = path
- }
- if mode, ok := opMap["mode"].(string); ok {
- operation.Mode = mode
- } else {
- return nil, false // mode 是必需的
- }
-
- // 可选字段
- if value, exists := opMap["value"]; exists {
- operation.Value = value
- }
- if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
- operation.KeepOrigin = keepOrigin
- }
- if from, ok := opMap["from"].(string); ok {
- operation.From = from
- }
- if to, ok := opMap["to"].(string); ok {
- operation.To = to
- }
- if logic, ok := opMap["logic"].(string); ok {
- operation.Logic = logic
- } else {
- operation.Logic = "OR" // 默认为OR
- }
-
- // 解析条件
- if conditions, exists := opMap["conditions"]; exists {
- if condSlice, ok := conditions.([]interface{}); ok {
- for _, cond := range condSlice {
- if condMap, ok := cond.(map[string]interface{}); ok {
- condition := ConditionOperation{}
- if path, ok := condMap["path"].(string); ok {
- condition.Path = path
- }
- if mode, ok := condMap["mode"].(string); ok {
- condition.Mode = mode
- }
- if value, ok := condMap["value"]; ok {
- condition.Value = value
- }
- if invert, ok := condMap["invert"].(bool); ok {
- condition.Invert = invert
- }
- if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
- condition.PassMissingKey = passMissingKey
- }
- operation.Conditions = append(operation.Conditions, condition)
- }
- }
- }
- }
-
- operations = append(operations, operation)
- } else {
- return nil, false
- }
- }
- return operations, true
- }
- }
-
- return nil, false
-}
-
-func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
- if len(conditions) == 0 {
- return true, nil // 没有条件,直接通过
- }
- results := make([]bool, len(conditions))
- for i, condition := range conditions {
- result, err := checkSingleCondition(jsonStr, condition)
- if err != nil {
- return false, err
- }
- results[i] = result
- }
-
- if strings.ToUpper(logic) == "AND" {
- for _, result := range results {
- if !result {
- return false, nil
- }
- }
- return true, nil
- } else {
- for _, result := range results {
- if result {
- return true, nil
- }
- }
- return false, nil
- }
-}
-
-func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
- // 处理负数索引
- path := processNegativeIndex(jsonStr, condition.Path)
- value := gjson.Get(jsonStr, path)
- if !value.Exists() {
- if condition.PassMissingKey {
- return true, nil
- }
- return false, nil
- }
-
- // 利用gjson的类型解析
- targetBytes, err := json.Marshal(condition.Value)
- if err != nil {
- return false, fmt.Errorf("failed to marshal condition value: %v", err)
- }
- targetValue := gjson.ParseBytes(targetBytes)
-
- result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
- if err != nil {
- return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
- }
-
- if condition.Invert {
- result = !result
- }
- return result, nil
-}
-
-func processNegativeIndex(jsonStr string, path string) string {
- re := regexp.MustCompile(`\.(-\d+)`)
- matches := re.FindAllStringSubmatch(path, -1)
-
- if len(matches) == 0 {
- return path
- }
-
- result := path
- for _, match := range matches {
- negIndex := match[1]
- index, _ := strconv.Atoi(negIndex)
-
- arrayPath := strings.Split(path, negIndex)[0]
- if strings.HasSuffix(arrayPath, ".") {
- arrayPath = arrayPath[:len(arrayPath)-1]
- }
-
- array := gjson.Get(jsonStr, arrayPath)
- if array.IsArray() {
- length := len(array.Array())
- actualIndex := length + index
- if actualIndex >= 0 && actualIndex < length {
- result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
- }
- }
- }
-
- return result
-}
-
-// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
-func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
- switch mode {
- case "full":
- return compareEqual(jsonValue, targetValue)
- case "prefix":
- return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
- case "suffix":
- return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
- case "contains":
- return strings.Contains(jsonValue.String(), targetValue.String()), nil
- case "gt":
- return compareNumeric(jsonValue, targetValue, "gt")
- case "gte":
- return compareNumeric(jsonValue, targetValue, "gte")
- case "lt":
- return compareNumeric(jsonValue, targetValue, "lt")
- case "lte":
- return compareNumeric(jsonValue, targetValue, "lte")
- default:
- return false, fmt.Errorf("unsupported comparison mode: %s", mode)
- }
-}
-
-func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
- // 对布尔值特殊处理
- if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
- (targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
- return jsonValue.Bool() == targetValue.Bool(), nil
- }
-
- // 如果类型不同,报错
- if jsonValue.Type != targetValue.Type {
- return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
- }
-
- switch jsonValue.Type {
- case gjson.True, gjson.False:
- return jsonValue.Bool() == targetValue.Bool(), nil
- case gjson.Number:
- return jsonValue.Num == targetValue.Num, nil
- case gjson.String:
- return jsonValue.String() == targetValue.String(), nil
- default:
- return jsonValue.String() == targetValue.String(), nil
- }
-}
-
-func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
- // 只有数字类型才支持数值比较
- if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
- return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
- }
-
- jsonNum := jsonValue.Num
- targetNum := targetValue.Num
-
- switch operator {
- case "gt":
- return jsonNum > targetNum, nil
- case "gte":
- return jsonNum >= targetNum, nil
- case "lt":
- return jsonNum < targetNum, nil
- case "lte":
- return jsonNum <= targetNum, nil
- default:
- return false, fmt.Errorf("unsupported numeric operator: %s", operator)
- }
-}
-
-// applyOperationsLegacy 原参数覆盖方法
-func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
- reqMap := make(map[string]interface{})
- err := json.Unmarshal(jsonData, &reqMap)
- if err != nil {
- return nil, err
- }
-
- for key, value := range paramOverride {
- reqMap[key] = value
- }
-
- return json.Marshal(reqMap)
-}
-
-func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
- result := jsonStr
- for _, op := range operations {
- // 检查条件是否满足
- ok, err := checkConditions(result, op.Conditions, op.Logic)
- if err != nil {
- return "", err
- }
- if !ok {
- continue // 条件不满足,跳过当前操作
- }
- // 处理路径中的负数索引
- opPath := processNegativeIndex(result, op.Path)
- opFrom := processNegativeIndex(result, op.From)
- opTo := processNegativeIndex(result, op.To)
-
- switch op.Mode {
- case "delete":
- result, err = sjson.Delete(result, opPath)
- case "set":
- if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
- continue
- }
- result, err = sjson.Set(result, opPath, op.Value)
- case "move":
- result, err = moveValue(result, opFrom, opTo)
- case "prepend":
- result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
- case "append":
- result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
- default:
- return "", fmt.Errorf("unknown operation: %s", op.Mode)
- }
- if err != nil {
- return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
- }
- }
- return result, nil
-}
-
-func moveValue(jsonStr, fromPath, toPath string) (string, error) {
- sourceValue := gjson.Get(jsonStr, fromPath)
- if !sourceValue.Exists() {
- return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
- }
- result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
- if err != nil {
- return "", err
- }
- return sjson.Delete(result, fromPath)
-}
-
-func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
- current := gjson.Get(jsonStr, path)
- switch {
- case current.IsArray():
- return modifyArray(jsonStr, path, value, isPrepend)
- case current.Type == gjson.String:
- return modifyString(jsonStr, path, value, isPrepend)
- case current.Type == gjson.JSON:
- return mergeObjects(jsonStr, path, value, keepOrigin)
- }
- return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
-}
-
-func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
- current := gjson.Get(jsonStr, path)
- var newArray []interface{}
- // 添加新值
- addValue := func() {
- if arr, ok := value.([]interface{}); ok {
- newArray = append(newArray, arr...)
- } else {
- newArray = append(newArray, value)
- }
- }
- // 添加原值
- addOriginal := func() {
- current.ForEach(func(_, val gjson.Result) bool {
- newArray = append(newArray, val.Value())
- return true
- })
- }
- if isPrepend {
- addValue()
- addOriginal()
- } else {
- addOriginal()
- addValue()
- }
- return sjson.Set(jsonStr, path, newArray)
-}
-
-func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
- current := gjson.Get(jsonStr, path)
- valueStr := fmt.Sprintf("%v", value)
- var newStr string
- if isPrepend {
- newStr = valueStr + current.String()
- } else {
- newStr = current.String() + valueStr
- }
- return sjson.Set(jsonStr, path, newStr)
-}
-
-func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
- current := gjson.Get(jsonStr, path)
- var currentMap, newMap map[string]interface{}
-
- // 解析当前值
- if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
- return "", err
- }
- // 解析新值
- switch v := value.(type) {
- case map[string]interface{}:
- newMap = v
- default:
- jsonBytes, _ := json.Marshal(v)
- if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
- return "", err
- }
- }
- // 合并
- result := make(map[string]interface{})
- for k, v := range currentMap {
- result[k] = v
- }
- for k, v := range newMap {
- if !keepOrigin || result[k] == nil {
- result[k] = v
- }
- }
- return sjson.Set(jsonStr, path, result)
-}
diff --git a/new-api/relay/common/relay_info.go b/new-api/relay/common/relay_info.go
deleted file mode 100644
index 52162ad64668c40e44758edd29d721fc4260cfe7..0000000000000000000000000000000000000000
--- a/new-api/relay/common/relay_info.go
+++ /dev/null
@@ -1,509 +0,0 @@
-package common
-
-import (
- "errors"
- "fmt"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relayconstant "one-api/relay/constant"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-type ThinkingContentInfo struct {
- IsFirstThinkingContent bool
- SendLastThinkingContent bool
- HasSentThinkingContent bool
-}
-
-const (
- LastMessageTypeNone = "none"
- LastMessageTypeText = "text"
- LastMessageTypeTools = "tools"
- LastMessageTypeThinking = "thinking"
-)
-
-type ClaudeConvertInfo struct {
- LastMessagesType string
- Index int
- Usage *dto.Usage
- FinishReason string
- Done bool
-}
-
-type RerankerInfo struct {
- Documents []any
- ReturnDocuments bool
-}
-
-type BuildInToolInfo struct {
- ToolName string
- CallCount int
- SearchContextSize string
-}
-
-type ResponsesUsageInfo struct {
- BuiltInTools map[string]*BuildInToolInfo
-}
-
-type ChannelMeta struct {
- ChannelType int
- ChannelId int
- ChannelIsMultiKey bool
- ChannelMultiKeyIndex int
- ChannelBaseUrl string
- ApiType int
- ApiVersion string
- ApiKey string
- Organization string
- ChannelCreateTime int64
- ParamOverride map[string]interface{}
- HeadersOverride map[string]interface{}
- ChannelSetting dto.ChannelSettings
- ChannelOtherSettings dto.ChannelOtherSettings
- UpstreamModelName string
- IsModelMapped bool
- SupportStreamOptions bool // 是否支持流式选项
-}
-
-type RelayInfo struct {
- TokenId int
- TokenKey string
- UserId int
- UsingGroup string // 使用的分组
- UserGroup string // 用户所在分组
- TokenUnlimited bool
- StartTime time.Time
- FirstResponseTime time.Time
- isFirstResponse bool
- //SendLastReasoningResponse bool
- IsStream bool
- IsGeminiBatchEmbedding bool
- IsPlayground bool
- UsePrice bool
- RelayMode int
- OriginModelName string
- RequestURLPath string
- PromptTokens int
- ShouldIncludeUsage bool
- DisablePing bool // 是否禁止向下游发送自定义 Ping
- ClientWs *websocket.Conn
- TargetWs *websocket.Conn
- InputAudioFormat string
- OutputAudioFormat string
- RealtimeTools []dto.RealTimeTool
- IsFirstRequest bool
- AudioUsage bool
- ReasoningEffort string
- UserSetting dto.UserSetting
- UserEmail string
- UserQuota int
- RelayFormat types.RelayFormat
- SendResponseCount int
- FinalPreConsumedQuota int // 最终预消耗的配额
- IsClaudeBetaQuery bool // /v1/messages?beta=true
-
- PriceData types.PriceData
-
- Request dto.Request
-
- ThinkingContentInfo
- *ClaudeConvertInfo
- *RerankerInfo
- *ResponsesUsageInfo
- *ChannelMeta
- *TaskRelayInfo
-}
-
-func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
- channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
- paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
- headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
- apiType, _ := common.ChannelType2APIType(channelType)
- channelMeta := &ChannelMeta{
- ChannelType: channelType,
- ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
- ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
- ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
- ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
- ApiType: apiType,
- ApiVersion: c.GetString("api_version"),
- ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
- Organization: c.GetString("channel_organization"),
- ChannelCreateTime: c.GetInt64("channel_create_time"),
- ParamOverride: paramOverride,
- HeadersOverride: headerOverride,
- UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
- IsModelMapped: false,
- SupportStreamOptions: false,
- }
-
- if channelType == constant.ChannelTypeAzure {
- channelMeta.ApiVersion = GetAPIVersion(c)
- }
- if channelType == constant.ChannelTypeVertexAi {
- channelMeta.ApiVersion = c.GetString("region")
- }
-
- channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
- if ok {
- channelMeta.ChannelSetting = channelSetting
- }
-
- channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
- if ok {
- channelMeta.ChannelOtherSettings = channelOtherSettings
- }
-
- if streamSupportedChannels[channelMeta.ChannelType] {
- channelMeta.SupportStreamOptions = true
- }
-
- info.ChannelMeta = channelMeta
-
- // reset some fields based on channel meta
- // 重置某些字段,例如模型名称等
- if info.Request != nil {
- info.Request.SetModelName(info.OriginModelName)
- }
-}
-
-func (info *RelayInfo) ToString() string {
- if info == nil {
- return "RelayInfo"
- }
-
- // Basic info
- b := &strings.Builder{}
- fmt.Fprintf(b, "RelayInfo{ ")
- fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
- fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
- fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
- fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
- fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
- fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
- fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
- fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
- fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
- fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
- fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
-
- // User & token info (mask secrets)
- fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
- info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
- fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
-
- // Time info
- latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
- fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
- info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
-
- // Audio / realtime
- if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
- fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
- info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
- }
-
- // Reasoning
- if info.ReasoningEffort != "" {
- fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
- }
-
- // Price data (non-sensitive)
- if info.PriceData.UsePrice {
- fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
- }
-
- // Channel metadata (mask ApiKey)
- if info.ChannelMeta != nil {
- cm := info.ChannelMeta
- fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
- cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
- }
-
- // Responses usage info (non-sensitive)
- if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
- fmt.Fprintf(b, "ResponsesTools{ ")
- first := true
- for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
- if !first {
- fmt.Fprintf(b, ", ")
- }
- first = false
- if tool != nil {
- fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
- } else {
- fmt.Fprintf(b, "%s: calls=0", name)
- }
- }
- fmt.Fprintf(b, " }, ")
- }
-
- fmt.Fprintf(b, "}")
- return b.String()
-}
-
-// 定义支持流式选项的通道类型
-var streamSupportedChannels = map[int]bool{
- constant.ChannelTypeOpenAI: true,
- constant.ChannelTypeAnthropic: true,
- constant.ChannelTypeAws: true,
- constant.ChannelTypeGemini: true,
- constant.ChannelCloudflare: true,
- constant.ChannelTypeAzure: true,
- constant.ChannelTypeVolcEngine: true,
- constant.ChannelTypeOllama: true,
- constant.ChannelTypeXai: true,
- constant.ChannelTypeDeepSeek: true,
- constant.ChannelTypeBaiduV2: true,
-}
-
-func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
- info := genBaseRelayInfo(c, nil)
- info.RelayFormat = types.RelayFormatOpenAIRealtime
- info.ClientWs = ws
- info.InputAudioFormat = "pcm16"
- info.OutputAudioFormat = "pcm16"
- info.IsFirstRequest = true
- return info
-}
-
-func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatClaude
- info.ShouldIncludeUsage = false
- info.ClaudeConvertInfo = &ClaudeConvertInfo{
- LastMessagesType: LastMessageTypeNone,
- }
- if c.Query("beta") == "true" {
- info.IsClaudeBetaQuery = true
- }
- return info
-}
-
-func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayMode = relayconstant.RelayModeRerank
- info.RelayFormat = types.RelayFormatRerank
- info.RerankerInfo = &RerankerInfo{
- Documents: request.Documents,
- ReturnDocuments: request.GetReturnDocuments(),
- }
- return info
-}
-
-func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatOpenAIAudio
- return info
-}
-
-func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatEmbedding
- return info
-}
-
-func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayMode = relayconstant.RelayModeResponses
- info.RelayFormat = types.RelayFormatOpenAIResponses
-
- info.ResponsesUsageInfo = &ResponsesUsageInfo{
- BuiltInTools: make(map[string]*BuildInToolInfo),
- }
- if len(request.Tools) > 0 {
- for _, tool := range request.GetToolsMap() {
- toolType := common.Interface2String(tool["type"])
- info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
- ToolName: toolType,
- CallCount: 0,
- }
- switch toolType {
- case dto.BuildInToolWebSearchPreview:
- searchContextSize := common.Interface2String(tool["search_context_size"])
- if searchContextSize == "" {
- searchContextSize = "medium"
- }
- info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
- }
- }
- }
- return info
-}
-
-func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatGemini
- info.ShouldIncludeUsage = false
-
- return info
-}
-
-func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatOpenAIImage
- return info
-}
-
-func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
- info := genBaseRelayInfo(c, request)
- info.RelayFormat = types.RelayFormatOpenAI
- return info
-}
-
-func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
-
- //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
- //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
- //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
-
- startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
- if startTime.IsZero() {
- startTime = time.Now()
- }
-
- isStream := false
-
- if request != nil {
- isStream = request.IsStream(c)
- }
-
- // firstResponseTime = time.Now() - 1 second
-
- info := &RelayInfo{
- Request: request,
-
- UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
- UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
- UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
- UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
- UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
-
- OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
- PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
-
- TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
- TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
- TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
-
- isFirstResponse: true,
- RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
- RequestURLPath: c.Request.URL.String(),
- IsStream: isStream,
-
- StartTime: startTime,
- FirstResponseTime: startTime.Add(-time.Second),
- ThinkingContentInfo: ThinkingContentInfo{
- IsFirstThinkingContent: true,
- SendLastThinkingContent: false,
- },
- }
-
- if info.RelayMode == relayconstant.RelayModeUnknown {
- info.RelayMode = c.GetInt("relay_mode")
- }
-
- if strings.HasPrefix(c.Request.URL.Path, "/pg") {
- info.IsPlayground = true
- info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
- info.RequestURLPath = "/v1" + info.RequestURLPath
- }
-
- userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
- if ok {
- info.UserSetting = userSetting
- }
-
- return info
-}
-
-func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
- switch relayFormat {
- case types.RelayFormatOpenAI:
- return GenRelayInfoOpenAI(c, request), nil
- case types.RelayFormatOpenAIAudio:
- return GenRelayInfoOpenAIAudio(c, request), nil
- case types.RelayFormatOpenAIImage:
- return GenRelayInfoImage(c, request), nil
- case types.RelayFormatOpenAIRealtime:
- return GenRelayInfoWs(c, ws), nil
- case types.RelayFormatClaude:
- return GenRelayInfoClaude(c, request), nil
- case types.RelayFormatRerank:
- if request, ok := request.(*dto.RerankRequest); ok {
- return GenRelayInfoRerank(c, request), nil
- }
- return nil, errors.New("request is not a RerankRequest")
- case types.RelayFormatGemini:
- return GenRelayInfoGemini(c, request), nil
- case types.RelayFormatEmbedding:
- return GenRelayInfoEmbedding(c, request), nil
- case types.RelayFormatOpenAIResponses:
- if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
- return GenRelayInfoResponses(c, request), nil
- }
- return nil, errors.New("request is not a OpenAIResponsesRequest")
- case types.RelayFormatTask:
- return genBaseRelayInfo(c, nil), nil
- case types.RelayFormatMjProxy:
- return genBaseRelayInfo(c, nil), nil
- default:
- return nil, errors.New("invalid relay format")
- }
-}
-
-func (info *RelayInfo) SetPromptTokens(promptTokens int) {
- info.PromptTokens = promptTokens
-}
-
-func (info *RelayInfo) SetFirstResponseTime() {
- if info.isFirstResponse {
- info.FirstResponseTime = time.Now()
- info.isFirstResponse = false
- }
-}
-
-func (info *RelayInfo) HasSendResponse() bool {
- return info.FirstResponseTime.After(info.StartTime)
-}
-
-type TaskRelayInfo struct {
- Action string
- OriginTaskID string
-
- ConsumeQuota bool
-}
-
-type TaskSubmitReq struct {
- Prompt string `json:"prompt"`
- Model string `json:"model,omitempty"`
- Mode string `json:"mode,omitempty"`
- Image string `json:"image,omitempty"`
- Images []string `json:"images,omitempty"`
- Size string `json:"size,omitempty"`
- Duration int `json:"duration,omitempty"`
- Metadata map[string]interface{} `json:"metadata,omitempty"`
-}
-
-func (t TaskSubmitReq) GetPrompt() string {
- return t.Prompt
-}
-
-func (t TaskSubmitReq) HasImage() bool {
- return len(t.Images) > 0
-}
-
-type TaskInfo struct {
- Code int `json:"code"`
- TaskID string `json:"task_id"`
- Status string `json:"status"`
- Reason string `json:"reason,omitempty"`
- Url string `json:"url,omitempty"`
- Progress string `json:"progress,omitempty"`
-}
diff --git a/new-api/relay/common/relay_utils.go b/new-api/relay/common/relay_utils.go
deleted file mode 100644
index 3b27242a27510d42fe1673ea9b30d408a7abb7c4..0000000000000000000000000000000000000000
--- a/new-api/relay/common/relay_utils.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package common
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-type HasPrompt interface {
- GetPrompt() string
-}
-
-type HasImage interface {
- HasImage() bool
-}
-
-func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
- switch channelType {
- case constant.ChannelTypeOpenAI:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- case constant.ChannelTypeAzure:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
- }
- }
- return fullRequestURL
-}
-
-func GetAPIVersion(c *gin.Context) string {
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString("api_version")
- }
- return apiVersion
-}
-
-func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
- return &dto.TaskError{
- Code: code,
- Message: err.Error(),
- StatusCode: statusCode,
- LocalError: localError,
- Error: err,
- }
-}
-
-func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
- info.Action = action
- c.Set("task_request", requestObj)
-}
-
-func validatePrompt(prompt string) *dto.TaskError {
- if strings.TrimSpace(prompt) == "" {
- return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
- }
- return nil
-}
-
-func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
- var req TaskSubmitReq
- if err := common.UnmarshalBodyReusable(c, &req); err != nil {
- return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
- }
-
- if taskErr := validatePrompt(req.Prompt); taskErr != nil {
- return taskErr
- }
-
- if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
- // 兼容单图上传
- req.Images = []string{req.Image}
- }
-
- if req.HasImage() {
- action = constant.TaskActionGenerate
- if info.ChannelType == constant.ChannelTypeVidu {
- // vidu 增加 首尾帧生视频和参考图生视频
- if len(req.Images) == 2 {
- action = constant.TaskActionFirstTailGenerate
- } else if len(req.Images) > 2 {
- action = constant.TaskActionReferenceGenerate
- }
- }
- }
-
- storeTaskRequest(c, info, action, req)
- return nil
-}
diff --git a/new-api/relay/common_handler/rerank.go b/new-api/relay/common_handler/rerank.go
deleted file mode 100644
index 8c13e2cee07796952aab2f67dddf707222f64287..0000000000000000000000000000000000000000
--- a/new-api/relay/common_handler/rerank.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package common_handler
-
-import (
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel/xinference"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- if common.DebugEnabled {
- println("reranker response body: ", string(responseBody))
- }
- var jinaResp dto.RerankResponse
- if info.ChannelType == constant.ChannelTypeXinference {
- var xinRerankResponse xinference.XinRerankResponse
- err = common.Unmarshal(responseBody, &xinRerankResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
- for i, result := range xinRerankResponse.Results {
- respResult := dto.RerankResponseResult{
- Index: result.Index,
- RelevanceScore: result.RelevanceScore,
- }
- if info.ReturnDocuments {
- var document any
- if result.Document != nil {
- if doc, ok := result.Document.(string); ok {
- if doc == "" {
- document = info.Documents[result.Index]
- } else {
- document = doc
- }
- } else {
- document = result.Document
- }
- }
- respResult.Document = document
- }
- jinaRespResults[i] = respResult
- }
- jinaResp = dto.RerankResponse{
- Results: jinaRespResults,
- Usage: dto.Usage{
- PromptTokens: info.PromptTokens,
- TotalTokens: info.PromptTokens,
- },
- }
- } else {
- err = common.Unmarshal(responseBody, &jinaResp)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
- c.JSON(http.StatusOK, jinaResp)
- return &jinaResp.Usage, nil
-}
diff --git a/new-api/relay/compatible_handler.go b/new-api/relay/compatible_handler.go
deleted file mode 100644
index 846d818b269db9bc4044c621d7e4da1bfea38630..0000000000000000000000000000000000000000
--- a/new-api/relay/compatible_handler.go
+++ /dev/null
@@ -1,461 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/setting/operation_setting"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/shopspring/decimal"
-
- "github.com/gin-gonic/gin"
-)
-
-func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(textReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- if request.WebSearchOptions != nil {
- c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- includeUsage := true
- // 判断用户是否需要返回使用情况
- if request.StreamOptions != nil {
- includeUsage = request.StreamOptions.IncludeUsage
- }
-
- // 如果不支持StreamOptions,将StreamOptions设置为nil
- if !info.SupportStreamOptions || !request.Stream {
- request.StreamOptions = nil
- } else {
- // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
- if constant.ForceStreamOption {
- request.StreamOptions = &dto.StreamOptions{
- IncludeUsage: true,
- }
- }
- }
-
- info.ShouldIncludeUsage = includeUsage
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
- var requestBody io.Reader
-
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- if common.DebugEnabled {
- println("requestBody: ", string(body))
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- if info.ChannelSetting.SystemPrompt != "" {
- // 如果有系统提示,则将其添加到请求中
- request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
- if ok {
- containSystemPrompt := false
- for _, message := range request.Messages {
- if message.Role == request.GetSystemRoleName() {
- containSystemPrompt = true
- break
- }
- }
- if !containSystemPrompt {
- // 如果没有系统提示,则添加系统提示
- systemMessage := dto.Message{
- Role: request.GetSystemRoleName(),
- Content: info.ChannelSetting.SystemPrompt,
- }
- request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
- } else if info.ChannelSetting.SystemPromptOverride {
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- // 如果有系统提示,且允许覆盖,则拼接到前面
- for i, message := range request.Messages {
- if message.Role == request.GetSystemRoleName() {
- if message.IsStringContent() {
- request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
- } else {
- contents := message.ParseContent()
- contents = append([]dto.MediaContent{
- {
- Type: dto.ContentTypeText,
- Text: info.ChannelSetting.SystemPrompt,
- },
- }, contents...)
- request.Messages[i].Content = contents
- }
- break
- }
- }
- }
- }
- }
-
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
-
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- if resp != nil {
- httpResp = resp.(*http.Response)
- info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newApiErr, statusCodeMappingStr)
- return newApiErr
- }
- }
-
- usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
- if newApiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newApiErr, statusCodeMappingStr)
- return newApiErr
- }
-
- if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
- } else {
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- }
- return nil
-}
-
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
- if usage == nil {
- usage = &dto.Usage{
- PromptTokens: relayInfo.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: relayInfo.PromptTokens,
- }
- extraContent += "(可能是请求出错)"
- }
- useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
- promptTokens := usage.PromptTokens
- cacheTokens := usage.PromptTokensDetails.CachedTokens
- imageTokens := usage.PromptTokensDetails.ImageTokens
- audioTokens := usage.PromptTokensDetails.AudioTokens
- completionTokens := usage.CompletionTokens
- cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
-
- modelName := relayInfo.OriginModelName
-
- tokenName := ctx.GetString("token_name")
- completionRatio := relayInfo.PriceData.CompletionRatio
- cacheRatio := relayInfo.PriceData.CacheRatio
- imageRatio := relayInfo.PriceData.ImageRatio
- modelRatio := relayInfo.PriceData.ModelRatio
- groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
- modelPrice := relayInfo.PriceData.ModelPrice
- cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
-
- // Convert values to decimal for precise calculation
- dPromptTokens := decimal.NewFromInt(int64(promptTokens))
- dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
- dImageTokens := decimal.NewFromInt(int64(imageTokens))
- dAudioTokens := decimal.NewFromInt(int64(audioTokens))
- dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
- dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
- dCompletionRatio := decimal.NewFromFloat(completionRatio)
- dCacheRatio := decimal.NewFromFloat(cacheRatio)
- dImageRatio := decimal.NewFromFloat(imageRatio)
- dModelRatio := decimal.NewFromFloat(modelRatio)
- dGroupRatio := decimal.NewFromFloat(groupRatio)
- dModelPrice := decimal.NewFromFloat(modelPrice)
- dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
-
- ratio := dModelRatio.Mul(dGroupRatio)
-
- // openai web search 工具计费
- var dWebSearchQuota decimal.Decimal
- var webSearchPrice float64
- // response api 格式工具计费
- if relayInfo.ResponsesUsageInfo != nil {
- if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
- // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
- webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
- dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
- Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
- webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
- }
- } else if strings.HasSuffix(modelName, "search-preview") {
- // search-preview 模型不支持 response api
- searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
- if searchContextSize == "" {
- searchContextSize = "medium"
- }
- webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
- dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
- searchContextSize, dWebSearchQuota.String())
- }
- // claude web search tool 计费
- var dClaudeWebSearchQuota decimal.Decimal
- var claudeWebSearchPrice float64
- claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
- if claudeWebSearchCallCount > 0 {
- claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
- dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
- extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
- claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
- }
- // file search tool 计费
- var dFileSearchQuota decimal.Decimal
- var fileSearchPrice float64
- if relayInfo.ResponsesUsageInfo != nil {
- if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
- fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
- dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
- Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
- fileSearchTool.CallCount, dFileSearchQuota.String())
- }
- }
- var dImageGenerationCallQuota decimal.Decimal
- var imageGenerationCallPrice float64
- if ctx.GetBool("image_generation_call") {
- imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
- dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
- }
-
- var quotaCalculateDecimal decimal.Decimal
-
- var audioInputQuota decimal.Decimal
- var audioInputPrice float64
- if !relayInfo.PriceData.UsePrice {
- baseTokens := dPromptTokens
- // 减去 cached tokens
- var cachedTokensWithRatio decimal.Decimal
- if !dCacheTokens.IsZero() {
- baseTokens = baseTokens.Sub(dCacheTokens)
- cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
- }
- var dCachedCreationTokensWithRatio decimal.Decimal
- if !dCachedCreationTokens.IsZero() {
- baseTokens = baseTokens.Sub(dCachedCreationTokens)
- dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
- }
-
- // 减去 image tokens
- var imageTokensWithRatio decimal.Decimal
- if !dImageTokens.IsZero() {
- baseTokens = baseTokens.Sub(dImageTokens)
- imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
- }
-
- // 减去 Gemini audio tokens
- if !dAudioTokens.IsZero() {
- audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
- if audioInputPrice > 0 {
- // 重新计算 base tokens
- baseTokens = baseTokens.Sub(dAudioTokens)
- audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
- }
- }
- promptQuota := baseTokens.Add(cachedTokensWithRatio).
- Add(imageTokensWithRatio).
- Add(dCachedCreationTokensWithRatio)
-
- completionQuota := dCompletionTokens.Mul(dCompletionRatio)
-
- quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
-
- if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
- quotaCalculateDecimal = decimal.NewFromInt(1)
- }
- } else {
- quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
- }
- // 添加 responses tools call 调用的配额
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
- // 添加 audio input 独立计费
- quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
- // 添加 image generation call 计费
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
-
- quota := int(quotaCalculateDecimal.Round(0).IntPart())
- totalTokens := promptTokens + completionTokens
-
- var logContent string
-
- // record all the consume log even if quota is 0
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- logContent += fmt.Sprintf("(可能是上游超时)")
- logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
- } else {
- if !ratio.IsZero() && quota == 0 {
- quota = 1
- }
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
-
- quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
- //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta)))
-
- if quotaDelta > 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- } else if quotaDelta < 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(-quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- }
-
- if quotaDelta != 0 {
- err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
- if err != nil {
- logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- }
-
- logModel := modelName
- if strings.HasPrefix(logModel, "gpt-4-gizmo") {
- logModel = "gpt-4-gizmo-*"
- logContent += fmt.Sprintf(",模型 %s", modelName)
- }
- if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
- logModel = "gpt-4o-gizmo-*"
- logContent += fmt.Sprintf(",模型 %s", modelName)
- }
- if extraContent != "" {
- logContent += ", " + extraContent
- }
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
- if imageTokens != 0 {
- other["image"] = true
- other["image_ratio"] = imageRatio
- other["image_output"] = imageTokens
- }
- if cachedCreationTokens != 0 {
- other["cache_creation_tokens"] = cachedCreationTokens
- other["cache_creation_ratio"] = cachedCreationRatio
- }
- if !dWebSearchQuota.IsZero() {
- if relayInfo.ResponsesUsageInfo != nil {
- if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
- other["web_search"] = true
- other["web_search_call_count"] = webSearchTool.CallCount
- other["web_search_price"] = webSearchPrice
- }
- } else if strings.HasSuffix(modelName, "search-preview") {
- other["web_search"] = true
- other["web_search_call_count"] = 1
- other["web_search_price"] = webSearchPrice
- }
- } else if !dClaudeWebSearchQuota.IsZero() {
- other["web_search"] = true
- other["web_search_call_count"] = claudeWebSearchCallCount
- other["web_search_price"] = claudeWebSearchPrice
- }
- if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
- if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
- other["file_search"] = true
- other["file_search_call_count"] = fileSearchTool.CallCount
- other["file_search_price"] = fileSearchPrice
- }
- }
- if !audioInputQuota.IsZero() {
- other["audio_input_seperate_price"] = true
- other["audio_input_token_count"] = audioTokens
- other["audio_input_price"] = audioInputPrice
- }
- if !dImageGenerationCallQuota.IsZero() {
- other["image_generation_call"] = true
- other["image_generation_call_price"] = imageGenerationCallPrice
- }
- model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- ModelName: logModel,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: relayInfo.TokenId,
- UseTimeSeconds: int(useTimeSeconds),
- IsStream: relayInfo.IsStream,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
-}
diff --git a/new-api/relay/constant/relay_mode.go b/new-api/relay/constant/relay_mode.go
deleted file mode 100644
index 503f0fcc12bf246859d01dfac7aefef6d7ebf6a6..0000000000000000000000000000000000000000
--- a/new-api/relay/constant/relay_mode.go
+++ /dev/null
@@ -1,146 +0,0 @@
-package constant
-
-import (
- "net/http"
- "strings"
-)
-
-const (
- RelayModeUnknown = iota
- RelayModeChatCompletions
- RelayModeCompletions
- RelayModeEmbeddings
- RelayModeModerations
- RelayModeImagesGenerations
- RelayModeImagesEdits
- RelayModeEdits
-
- RelayModeMidjourneyImagine
- RelayModeMidjourneyDescribe
- RelayModeMidjourneyBlend
- RelayModeMidjourneyChange
- RelayModeMidjourneySimpleChange
- RelayModeMidjourneyNotify
- RelayModeMidjourneyTaskFetch
- RelayModeMidjourneyTaskImageSeed
- RelayModeMidjourneyTaskFetchByCondition
- RelayModeMidjourneyAction
- RelayModeMidjourneyModal
- RelayModeMidjourneyShorten
- RelayModeSwapFace
- RelayModeMidjourneyUpload
- RelayModeMidjourneyVideo
- RelayModeMidjourneyEdits
-
- RelayModeAudioSpeech // tts
- RelayModeAudioTranscription // whisper
- RelayModeAudioTranslation // whisper
-
- RelayModeSunoFetch
- RelayModeSunoFetchByID
- RelayModeSunoSubmit
-
- RelayModeVideoFetchByID
- RelayModeVideoSubmit
-
- RelayModeRerank
-
- RelayModeResponses
-
- RelayModeRealtime
-
- RelayModeGemini
-)
-
-func Path2RelayMode(path string) int {
- relayMode := RelayModeUnknown
- if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
- relayMode = RelayModeChatCompletions
- } else if strings.HasPrefix(path, "/v1/completions") {
- relayMode = RelayModeCompletions
- } else if strings.HasPrefix(path, "/v1/embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasSuffix(path, "embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasPrefix(path, "/v1/moderations") {
- relayMode = RelayModeModerations
- } else if strings.HasPrefix(path, "/v1/images/generations") {
- relayMode = RelayModeImagesGenerations
- } else if strings.HasPrefix(path, "/v1/images/edits") {
- relayMode = RelayModeImagesEdits
- } else if strings.HasPrefix(path, "/v1/edits") {
- relayMode = RelayModeEdits
- } else if strings.HasPrefix(path, "/v1/responses") {
- relayMode = RelayModeResponses
- } else if strings.HasPrefix(path, "/v1/audio/speech") {
- relayMode = RelayModeAudioSpeech
- } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
- relayMode = RelayModeAudioTranscription
- } else if strings.HasPrefix(path, "/v1/audio/translations") {
- relayMode = RelayModeAudioTranslation
- } else if strings.HasPrefix(path, "/v1/rerank") {
- relayMode = RelayModeRerank
- } else if strings.HasPrefix(path, "/v1/realtime") {
- relayMode = RelayModeRealtime
- } else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
- relayMode = RelayModeGemini
- } else if strings.HasPrefix(path, "/mj") {
- relayMode = Path2RelayModeMidjourney(path)
- }
- return relayMode
-}
-
-func Path2RelayModeMidjourney(path string) int {
- relayMode := RelayModeUnknown
- if strings.HasSuffix(path, "/mj/submit/action") {
- // midjourney plus
- relayMode = RelayModeMidjourneyAction
- } else if strings.HasSuffix(path, "/mj/submit/modal") {
- // midjourney plus
- relayMode = RelayModeMidjourneyModal
- } else if strings.HasSuffix(path, "/mj/submit/shorten") {
- // midjourney plus
- relayMode = RelayModeMidjourneyShorten
- } else if strings.HasSuffix(path, "/mj/insight-face/swap") {
- // midjourney plus
- relayMode = RelayModeSwapFace
- } else if strings.HasSuffix(path, "/submit/upload-discord-images") {
- // midjourney plus
- relayMode = RelayModeMidjourneyUpload
- } else if strings.HasSuffix(path, "/mj/submit/imagine") {
- relayMode = RelayModeMidjourneyImagine
- } else if strings.HasSuffix(path, "/mj/submit/video") {
- relayMode = RelayModeMidjourneyVideo
- } else if strings.HasSuffix(path, "/mj/submit/edits") {
- relayMode = RelayModeMidjourneyEdits
- } else if strings.HasSuffix(path, "/mj/submit/blend") {
- relayMode = RelayModeMidjourneyBlend
- } else if strings.HasSuffix(path, "/mj/submit/describe") {
- relayMode = RelayModeMidjourneyDescribe
- } else if strings.HasSuffix(path, "/mj/notify") {
- relayMode = RelayModeMidjourneyNotify
- } else if strings.HasSuffix(path, "/mj/submit/change") {
- relayMode = RelayModeMidjourneyChange
- } else if strings.HasSuffix(path, "/mj/submit/simple-change") {
- relayMode = RelayModeMidjourneyChange
- } else if strings.HasSuffix(path, "/fetch") {
- relayMode = RelayModeMidjourneyTaskFetch
- } else if strings.HasSuffix(path, "/image-seed") {
- relayMode = RelayModeMidjourneyTaskImageSeed
- } else if strings.HasSuffix(path, "/list-by-condition") {
- relayMode = RelayModeMidjourneyTaskFetchByCondition
- }
- return relayMode
-}
-
-func Path2RelaySuno(method, path string) int {
- relayMode := RelayModeUnknown
- if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
- relayMode = RelayModeSunoFetch
- } else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
- relayMode = RelayModeSunoFetchByID
- } else if strings.Contains(path, "/submit/") {
- relayMode = RelayModeSunoSubmit
- }
- return relayMode
-}
diff --git a/new-api/relay/embedding_handler.go b/new-api/relay/embedding_handler.go
deleted file mode 100644
index d26b9c0a87d855bd6bb5a98d9352c15ae533ac28..0000000000000000000000000000000000000000
--- a/new-api/relay/embedding_handler.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(embeddingReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- requestBody := bytes.NewBuffer(jsonData)
- statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- return nil
-}
diff --git a/new-api/relay/gemini_handler.go b/new-api/relay/gemini_handler.go
deleted file mode 100644
index be59227231e559b70f3236bc98d9603d2a2d2f57..0000000000000000000000000000000000000000
--- a/new-api/relay/gemini_handler.go
+++ /dev/null
@@ -1,293 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/relay/channel/gemini"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
- if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
- configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
- if configBudget != nil && *configBudget == 0 {
- // 如果思考预算为 0,则认为是非思考请求
- return true
- }
- }
- return false
-}
-
-func trimModelThinking(modelName string) string {
- // 去除模型名称中的 -nothinking 后缀
- if strings.HasSuffix(modelName, "-nothinking") {
- return strings.TrimSuffix(modelName, "-nothinking")
- }
- // 去除模型名称中的 -thinking 后缀
- if strings.HasSuffix(modelName, "-thinking") {
- return strings.TrimSuffix(modelName, "-thinking")
- }
-
- // 去除模型名称中的 -thinking-number
- if strings.Contains(modelName, "-thinking-") {
- parts := strings.Split(modelName, "-thinking-")
- if len(parts) > 1 {
- return parts[0] + "-thinking"
- }
- }
- return modelName
-}
-
-func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(geminiReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- // model mapped 模型映射
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- if isNoThinkingRequest(request) {
- // check is thinking
- if !strings.Contains(info.OriginModelName, "-nothinking") {
- // try to get no thinking model price
- noThinkingModelName := info.OriginModelName + "-nothinking"
- containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
- if containPrice {
- info.OriginModelName = noThinkingModelName
- info.UpstreamModelName = noThinkingModelName
- }
- }
- }
- if request.GenerationConfig.ThinkingConfig == nil {
- gemini.ThinkingAdaptor(request, info)
- }
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
-
- adaptor.Init(info)
-
- if info.ChannelSetting.SystemPrompt != "" {
- if request.SystemInstructions == nil {
- request.SystemInstructions = &dto.GeminiChatContent{
- Parts: []dto.GeminiPart{
- {Text: info.ChannelSetting.SystemPrompt},
- },
- }
- } else if len(request.SystemInstructions.Parts) == 0 {
- request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
- } else if info.ChannelSetting.SystemPromptOverride {
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- merged := false
- for i := range request.SystemInstructions.Parts {
- if request.SystemInstructions.Parts[i].Text == "" {
- continue
- }
- request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
- merged = true
- break
- }
- if !merged {
- request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
- }
- }
- }
-
- // Clean up empty system instruction
- if request.SystemInstructions != nil {
- hasContent := false
- for _, part := range request.SystemInstructions.Parts {
- if part.Text != "" {
- hasContent = true
- break
- }
- }
- if !hasContent {
- request.SystemInstructions = nil
- }
- }
-
- var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- requestBody = bytes.NewReader(body)
- } else {
- // 使用 ConvertGeminiRequest 转换请求格式
- convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- logger.LogDebug(c, "Gemini request body: "+string(jsonData))
-
- requestBody = bytes.NewReader(jsonData)
- }
-
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- logger.LogError(c, "Do gemini request failed: "+err.Error())
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
- if openaiErr != nil {
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- return nil
-}
-
-func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
- info.IsGeminiBatchEmbedding = isBatch
-
- var req dto.Request
- var err error
- var inputTexts []string
-
- if isBatch {
- batchRequest := &dto.GeminiBatchEmbeddingRequest{}
- err = common.UnmarshalBodyReusable(c, batchRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
- req = batchRequest
- for _, r := range batchRequest.Requests {
- for _, part := range r.Content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
- } else {
- singleRequest := &dto.GeminiEmbeddingRequest{}
- err = common.UnmarshalBodyReusable(c, singleRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
- req = singleRequest
- for _, part := range singleRequest.Content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
-
- err = helper.ModelMappedHelper(c, info, req)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- var requestBody io.Reader
- jsonData, err := common.Marshal(req)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- reqMap := make(map[string]interface{})
- _ = common.Unmarshal(jsonData, &reqMap)
- for key, value := range info.ParamOverride {
- reqMap[key] = value
- }
- jsonData, err = common.Marshal(reqMap)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
- requestBody = bytes.NewReader(jsonData)
-
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- logger.LogError(c, "Do gemini request failed: "+err.Error())
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
- if openaiErr != nil {
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- return nil
-}
diff --git a/new-api/relay/helper/common.go b/new-api/relay/helper/common.go
deleted file mode 100644
index e85907e662bb39e40a995d9784d7068f8778d2f6..0000000000000000000000000000000000000000
--- a/new-api/relay/helper/common.go
+++ /dev/null
@@ -1,183 +0,0 @@
-package helper
-
-import (
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-func FlushWriter(c *gin.Context) error {
- if c.Writer == nil {
- return nil
- }
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- return nil
- }
- return errors.New("streaming error: flusher not found")
-}
-
-func SetEventStreamHeaders(c *gin.Context) {
- // 检查是否已经设置过头部
- if _, exists := c.Get("event_stream_headers_set"); exists {
- return
- }
-
- // 设置标志,表示头部已经设置过
- c.Set("event_stream_headers_set", true)
-
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("Transfer-Encoding", "chunked")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
-}
-
-func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
- jsonData, err := common.Marshal(resp)
- if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
- } else {
- c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
- }
- _ = FlushWriter(c)
- return nil
-}
-
-func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
- c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
- c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
- _ = FlushWriter(c)
-}
-
-func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
- c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
- c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
- _ = FlushWriter(c)
-}
-
-func StringData(c *gin.Context, str string) error {
- //str = strings.TrimPrefix(str, "data: ")
- //str = strings.TrimSuffix(str, "\r")
- c.Render(-1, common.CustomEvent{Data: "data: " + str})
- _ = FlushWriter(c)
- return nil
-}
-
-func PingData(c *gin.Context) error {
- c.Writer.Write([]byte(": PING\n\n"))
- _ = FlushWriter(c)
- return nil
-}
-
-func ObjectData(c *gin.Context, object interface{}) error {
- if object == nil {
- return errors.New("object is nil")
- }
- jsonData, err := common.Marshal(object)
- if err != nil {
- return fmt.Errorf("error marshalling object: %w", err)
- }
- return StringData(c, string(jsonData))
-}
-
-func Done(c *gin.Context) {
- _ = StringData(c, "[DONE]")
-}
-
-func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
- if ws == nil {
- logger.LogError(c, "websocket connection is nil")
- return errors.New("websocket connection is nil")
- }
- //common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
- return ws.WriteMessage(1, []byte(str))
-}
-
-func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
- jsonData, err := common.Marshal(object)
- if err != nil {
- return fmt.Errorf("error marshalling object: %w", err)
- }
- if ws == nil {
- logger.LogError(c, "websocket connection is nil")
- return errors.New("websocket connection is nil")
- }
- //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
- return ws.WriteMessage(1, jsonData)
-}
-
-func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
- if ws == nil {
- return
- }
- errorObj := &dto.RealtimeEvent{
- Type: "error",
- EventId: GetLocalRealtimeID(c),
- Error: &openaiError,
- }
- _ = WssObject(c, ws, errorObj)
-}
-
-func GetResponseID(c *gin.Context) string {
- logID := c.GetString(common.RequestIdKey)
- return fmt.Sprintf("chatcmpl-%s", logID)
-}
-
-func GetLocalRealtimeID(c *gin.Context) string {
- logID := c.GetString(common.RequestIdKey)
- return fmt.Sprintf("evt_%s", logID)
-}
-
-func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
- return &dto.ChatCompletionsStreamResponse{
- Id: id,
- Object: "chat.completion.chunk",
- Created: createAt,
- Model: model,
- SystemFingerprint: systemFingerprint,
- Choices: []dto.ChatCompletionsStreamResponseChoice{
- {
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
- Role: "assistant",
- Content: common.GetPointer(""),
- },
- },
- },
- }
-}
-
-func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
- return &dto.ChatCompletionsStreamResponse{
- Id: id,
- Object: "chat.completion.chunk",
- Created: createAt,
- Model: model,
- SystemFingerprint: nil,
- Choices: []dto.ChatCompletionsStreamResponseChoice{
- {
- FinishReason: &finishReason,
- },
- },
- }
-}
-
-func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
- return &dto.ChatCompletionsStreamResponse{
- Id: id,
- Object: "chat.completion.chunk",
- Created: createAt,
- Model: model,
- SystemFingerprint: nil,
- Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
- Usage: &usage,
- }
-}
diff --git a/new-api/relay/helper/model_mapped.go b/new-api/relay/helper/model_mapped.go
deleted file mode 100644
index cda81d4ceec42f81f577343823da7c7b59546daa..0000000000000000000000000000000000000000
--- a/new-api/relay/helper/model_mapped.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package helper
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "one-api/dto"
- "one-api/relay/common"
-)
-
-func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
- // map model name
- modelMapping := c.GetString("model_mapping")
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return fmt.Errorf("unmarshal_model_mapping_failed")
- }
-
- // 支持链式模型重定向,最终使用链尾的模型
- currentModel := info.OriginModelName
- visitedModels := map[string]bool{
- currentModel: true,
- }
- for {
- if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
- // 模型重定向循环检测,避免无限循环
- if visitedModels[mappedModel] {
- if mappedModel == currentModel {
- if currentModel == info.OriginModelName {
- info.IsModelMapped = false
- return nil
- } else {
- info.IsModelMapped = true
- break
- }
- }
- return errors.New("model_mapping_contains_cycle")
- }
- visitedModels[mappedModel] = true
- currentModel = mappedModel
- info.IsModelMapped = true
- } else {
- break
- }
- }
- if info.IsModelMapped {
- info.UpstreamModelName = currentModel
- }
- }
- if request != nil {
- request.SetModelName(info.UpstreamModelName)
- }
- return nil
-}
diff --git a/new-api/relay/helper/price.go b/new-api/relay/helper/price.go
deleted file mode 100644
index 63e3a596728f21843e2c2503ff9604b0875b6679..0000000000000000000000000000000000000000
--- a/new-api/relay/helper/price.go
+++ /dev/null
@@ -1,143 +0,0 @@
-package helper
-
-import (
- "fmt"
- "one-api/common"
- relaycommon "one-api/relay/common"
- "one-api/setting/ratio_setting"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
-func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
- groupRatioInfo := types.GroupRatioInfo{
- GroupRatio: 1.0, // default ratio
- GroupSpecialRatio: -1,
- }
-
- // check auto group
- autoGroup, exists := ctx.Get("auto_group")
- if exists {
- if common.DebugEnabled {
- println(fmt.Sprintf("final group: %s", autoGroup))
- }
- relayInfo.UsingGroup = autoGroup.(string)
- }
-
- // check user group special ratio
- userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
- if ok {
- // user group special ratio
- groupRatioInfo.GroupSpecialRatio = userGroupRatio
- groupRatioInfo.GroupRatio = userGroupRatio
- groupRatioInfo.HasSpecialRatio = true
- } else {
- // normal group ratio
- groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
- }
-
- return groupRatioInfo
-}
-
-func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
- modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
-
- groupRatioInfo := HandleGroupRatio(c, info)
-
- var preConsumedQuota int
- var modelRatio float64
- var completionRatio float64
- var cacheRatio float64
- var imageRatio float64
- var cacheCreationRatio float64
- var audioRatio float64
- var audioCompletionRatio float64
- if !usePrice {
- preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
- if meta.MaxTokens != 0 {
- preConsumedTokens += meta.MaxTokens
- }
- var success bool
- var matchName string
- modelRatio, success, matchName = ratio_setting.GetModelRatio(info.OriginModelName)
- if !success {
- acceptUnsetRatio := false
- if info.UserSetting.AcceptUnsetRatioModel {
- acceptUnsetRatio = true
- }
- if !acceptUnsetRatio {
- return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
- }
- }
- completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
- cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
- cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
- imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
- audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
- audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
- ratio := modelRatio * groupRatioInfo.GroupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- if meta.ImagePriceRatio != 0 {
- modelPrice = modelPrice * meta.ImagePriceRatio
- }
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
- }
-
- priceData := types.PriceData{
- ModelPrice: modelPrice,
- ModelRatio: modelRatio,
- CompletionRatio: completionRatio,
- GroupRatioInfo: groupRatioInfo,
- UsePrice: usePrice,
- CacheRatio: cacheRatio,
- ImageRatio: imageRatio,
- AudioRatio: audioRatio,
- AudioCompletionRatio: audioCompletionRatio,
- CacheCreationRatio: cacheCreationRatio,
- ShouldPreConsumedQuota: preConsumedQuota,
- }
-
- if common.DebugEnabled {
- println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
- }
- info.PriceData = priceData
- return priceData, nil
-}
-
-// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
- groupRatioInfo := HandleGroupRatio(c, info)
-
- modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
- // 如果没有配置价格,则使用默认价格
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
- priceData := types.PerCallPriceData{
- ModelPrice: modelPrice,
- Quota: quota,
- GroupRatioInfo: groupRatioInfo,
- }
- return priceData
-}
-
-func ContainPriceOrRatio(modelName string) bool {
- _, ok := ratio_setting.GetModelPrice(modelName, false)
- if ok {
- return true
- }
- _, ok, _ = ratio_setting.GetModelRatio(modelName)
- if ok {
- return true
- }
- return false
-}
diff --git a/new-api/relay/helper/stream_scanner.go b/new-api/relay/helper/stream_scanner.go
deleted file mode 100644
index 14778286c3379fc4a175571c2424576b927ecd9d..0000000000000000000000000000000000000000
--- a/new-api/relay/helper/stream_scanner.go
+++ /dev/null
@@ -1,262 +0,0 @@
-package helper
-
-import (
- "bufio"
- "context"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/setting/operation_setting"
- "strings"
- "sync"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
-
- "github.com/gin-gonic/gin"
-)
-
-const (
- InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
- MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
- DefaultPingInterval = 10 * time.Second
-)
-
-func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
-
- if resp == nil || dataHandler == nil {
- return
- }
-
- // 确保响应体总是被关闭
- defer func() {
- if resp.Body != nil {
- resp.Body.Close()
- }
- }()
-
- streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
-
- var (
- stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
- scanner = bufio.NewScanner(resp.Body)
- ticker = time.NewTicker(streamingTimeout)
- pingTicker *time.Ticker
- writeMutex sync.Mutex // Mutex to protect concurrent writes
- wg sync.WaitGroup // 用于等待所有 goroutine 退出
- )
-
- generalSettings := operation_setting.GetGeneralSetting()
- pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
- pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
- if pingInterval <= 0 {
- pingInterval = DefaultPingInterval
- }
-
- if pingEnabled {
- pingTicker = time.NewTicker(pingInterval)
- }
-
- if common.DebugEnabled {
- // print timeout and ping interval for debugging
- println("relay timeout seconds:", common.RelayTimeout)
- println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
- println("ping interval seconds:", int64(pingInterval.Seconds()))
- }
-
- // 改进资源清理,确保所有 goroutine 正确退出
- defer func() {
- // 通知所有 goroutine 停止
- common.SafeSendBool(stopChan, true)
-
- ticker.Stop()
- if pingTicker != nil {
- pingTicker.Stop()
- }
-
- // 等待所有 goroutine 退出,最多等待5秒
- done := make(chan struct{})
- go func() {
- wg.Wait()
- close(done)
- }()
-
- select {
- case <-done:
- case <-time.After(5 * time.Second):
- logger.LogError(c, "timeout waiting for goroutines to exit")
- }
-
- close(stopChan)
- }()
-
- scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
- scanner.Split(bufio.ScanLines)
- SetEventStreamHeaders(c)
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- ctx = context.WithValue(ctx, "stop_chan", stopChan)
-
- // Handle ping data sending with improved error handling
- if pingEnabled && pingTicker != nil {
- wg.Add(1)
- gopool.Go(func() {
- defer func() {
- wg.Done()
- if r := recover(); r != nil {
- logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
- common.SafeSendBool(stopChan, true)
- }
- if common.DebugEnabled {
- println("ping goroutine exited")
- }
- }()
-
- // 添加超时保护,防止 goroutine 无限运行
- maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
- pingTimeout := time.NewTimer(maxPingDuration)
- defer pingTimeout.Stop()
-
- for {
- select {
- case <-pingTicker.C:
- // 使用超时机制防止写操作阻塞
- done := make(chan error, 1)
- go func() {
- writeMutex.Lock()
- defer writeMutex.Unlock()
- done <- PingData(c)
- }()
-
- select {
- case err := <-done:
- if err != nil {
- logger.LogError(c, "ping data error: "+err.Error())
- return
- }
- if common.DebugEnabled {
- println("ping data sent")
- }
- case <-time.After(10 * time.Second):
- logger.LogError(c, "ping data send timeout")
- return
- case <-ctx.Done():
- return
- case <-stopChan:
- return
- }
- case <-ctx.Done():
- return
- case <-stopChan:
- return
- case <-c.Request.Context().Done():
- // 监听客户端断开连接
- return
- case <-pingTimeout.C:
- logger.LogError(c, "ping goroutine max duration reached")
- return
- }
- }
- })
- }
-
- // Scanner goroutine with improved error handling
- wg.Add(1)
- common.RelayCtxGo(ctx, func() {
- defer func() {
- wg.Done()
- if r := recover(); r != nil {
- logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
- }
- common.SafeSendBool(stopChan, true)
- if common.DebugEnabled {
- println("scanner goroutine exited")
- }
- }()
-
- for scanner.Scan() {
- // 检查是否需要停止
- select {
- case <-stopChan:
- return
- case <-ctx.Done():
- return
- case <-c.Request.Context().Done():
- return
- default:
- }
-
- ticker.Reset(streamingTimeout)
- data := scanner.Text()
- if common.DebugEnabled {
- println(data)
- }
-
- if len(data) < 6 {
- continue
- }
- if data[:5] != "data:" && data[:6] != "[DONE]" {
- continue
- }
- data = data[5:]
- data = strings.TrimLeft(data, " ")
- data = strings.TrimSuffix(data, "\r")
- if !strings.HasPrefix(data, "[DONE]") {
- info.SetFirstResponseTime()
-
- // 使用超时机制防止写操作阻塞
- done := make(chan bool, 1)
- go func() {
- writeMutex.Lock()
- defer writeMutex.Unlock()
- done <- dataHandler(data)
- }()
-
- select {
- case success := <-done:
- if !success {
- return
- }
- case <-time.After(10 * time.Second):
- logger.LogError(c, "data handler timeout")
- return
- case <-ctx.Done():
- return
- case <-stopChan:
- return
- }
- } else {
- // done, 处理完成标志,直接退出停止读取剩余数据防止出错
- if common.DebugEnabled {
- println("received [DONE], stopping scanner")
- }
- return
- }
- }
-
- if err := scanner.Err(); err != nil {
- if err != io.EOF {
- logger.LogError(c, "scanner error: "+err.Error())
- }
- }
- })
-
- // 主循环等待完成或超时
- select {
- case <-ticker.C:
- // 超时处理逻辑
- logger.LogError(c, "streaming timeout")
- case <-stopChan:
- // 正常结束
- logger.LogInfo(c, "streaming finished")
- case <-c.Request.Context().Done():
- // 客户端断开连接
- logger.LogInfo(c, "client disconnected")
- }
-}
diff --git a/new-api/relay/helper/valid_request.go b/new-api/relay/helper/valid_request.go
deleted file mode 100644
index 66213b01fdfd3d639163faf4a9189a9677410fde..0000000000000000000000000000000000000000
--- a/new-api/relay/helper/valid_request.go
+++ /dev/null
@@ -1,318 +0,0 @@
-package helper
-
-import (
- "errors"
- "fmt"
- "math"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relayconstant "one-api/relay/constant"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
- relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
-
- switch format {
- case types.RelayFormatOpenAI:
- request, err = GetAndValidateTextRequest(c, relayMode)
- case types.RelayFormatGemini:
- if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
- request, err = GetAndValidateGeminiEmbeddingRequest(c)
- } else {
- request, err = GetAndValidateGeminiRequest(c)
- }
- case types.RelayFormatClaude:
- request, err = GetAndValidateClaudeRequest(c)
- case types.RelayFormatOpenAIResponses:
- request, err = GetAndValidateResponsesRequest(c)
-
- case types.RelayFormatOpenAIImage:
- request, err = GetAndValidOpenAIImageRequest(c, relayMode)
- case types.RelayFormatEmbedding:
- request, err = GetAndValidateEmbeddingRequest(c, relayMode)
- case types.RelayFormatRerank:
- request, err = GetAndValidateRerankRequest(c)
- case types.RelayFormatOpenAIAudio:
- request, err = GetAndValidAudioRequest(c, relayMode)
- case types.RelayFormatOpenAIRealtime:
- request = &dto.BaseRequest{}
- default:
- return nil, fmt.Errorf("unsupported relay format: %s", format)
- }
- return request, err
-}
-
-func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
- audioRequest := &dto.AudioRequest{}
- err := common.UnmarshalBodyReusable(c, audioRequest)
- if err != nil {
- return nil, err
- }
- switch relayMode {
- case relayconstant.RelayModeAudioSpeech:
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- default:
- err = c.Request.ParseForm()
- if err != nil {
- return nil, err
- }
- formData := c.Request.PostForm
- if audioRequest.Model == "" {
- audioRequest.Model = formData.Get("model")
- }
-
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- audioRequest.ResponseFormat = formData.Get("response_format")
- if audioRequest.ResponseFormat == "" {
- audioRequest.ResponseFormat = "json"
- }
- }
- return audioRequest, nil
-}
-
-func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
- var rerankRequest *dto.RerankRequest
- err := common.UnmarshalBodyReusable(c, &rerankRequest)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- if rerankRequest.Query == "" {
- return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
- if len(rerankRequest.Documents) == 0 {
- return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
- return rerankRequest, nil
-}
-
-func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
- var embeddingRequest *dto.EmbeddingRequest
- err := common.UnmarshalBodyReusable(c, &embeddingRequest)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- if embeddingRequest.Input == nil {
- return nil, fmt.Errorf("input is empty")
- }
- if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
- embeddingRequest.Model = "omni-moderation-latest"
- }
- if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
- embeddingRequest.Model = c.Param("model")
- }
- return embeddingRequest, nil
-}
-
-func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
- request := &dto.OpenAIResponsesRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if request.Model == "" {
- return nil, errors.New("model is required")
- }
- if request.Input == nil {
- return nil, errors.New("input is required")
- }
- return request, nil
-}
-
-func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
- imageRequest := &dto.ImageRequest{}
-
- switch relayMode {
- case relayconstant.RelayModeImagesEdits:
- if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
- _, err := c.MultipartForm()
- if err != nil {
- return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
- }
- formData := c.Request.PostForm
- imageRequest.Prompt = formData.Get("prompt")
- imageRequest.Model = formData.Get("model")
- imageRequest.N = uint(common.String2Int(formData.Get("n")))
- imageRequest.Quality = formData.Get("quality")
- imageRequest.Size = formData.Get("size")
-
- if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
-
- watermark := formData.Has("watermark")
- if watermark {
- imageRequest.Watermark = &watermark
- }
- break
- }
- fallthrough
- default:
- err := common.UnmarshalBodyReusable(c, imageRequest)
- if err != nil {
- return nil, err
- }
-
- if imageRequest.Model == "" {
- //imageRequest.Model = "dall-e-3"
- return nil, errors.New("model is required")
- }
-
- if strings.Contains(imageRequest.Size, "×") {
- return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
- }
-
- // Not "256x256", "512x512", or "1024x1024"
- if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
- if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
- return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "dall-e-3" {
- if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
- return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
- }
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "auto"
- }
- }
-
- //if imageRequest.Prompt == "" {
- // return nil, errors.New("prompt is required")
- //}
-
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- }
-
- return imageRequest, nil
-}
-
-func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
- textRequest = &dto.ClaudeRequest{}
- err = c.ShouldBindJSON(textRequest)
- if err != nil {
- return nil, err
- }
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- if textRequest.Model == "" {
- return nil, errors.New("field model is required")
- }
-
- //if textRequest.Stream {
- // relayInfo.IsStream = true
- //}
-
- return textRequest, nil
-}
-
-func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
- textRequest := &dto.GeneralOpenAIRequest{}
- err := common.UnmarshalBodyReusable(c, textRequest)
- if err != nil {
- return nil, err
- }
-
- if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
- textRequest.Model = c.Param("model")
- }
-
- if textRequest.MaxTokens > math.MaxInt32/2 {
- return nil, errors.New("max_tokens is invalid")
- }
- if textRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if textRequest.WebSearchOptions != nil {
- if textRequest.WebSearchOptions.SearchContextSize != "" {
- validSizes := map[string]bool{
- "high": true,
- "medium": true,
- "low": true,
- }
- if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
- return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
- }
- } else {
- textRequest.WebSearchOptions.SearchContextSize = "medium"
- }
- }
- switch relayMode {
- case relayconstant.RelayModeCompletions:
- if textRequest.Prompt == "" {
- return nil, errors.New("field prompt is required")
- }
- case relayconstant.RelayModeChatCompletions:
- if len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- case relayconstant.RelayModeEmbeddings:
- case relayconstant.RelayModeModerations:
- if textRequest.Input == nil || textRequest.Input == "" {
- return nil, errors.New("field input is required")
- }
- case relayconstant.RelayModeEdits:
- if textRequest.Instruction == "" {
- return nil, errors.New("field instruction is required")
- }
- }
- return textRequest, nil
-}
-
-func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
- request := &dto.GeminiChatRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if len(request.Contents) == 0 {
- return nil, errors.New("contents is required")
- }
-
- //if c.Query("alt") == "sse" {
- // relayInfo.IsStream = true
- //}
-
- return request, nil
-}
-
-func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
- request := &dto.GeminiEmbeddingRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- return request, nil
-}
diff --git a/new-api/relay/image_handler.go b/new-api/relay/image_handler.go
deleted file mode 100644
index 98db3e856c459b79db00d9fdfd0e883479f6e737..0000000000000000000000000000000000000000
--- a/new-api/relay/image_handler.go
+++ /dev/null
@@ -1,128 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- imageReq, ok := info.Request.(*dto.ImageRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(imageReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- var requestBody io.Reader
-
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed)
- }
-
- switch convertedRequest.(type) {
- case *bytes.Buffer:
- requestBody = convertedRequest.(io.Reader)
- default:
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- if common.DebugEnabled {
- logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData)))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
-
- if usage.(*dto.Usage).TotalTokens == 0 {
- usage.(*dto.Usage).TotalTokens = int(request.N)
- }
- if usage.(*dto.Usage).PromptTokens == 0 {
- usage.(*dto.Usage).PromptTokens = int(request.N)
- }
-
- quality := "standard"
- if request.Quality == "hd" {
- quality = "hd"
- }
-
- var logContent string
-
- if len(request.Size) > 0 {
- logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
- }
-
- postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
- return nil
-}
diff --git a/new-api/relay/mjproxy_handler.go b/new-api/relay/mjproxy_handler.go
deleted file mode 100644
index 08d390e7db3e88d42e3dca2307e4dc3ce24a67d1..0000000000000000000000000000000000000000
--- a/new-api/relay/mjproxy_handler.go
+++ /dev/null
@@ -1,659 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "log"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/system_setting"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayMidjourneyImage(c *gin.Context) {
- taskId := c.Param("id")
- midjourneyTask := model.GetByOnlyMJId(taskId)
- if midjourneyTask == nil {
- c.JSON(400, gin.H{
- "error": "midjourney_task_not_found",
- })
- return
- }
- var httpClient *http.Client
- if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
- proxy := channel.GetSetting().Proxy
- if proxy != "" {
- if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
- c.JSON(400, gin.H{
- "error": "proxy_url_invalid",
- })
- return
- }
- }
- }
- if httpClient == nil {
- httpClient = service.GetHttpClient()
- }
- resp, err := httpClient.Get(midjourneyTask.ImageUrl)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": "http_get_image_failed",
- })
- return
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- c.JSON(resp.StatusCode, gin.H{
- "error": string(responseBody),
- })
- return
- }
- // 从Content-Type头获取MIME类型
- contentType := resp.Header.Get("Content-Type")
- if contentType == "" {
- // 如果无法确定内容类型,则默认为jpeg
- contentType = "image/jpeg"
- }
- // 设置响应的内容类型
- c.Writer.Header().Set("Content-Type", contentType)
- // 将图片流式传输到响应体
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- log.Println("Failed to stream image:", err)
- }
- return
-}
-
-func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
- var midjRequest dto.MidjourneyDto
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "bind_request_body_failed",
- Properties: nil,
- Result: "",
- }
- }
- midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
- if midjourneyTask == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "midjourney_task_not_found",
- Properties: nil,
- Result: "",
- }
- }
- midjourneyTask.Progress = midjRequest.Progress
- midjourneyTask.PromptEn = midjRequest.PromptEn
- midjourneyTask.State = midjRequest.State
- midjourneyTask.SubmitTime = midjRequest.SubmitTime
- midjourneyTask.StartTime = midjRequest.StartTime
- midjourneyTask.FinishTime = midjRequest.FinishTime
- midjourneyTask.ImageUrl = midjRequest.ImageUrl
- midjourneyTask.VideoUrl = midjRequest.VideoUrl
- videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
- midjourneyTask.VideoUrls = string(videoUrlsStr)
- midjourneyTask.Status = midjRequest.Status
- midjourneyTask.FailReason = midjRequest.FailReason
- err = midjourneyTask.Update()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "update_midjourney_task_failed",
- }
- }
-
- return nil
-}
-
-func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
- midjourneyTask.MjId = originTask.MjId
- midjourneyTask.Progress = originTask.Progress
- midjourneyTask.PromptEn = originTask.PromptEn
- midjourneyTask.State = originTask.State
- midjourneyTask.SubmitTime = originTask.SubmitTime
- midjourneyTask.StartTime = originTask.StartTime
- midjourneyTask.FinishTime = originTask.FinishTime
- midjourneyTask.ImageUrl = ""
- if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
- midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
- if originTask.Status != "SUCCESS" {
- midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
- }
- } else {
- midjourneyTask.ImageUrl = originTask.ImageUrl
- }
- if originTask.VideoUrl != "" {
- midjourneyTask.VideoUrl = originTask.VideoUrl
- }
- midjourneyTask.Status = originTask.Status
- midjourneyTask.FailReason = originTask.FailReason
- midjourneyTask.Action = originTask.Action
- midjourneyTask.Description = originTask.Description
- midjourneyTask.Prompt = originTask.Prompt
- if originTask.Buttons != "" {
- var buttons []dto.ActionButton
- err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
- if err == nil {
- midjourneyTask.Buttons = buttons
- }
- }
- if originTask.VideoUrls != "" {
- var videoUrls []dto.ImgUrls
- err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
- if err == nil {
- midjourneyTask.VideoUrls = videoUrls
- }
- }
- if originTask.Properties != "" {
- var properties dto.Properties
- err := json.Unmarshal([]byte(originTask.Properties), &properties)
- if err == nil {
- midjourneyTask.Properties = &properties
- }
- }
- return
-}
-
-func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
- var swapFaceRequest dto.SwapFaceRequest
- err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
- }
-
- info.InitChannelMeta(c)
-
- if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
- }
- modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
-
- priceData := helper.ModelPriceHelperPerCall(c, info)
-
- userQuota, err := model.GetUserQuota(info.UserId, false)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: err.Error(),
- }
- }
-
- if userQuota-priceData.Quota < 0 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "quota_not_enough",
- }
- }
- requestURL := getMjRequestPath(c.Request.URL.String())
- baseURL := c.GetString("base_url")
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
- mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
- if err != nil {
- return &mjResp.Response
- }
- defer func() {
- if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
- err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
- if err != nil {
- common.SysLog("error consuming token remain quota: " + err.Error())
- }
-
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
- other := service.GenerateMjOtherInfo(priceData)
- model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
- ChannelId: info.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: priceData.Quota,
- Content: logContent,
- TokenId: info.TokenId,
- Group: info.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
- model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
- }
- }()
- midjResponse := &mjResp.Response
- midjourneyTask := &model.Midjourney{
- UserId: info.UserId,
- Code: midjResponse.Code,
- Action: constant.MjActionSwapFace,
- MjId: midjResponse.Result,
- Prompt: "InsightFace",
- PromptEn: "",
- Description: midjResponse.Description,
- State: "",
- SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
- StartTime: time.Now().UnixNano() / int64(time.Millisecond),
- FinishTime: 0,
- ImageUrl: "",
- Status: "",
- Progress: "0%",
- FailReason: "",
- ChannelId: c.GetInt("channel_id"),
- Quota: priceData.Quota,
- }
- err = midjourneyTask.Insert()
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
- }
- c.Writer.WriteHeader(mjResp.StatusCode)
- respBody, err := json.Marshal(midjResponse)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
- }
- _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
- }
- return nil
-}
-
-func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
- taskId := c.Param("id")
- userId := c.GetInt("id")
- originTask := model.GetByMJId(userId, taskId)
- if originTask == nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
- }
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
- }
- if channel.Status != common.ChannelStatusEnabled {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
- }
- c.Set("channel_id", originTask.ChannelId)
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-
- requestURL := getMjRequestPath(c.Request.URL.String())
- fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
- midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
- if err != nil {
- return &midjResponseWithStatus.Response
- }
- midjResponse := &midjResponseWithStatus.Response
- c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
- respBody, err := json.Marshal(midjResponse)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
- }
- service.IOCopyBytesGracefully(c, nil, respBody)
- return nil
-}
-
-func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
- userId := c.GetInt("id")
- var err error
- var respBody []byte
- switch relayMode {
- case relayconstant.RelayModeMidjourneyTaskFetch:
- taskId := c.Param("id")
- originTask := model.GetByMJId(userId, taskId)
- if originTask == nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "task_no_found",
- }
- }
- midjourneyTask := coverMidjourneyTaskDto(c, originTask)
- respBody, err = json.Marshal(midjourneyTask)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_response_body_failed",
- }
- }
- case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
- var condition = struct {
- IDs []string `json:"ids"`
- }{}
- err = c.BindJSON(&condition)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "do_request_failed",
- }
- }
- var tasks []dto.MidjourneyDto
- if len(condition.IDs) != 0 {
- originTasks := model.GetByMJIds(userId, condition.IDs)
- for _, originTask := range originTasks {
- midjourneyTask := coverMidjourneyTaskDto(c, originTask)
- tasks = append(tasks, midjourneyTask)
- }
- }
- if tasks == nil {
- tasks = make([]dto.MidjourneyDto, 0)
- }
- respBody, err = json.Marshal(tasks)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "unmarshal_response_body_failed",
- }
- }
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
-
- _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "copy_response_body_failed",
- }
- }
- return nil
-}
-
-func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
- consumeQuota := true
- var midjRequest dto.MidjourneyRequest
- err := common.UnmarshalBodyReusable(c, &midjRequest)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
- }
-
- relayInfo.InitChannelMeta(c)
-
- if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
- mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
- if mjErr != nil {
- return mjErr
- }
- relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
- }
- if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
- midjRequest.Action = constant.MjActionVideo
- }
-
- if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
- if midjRequest.Prompt == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
- }
- midjRequest.Action = constant.MjActionImagine
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
- midjRequest.Action = constant.MjActionDescribe
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
- midjRequest.Action = constant.MjActionEdits
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
- midjRequest.Action = constant.MjActionShorten
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
- midjRequest.Action = constant.MjActionBlend
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
- midjRequest.Action = constant.MjActionUpload
- } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
- mjId := ""
- if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
- if midjRequest.TaskId == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
- } else if midjRequest.Action == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
- } else if midjRequest.Index == 0 {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
- }
- //action = midjRequest.Action
- mjId = midjRequest.TaskId
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
- if midjRequest.Content == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
- }
- params := service.ConvertSimpleChangeParams(midjRequest.Content)
- if params == nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
- }
- mjId = params.TaskId
- midjRequest.Action = params.Action
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
- //if midjRequest.MaskBase64 == "" {
- // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
- //}
- mjId = midjRequest.TaskId
- midjRequest.Action = constant.MjActionModal
- } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
- midjRequest.Action = constant.MjActionVideo
- if midjRequest.TaskId == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
- } else if midjRequest.Action == "" {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
- }
- mjId = midjRequest.TaskId
- }
-
- originTask := model.GetByMJId(relayInfo.UserId, mjId)
- if originTask == nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
- } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
- if setting.MjActionCheckSuccessEnabled {
- if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
- }
- }
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
- }
- if channel.Status != common.ChannelStatusEnabled {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
- }
- c.Set("base_url", channel.GetBaseURL())
- c.Set("channel_id", originTask.ChannelId)
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
- }
- midjRequest.Prompt = originTask.Prompt
-
- //if channelType == common.ChannelTypeMidjourneyPlus {
- // // plus
- //} else {
- // // 普通版渠道
- //
- //}
- }
-
- if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
- consumeQuota = false
- }
-
- //baseURL := common.ChannelBaseURLs[channelType]
- requestURL := getMjRequestPath(c.Request.URL.String())
-
- baseURL := c.GetString("base_url")
-
- //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
-
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
- modelName := service.CoverActionToModelName(midjRequest.Action)
-
- priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
-
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: err.Error(),
- }
- }
-
- if consumeQuota && userQuota-priceData.Quota < 0 {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "quota_not_enough",
- }
- }
-
- midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
- if err != nil {
- return &midjResponseWithStatus.Response
- }
- midjResponse := &midjResponseWithStatus.Response
-
- defer func() {
- if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
- err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
- if err != nil {
- common.SysLog("error consuming token remain quota: " + err.Error())
- }
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
- other := service.GenerateMjOtherInfo(priceData)
- model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: priceData.Quota,
- Content: logContent,
- TokenId: relayInfo.TokenId,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
- }
- }()
-
- // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
- //1-提交成功
- // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
- // 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
- // 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
- // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
- // other: 提交错误,description为错误描述
- midjourneyTask := &model.Midjourney{
- UserId: relayInfo.UserId,
- Code: midjResponse.Code,
- Action: midjRequest.Action,
- MjId: midjResponse.Result,
- Prompt: midjRequest.Prompt,
- PromptEn: "",
- Description: midjResponse.Description,
- State: "",
- SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
- StartTime: 0,
- FinishTime: 0,
- ImageUrl: "",
- Status: "",
- Progress: "0%",
- FailReason: "",
- ChannelId: c.GetInt("channel_id"),
- Quota: priceData.Quota,
- }
- if midjResponse.Code == 3 {
- //无实例账号自动禁用渠道(No available account instance)
- channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
- if err != nil {
- common.SysLog("get_channel_null: " + err.Error())
- }
- if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
- model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
- }
- }
- if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
- //非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
- midjourneyTask.FailReason = midjResponse.Description
- consumeQuota = false
- }
-
- if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
- // 将 properties 转换为一个 map
- properties, ok := midjResponse.Properties.(map[string]interface{})
- if ok {
- imageUrl, ok1 := properties["imageUrl"].(string)
- status, ok2 := properties["status"].(string)
- if ok1 && ok2 {
- midjourneyTask.ImageUrl = imageUrl
- midjourneyTask.Status = status
- if status == "SUCCESS" {
- midjourneyTask.Progress = "100%"
- midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
- midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
- midjResponse.Code = 1
- }
- }
- }
- //修改返回值
- if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
- newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
- responseBody = []byte(newBody)
- }
- }
- if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
- midjourneyTask.Progress = "100%"
- midjourneyTask.Status = "SUCCESS"
- }
- err = midjourneyTask.Insert()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "insert_midjourney_task_failed",
- }
- }
-
- if midjResponse.Code == 22 { //22-排队中,说明任务已存在
- //修改返回值
- newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
- responseBody = []byte(newBody)
- }
- //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
-
- //for k, v := range resp.Header {
- // c.Writer.Header().Set(k, v[0])
- //}
- c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
-
- _, err = io.Copy(c.Writer, bodyReader)
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "copy_response_body_failed",
- }
- }
- err = bodyReader.Close()
- if err != nil {
- return &dto.MidjourneyResponse{
- Code: 4,
- Description: "close_response_body_failed",
- }
- }
- return nil
-}
-
-type taskChangeParams struct {
- ID string
- Action string
- Index int
-}
-
-func getMjRequestPath(path string) string {
- requestURL := path
- if strings.Contains(requestURL, "/mj-") {
- urls := strings.Split(requestURL, "/mj/")
- if len(urls) < 2 {
- return requestURL
- }
- requestURL = "/mj/" + urls[1]
- }
- return requestURL
-}
diff --git a/new-api/relay/relay_adaptor.go b/new-api/relay/relay_adaptor.go
deleted file mode 100644
index 828ec9b181b813471f73cfdc5897a48f988a780c..0000000000000000000000000000000000000000
--- a/new-api/relay/relay_adaptor.go
+++ /dev/null
@@ -1,140 +0,0 @@
-package relay
-
-import (
- "one-api/constant"
- "one-api/relay/channel"
- "one-api/relay/channel/ali"
- "one-api/relay/channel/aws"
- "one-api/relay/channel/baidu"
- "one-api/relay/channel/baidu_v2"
- "one-api/relay/channel/claude"
- "one-api/relay/channel/cloudflare"
- "one-api/relay/channel/cohere"
- "one-api/relay/channel/coze"
- "one-api/relay/channel/deepseek"
- "one-api/relay/channel/dify"
- "one-api/relay/channel/gemini"
- "one-api/relay/channel/jimeng"
- "one-api/relay/channel/jina"
- "one-api/relay/channel/mistral"
- "one-api/relay/channel/mokaai"
- "one-api/relay/channel/moonshot"
- "one-api/relay/channel/ollama"
- "one-api/relay/channel/openai"
- "one-api/relay/channel/palm"
- "one-api/relay/channel/perplexity"
- "one-api/relay/channel/siliconflow"
- taskjimeng "one-api/relay/channel/task/jimeng"
- "one-api/relay/channel/task/kling"
- "one-api/relay/channel/task/suno"
- taskvertex "one-api/relay/channel/task/vertex"
- taskVidu "one-api/relay/channel/task/vidu"
- "one-api/relay/channel/tencent"
- "one-api/relay/channel/vertex"
- "one-api/relay/channel/volcengine"
- "one-api/relay/channel/xai"
- "one-api/relay/channel/xunfei"
- "one-api/relay/channel/zhipu"
- "one-api/relay/channel/zhipu_4v"
- "strconv"
- "one-api/relay/channel/submodel"
- "github.com/gin-gonic/gin"
-)
-
-func GetAdaptor(apiType int) channel.Adaptor {
- switch apiType {
- case constant.APITypeAli:
- return &ali.Adaptor{}
- case constant.APITypeAnthropic:
- return &claude.Adaptor{}
- case constant.APITypeBaidu:
- return &baidu.Adaptor{}
- case constant.APITypeGemini:
- return &gemini.Adaptor{}
- case constant.APITypeOpenAI:
- return &openai.Adaptor{}
- case constant.APITypePaLM:
- return &palm.Adaptor{}
- case constant.APITypeTencent:
- return &tencent.Adaptor{}
- case constant.APITypeXunfei:
- return &xunfei.Adaptor{}
- case constant.APITypeZhipu:
- return &zhipu.Adaptor{}
- case constant.APITypeZhipuV4:
- return &zhipu_4v.Adaptor{}
- case constant.APITypeOllama:
- return &ollama.Adaptor{}
- case constant.APITypePerplexity:
- return &perplexity.Adaptor{}
- case constant.APITypeAws:
- return &aws.Adaptor{}
- case constant.APITypeCohere:
- return &cohere.Adaptor{}
- case constant.APITypeDify:
- return &dify.Adaptor{}
- case constant.APITypeJina:
- return &jina.Adaptor{}
- case constant.APITypeCloudflare:
- return &cloudflare.Adaptor{}
- case constant.APITypeSiliconFlow:
- return &siliconflow.Adaptor{}
- case constant.APITypeVertexAi:
- return &vertex.Adaptor{}
- case constant.APITypeMistral:
- return &mistral.Adaptor{}
- case constant.APITypeDeepSeek:
- return &deepseek.Adaptor{}
- case constant.APITypeMokaAI:
- return &mokaai.Adaptor{}
- case constant.APITypeVolcEngine:
- return &volcengine.Adaptor{}
- case constant.APITypeBaiduV2:
- return &baidu_v2.Adaptor{}
- case constant.APITypeOpenRouter:
- return &openai.Adaptor{}
- case constant.APITypeXinference:
- return &openai.Adaptor{}
- case constant.APITypeXai:
- return &xai.Adaptor{}
- case constant.APITypeCoze:
- return &coze.Adaptor{}
- case constant.APITypeJimeng:
- return &jimeng.Adaptor{}
- case constant.APITypeMoonshot:
- return &moonshot.Adaptor{} // Moonshot uses Claude API
- case constant.APITypeSubmodel:
- return &submodel.Adaptor{}
- }
- return nil
-}
-
-func GetTaskPlatform(c *gin.Context) constant.TaskPlatform {
- channelType := c.GetInt("channel_type")
- if channelType > 0 {
- return constant.TaskPlatform(strconv.Itoa(channelType))
- }
- return constant.TaskPlatform(c.GetString("platform"))
-}
-
-func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
- switch platform {
- //case constant.APITypeAIProxyLibrary:
- // return &aiproxy.Adaptor{}
- case constant.TaskPlatformSuno:
- return &suno.TaskAdaptor{}
- }
- if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil {
- switch channelType {
- case constant.ChannelTypeKling:
- return &kling.TaskAdaptor{}
- case constant.ChannelTypeJimeng:
- return &taskjimeng.TaskAdaptor{}
- case constant.ChannelTypeVertexAi:
- return &taskvertex.TaskAdaptor{}
- case constant.ChannelTypeVidu:
- return &taskVidu.TaskAdaptor{}
- }
- }
- return nil
-}
diff --git a/new-api/relay/relay_task.go b/new-api/relay/relay_task.go
deleted file mode 100644
index 18decc00552d121beded9321111a05bac0a5d9a0..0000000000000000000000000000000000000000
--- a/new-api/relay/relay_task.go
+++ /dev/null
@@ -1,386 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/service"
- "one-api/setting/ratio_setting"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-/*
-Task 任务通过平台、Action 区分任务
-*/
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- info.InitChannelMeta(c)
- // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
- if info.TaskRelayInfo == nil {
- info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
- }
- platform := constant.TaskPlatform(c.GetString("platform"))
- if platform == "" {
- platform = GetTaskPlatform(c)
- }
-
- info.InitChannelMeta(c)
- adaptor := GetTaskAdaptor(platform)
- if adaptor == nil {
- return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
- }
- adaptor.Init(info)
- // get & validate taskRequest 获取并验证文本请求
- taskErr = adaptor.ValidateRequestAndSetAction(c, info)
- if taskErr != nil {
- return
- }
-
- modelName := info.OriginModelName
- if modelName == "" {
- modelName = service.CoverTaskActionToModelName(platform, info.Action)
- }
- modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
- if !success {
- defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
-
- // 预扣
- groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
- var ratio float64
- userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
- if hasUserGroupRatio {
- ratio = modelPrice * userGroupRatio
- } else {
- ratio = modelPrice * groupRatio
- }
- userQuota, err := model.GetUserQuota(info.UserId, false)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
- return
- }
- quota := int(ratio * common.QuotaPerUnit)
- if userQuota-quota < 0 {
- taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
- return
- }
-
- if info.OriginTaskID != "" {
- originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
- if originTask.ChannelId != info.ChannelId {
- channel, err := model.GetChannelById(originTask.ChannelId, true)
- if err != nil {
- taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
- return
- }
- if channel.Status != common.ChannelStatusEnabled {
- return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
- }
- c.Set("base_url", channel.GetBaseURL())
- c.Set("channel_id", originTask.ChannelId)
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
-
- info.ChannelBaseUrl = channel.GetBaseURL()
- info.ChannelId = originTask.ChannelId
- }
- }
-
- // build body
- requestBody, err := adaptor.BuildRequestBody(c, info)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
- return
- }
- // do request
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- return
- }
- // handle response
- if resp != nil && resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
- return
- }
-
- defer func() {
- // release quota
- if info.ConsumeQuota && taskErr == nil {
-
- err := service.PostConsumeQuota(info, quota, 0, true)
- if err != nil {
- common.SysLog("error consuming token remain quota: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- gRatio := groupRatio
- if hasUserGroupRatio {
- gRatio = userGroupRatio
- }
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
- other := make(map[string]interface{})
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- if hasUserGroupRatio {
- other["user_group_ratio"] = userGroupRatio
- }
- model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
- ChannelId: info.ChannelId,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: info.TokenId,
- Group: info.UsingGroup,
- Other: other,
- })
- model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
- model.UpdateChannelUsedQuota(info.ChannelId, quota)
- }
- }
- }()
-
- taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
- if taskErr != nil {
- return
- }
- info.ConsumeQuota = true
- // insert task
- task := model.InitTask(platform, info)
- task.TaskID = taskID
- task.Quota = quota
- task.Data = taskData
- task.Action = info.Action
- err = task.Insert()
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
- return
- }
- return nil
-}
-
-var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
- relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
-}
-
-func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
- respBuilder, ok := fetchRespBuilders[relayMode]
- if !ok {
- taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
- }
-
- respBody, taskErr := respBuilder(c)
- if taskErr != nil {
- return taskErr
- }
- if len(respBody) == 0 {
- respBody = []byte("{\"code\":\"success\",\"data\":null}")
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
- _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- return
- }
- return
-}
-
-func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- userId := c.GetInt("id")
- var condition = struct {
- IDs []any `json:"ids"`
- Action string `json:"action"`
- }{}
- err := c.BindJSON(&condition)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
- return
- }
- var tasks []any
- if len(condition.IDs) > 0 {
- taskModels, err := model.GetByTaskIds(userId, condition.IDs)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
- return
- }
- for _, task := range taskModels {
- tasks = append(tasks, TaskModel2Dto(task))
- }
- } else {
- tasks = make([]any, 0)
- }
- respBody, err = json.Marshal(dto.TaskResponse[[]any]{
- Code: "success",
- Data: tasks,
- })
- return
-}
-
-func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("id")
- userId := c.GetInt("id")
-
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
-
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- return
-}
-
-func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
- taskId := c.Param("task_id")
- if taskId == "" {
- taskId = c.GetString("task_id")
- }
- userId := c.GetInt("id")
-
- originTask, exist, err := model.GetByTaskId(userId, taskId)
- if err != nil {
- taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
- return
- }
- if !exist {
- taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
- return
- }
-
- func() {
- channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
- if err2 != nil {
- return
- }
- if channelModel.Type != constant.ChannelTypeVertexAi {
- return
- }
- baseURL := constant.ChannelBaseURLs[channelModel.Type]
- if channelModel.GetBaseURL() != "" {
- baseURL = channelModel.GetBaseURL()
- }
- adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
- if adaptor == nil {
- return
- }
- resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
- "task_id": originTask.TaskID,
- "action": originTask.Action,
- })
- if err2 != nil || resp == nil {
- return
- }
- defer resp.Body.Close()
- body, err2 := io.ReadAll(resp.Body)
- if err2 != nil {
- return
- }
- ti, err2 := adaptor.ParseTaskResult(body)
- if err2 == nil && ti != nil {
- if ti.Status != "" {
- originTask.Status = model.TaskStatus(ti.Status)
- }
- if ti.Progress != "" {
- originTask.Progress = ti.Progress
- }
- if ti.Url != "" {
- originTask.FailReason = ti.Url
- }
- _ = originTask.Update()
- var raw map[string]any
- _ = json.Unmarshal(body, &raw)
- format := "mp4"
- if respObj, ok := raw["response"].(map[string]any); ok {
- if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
- if v0, ok := vids[0].(map[string]any); ok {
- if mt, ok := v0["mimeType"].(string); ok && mt != "" {
- if strings.Contains(mt, "mp4") {
- format = "mp4"
- } else {
- format = mt
- }
- }
- }
- }
- }
- status := "processing"
- switch originTask.Status {
- case model.TaskStatusSuccess:
- status = "succeeded"
- case model.TaskStatusFailure:
- status = "failed"
- case model.TaskStatusQueued, model.TaskStatusSubmitted:
- status = "queued"
- }
- out := map[string]any{
- "error": nil,
- "format": format,
- "metadata": nil,
- "status": status,
- "task_id": originTask.TaskID,
- "url": originTask.FailReason,
- }
- respBody, _ = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: out,
- })
- }
- }()
-
- if len(respBody) == 0 {
- respBody, err = json.Marshal(dto.TaskResponse[any]{
- Code: "success",
- Data: TaskModel2Dto(originTask),
- })
- }
- return
-}
-
-func TaskModel2Dto(task *model.Task) *dto.TaskDto {
- return &dto.TaskDto{
- TaskID: task.TaskID,
- Action: task.Action,
- Status: string(task.Status),
- FailReason: task.FailReason,
- SubmitTime: task.SubmitTime,
- StartTime: task.StartTime,
- FinishTime: task.FinishTime,
- Progress: task.Progress,
- Data: task.Data,
- }
-}
diff --git a/new-api/relay/rerank_handler.go b/new-api/relay/rerank_handler.go
deleted file mode 100644
index 3c6d3ea30e46cd08a6da0e67d7f07df433ac3644..0000000000000000000000000000000000000000
--- a/new-api/relay/rerank_handler.go
+++ /dev/null
@@ -1,99 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- rerankReq, ok := info.Request.(*dto.RerankRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(rerankReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
-
- var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
-
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- if common.DebugEnabled {
- println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- return nil
-}
diff --git a/new-api/relay/responses_handler.go b/new-api/relay/responses_handler.go
deleted file mode 100644
index ad421349f6c1fe9df881bcdb79a7c19a10db76a8..0000000000000000000000000000000000000000
--- a/new-api/relay/responses_handler.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package relay
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- responsesReq, ok := info.Request.(*dto.OpenAIResponsesRequest)
- if !ok {
- return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
-
- request, err := common.DeepCopy(responsesReq)
- if err != nil {
- return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
- }
-
- err = helper.ModelMappedHelper(c, info, request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
- }
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
- var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- // apply param override
- if len(info.ParamOverride) > 0 {
- jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
-
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, info, requestBody)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- if resp != nil {
- httpResp = resp.(*http.Response)
-
- if httpResp.StatusCode != http.StatusOK {
- newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- }
-
- usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
-
- if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
- } else {
- postConsumeQuota(c, info, usage.(*dto.Usage), "")
- }
- return nil
-}
diff --git a/new-api/relay/websocket.go b/new-api/relay/websocket.go
deleted file mode 100644
index 2612c6f79bc5d85cb5fc88b16162f461bbf28676..0000000000000000000000000000000000000000
--- a/new-api/relay/websocket.go
+++ /dev/null
@@ -1,45 +0,0 @@
-package relay
-
-import (
- "fmt"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/service"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
-)
-
-func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
- info.InitChannelMeta(c)
-
- adaptor := GetAdaptor(info.ApiType)
- if adaptor == nil {
- return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
- }
- adaptor.Init(info)
- //var requestBody io.Reader
- //firstWssRequest, _ := c.Get("first_wss_request")
- //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, info, nil)
- if err != nil {
- return types.NewError(err, types.ErrorCodeDoRequestFailed)
- }
-
- if resp != nil {
- info.TargetWs = resp.(*websocket.Conn)
- defer info.TargetWs.Close()
- }
-
- usage, newAPIError := adaptor.DoResponse(c, nil, info)
- if newAPIError != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(newAPIError, statusCodeMappingStr)
- return newAPIError
- }
- service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
- return nil
-}
diff --git a/new-api/router/api-router.go b/new-api/router/api-router.go
deleted file mode 100644
index c161c94f14b9822af2ed599098bdba9ceedbf241..0000000000000000000000000000000000000000
--- a/new-api/router/api-router.go
+++ /dev/null
@@ -1,252 +0,0 @@
-package router
-
-import (
- "one-api/controller"
- "one-api/middleware"
-
- "github.com/gin-contrib/gzip"
- "github.com/gin-gonic/gin"
-)
-
-func SetApiRouter(router *gin.Engine) {
- apiRouter := router.Group("/api")
- apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
- apiRouter.Use(middleware.GlobalAPIRateLimit())
- {
- apiRouter.GET("/setup", controller.GetSetup)
- apiRouter.POST("/setup", controller.PostSetup)
- apiRouter.GET("/status", controller.GetStatus)
- apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
- apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
- apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
- apiRouter.GET("/notice", controller.GetNotice)
- apiRouter.GET("/about", controller.GetAbout)
- //apiRouter.GET("/midjourney", controller.GetMidjourney)
- apiRouter.GET("/home_page_content", controller.GetHomePageContent)
- apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
- apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
- apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
- apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
- apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
- apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
- apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
- apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
- apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
- apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
- apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
- apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
- apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
- apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
-
- apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
-
- // Universal secure verification routes
- apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
- apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus)
-
- userRoute := apiRouter.Group("/user")
- {
- userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
- userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
- userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
- userRoute.POST("/passkey/login/begin", middleware.CriticalRateLimit(), controller.PasskeyLoginBegin)
- userRoute.POST("/passkey/login/finish", middleware.CriticalRateLimit(), controller.PasskeyLoginFinish)
- //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
- userRoute.GET("/logout", controller.Logout)
- userRoute.GET("/epay/notify", controller.EpayNotify)
- userRoute.GET("/groups", controller.GetUserGroups)
-
- selfRoute := userRoute.Group("/")
- selfRoute.Use(middleware.UserAuth())
- {
- selfRoute.GET("/self/groups", controller.GetUserGroups)
- selfRoute.GET("/self", controller.GetSelf)
- selfRoute.GET("/models", controller.GetUserModels)
- selfRoute.PUT("/self", controller.UpdateSelf)
- selfRoute.DELETE("/self", controller.DeleteSelf)
- selfRoute.GET("/token", controller.GenerateAccessToken)
- selfRoute.GET("/passkey", controller.PasskeyStatus)
- selfRoute.POST("/passkey/register/begin", controller.PasskeyRegisterBegin)
- selfRoute.POST("/passkey/register/finish", controller.PasskeyRegisterFinish)
- selfRoute.POST("/passkey/verify/begin", controller.PasskeyVerifyBegin)
- selfRoute.POST("/passkey/verify/finish", controller.PasskeyVerifyFinish)
- selfRoute.DELETE("/passkey", controller.PasskeyDelete)
- selfRoute.GET("/aff", controller.GetAffCode)
- selfRoute.GET("/topup/info", controller.GetTopUpInfo)
- selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
- selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
- selfRoute.POST("/amount", controller.RequestAmount)
- selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
- selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
- selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
- selfRoute.PUT("/setting", controller.UpdateUserSetting)
-
- // 2FA routes
- selfRoute.GET("/2fa/status", controller.Get2FAStatus)
- selfRoute.POST("/2fa/setup", controller.Setup2FA)
- selfRoute.POST("/2fa/enable", controller.Enable2FA)
- selfRoute.POST("/2fa/disable", controller.Disable2FA)
- selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
- }
-
- adminRoute := userRoute.Group("/")
- adminRoute.Use(middleware.AdminAuth())
- {
- adminRoute.GET("/", controller.GetAllUsers)
- adminRoute.GET("/search", controller.SearchUsers)
- adminRoute.GET("/:id", controller.GetUser)
- adminRoute.POST("/", controller.CreateUser)
- adminRoute.POST("/manage", controller.ManageUser)
- adminRoute.PUT("/", controller.UpdateUser)
- adminRoute.DELETE("/:id", controller.DeleteUser)
- adminRoute.DELETE("/:id/reset_passkey", controller.AdminResetPasskey)
-
- // Admin 2FA routes
- adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
- adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
- }
- }
- optionRoute := apiRouter.Group("/option")
- optionRoute.Use(middleware.RootAuth())
- {
- optionRoute.GET("/", controller.GetOptions)
- optionRoute.PUT("/", controller.UpdateOption)
- optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
- optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
- }
- ratioSyncRoute := apiRouter.Group("/ratio_sync")
- ratioSyncRoute.Use(middleware.RootAuth())
- {
- ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
- ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
- }
- channelRoute := apiRouter.Group("/channel")
- channelRoute.Use(middleware.AdminAuth())
- {
- channelRoute.GET("/", controller.GetAllChannels)
- channelRoute.GET("/search", controller.SearchChannels)
- channelRoute.GET("/models", controller.ChannelListModels)
- channelRoute.GET("/models_enabled", controller.EnabledListModels)
- channelRoute.GET("/:id", controller.GetChannel)
- channelRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), middleware.SecureVerificationRequired(), controller.GetChannelKey)
- channelRoute.GET("/test", controller.TestAllChannels)
- channelRoute.GET("/test/:id", controller.TestChannel)
- channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
- channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
- channelRoute.POST("/", controller.AddChannel)
- channelRoute.PUT("/", controller.UpdateChannel)
- channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
- channelRoute.POST("/tag/disabled", controller.DisableTagChannels)
- channelRoute.POST("/tag/enabled", controller.EnableTagChannels)
- channelRoute.PUT("/tag", controller.EditTagChannels)
- channelRoute.DELETE("/:id", controller.DeleteChannel)
- channelRoute.POST("/batch", controller.DeleteChannelBatch)
- channelRoute.POST("/fix", controller.FixChannelsAbilities)
- channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
- channelRoute.POST("/fetch_models", controller.FetchModels)
- channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
- channelRoute.GET("/tag/models", controller.GetTagModels)
- channelRoute.POST("/copy/:id", controller.CopyChannel)
- channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys)
- }
- tokenRoute := apiRouter.Group("/token")
- tokenRoute.Use(middleware.UserAuth())
- {
- tokenRoute.GET("/", controller.GetAllTokens)
- tokenRoute.GET("/search", controller.SearchTokens)
- tokenRoute.GET("/:id", controller.GetToken)
- tokenRoute.POST("/", controller.AddToken)
- tokenRoute.PUT("/", controller.UpdateToken)
- tokenRoute.DELETE("/:id", controller.DeleteToken)
- tokenRoute.POST("/batch", controller.DeleteTokenBatch)
- }
-
- usageRoute := apiRouter.Group("/usage")
- usageRoute.Use(middleware.CriticalRateLimit())
- {
- tokenUsageRoute := usageRoute.Group("/token")
- tokenUsageRoute.Use(middleware.TokenAuth())
- {
- tokenUsageRoute.GET("/", controller.GetTokenUsage)
- }
- }
-
- redemptionRoute := apiRouter.Group("/redemption")
- redemptionRoute.Use(middleware.AdminAuth())
- {
- redemptionRoute.GET("/", controller.GetAllRedemptions)
- redemptionRoute.GET("/search", controller.SearchRedemptions)
- redemptionRoute.GET("/:id", controller.GetRedemption)
- redemptionRoute.POST("/", controller.AddRedemption)
- redemptionRoute.PUT("/", controller.UpdateRedemption)
- redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
- redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
- }
- logRoute := apiRouter.Group("/log")
- logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
- logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
- logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
- logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
- logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
- logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
- logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
-
- dataRoute := apiRouter.Group("/data")
- dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
- dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
-
- logRoute.Use(middleware.CORS())
- {
- logRoute.GET("/token", controller.GetLogByKey)
- }
- groupRoute := apiRouter.Group("/group")
- groupRoute.Use(middleware.AdminAuth())
- {
- groupRoute.GET("/", controller.GetGroups)
- }
-
- prefillGroupRoute := apiRouter.Group("/prefill_group")
- prefillGroupRoute.Use(middleware.AdminAuth())
- {
- prefillGroupRoute.GET("/", controller.GetPrefillGroups)
- prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
- prefillGroupRoute.PUT("/", controller.UpdatePrefillGroup)
- prefillGroupRoute.DELETE("/:id", controller.DeletePrefillGroup)
- }
-
- mjRoute := apiRouter.Group("/mj")
- mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
- mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
-
- taskRoute := apiRouter.Group("/task")
- {
- taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
- taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
- }
-
- vendorRoute := apiRouter.Group("/vendors")
- vendorRoute.Use(middleware.AdminAuth())
- {
- vendorRoute.GET("/", controller.GetAllVendors)
- vendorRoute.GET("/search", controller.SearchVendors)
- vendorRoute.GET("/:id", controller.GetVendorMeta)
- vendorRoute.POST("/", controller.CreateVendorMeta)
- vendorRoute.PUT("/", controller.UpdateVendorMeta)
- vendorRoute.DELETE("/:id", controller.DeleteVendorMeta)
- }
-
- modelsRoute := apiRouter.Group("/models")
- modelsRoute.Use(middleware.AdminAuth())
- {
- modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview)
- modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels)
- modelsRoute.GET("/missing", controller.GetMissingModels)
- modelsRoute.GET("/", controller.GetAllModelsMeta)
- modelsRoute.GET("/search", controller.SearchModelsMeta)
- modelsRoute.GET("/:id", controller.GetModelMeta)
- modelsRoute.POST("/", controller.CreateModelMeta)
- modelsRoute.PUT("/", controller.UpdateModelMeta)
- modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
- }
- }
-}
diff --git a/new-api/router/dashboard.go b/new-api/router/dashboard.go
deleted file mode 100644
index 48f365340cb43c4496e8432e6302fd839087d2a0..0000000000000000000000000000000000000000
--- a/new-api/router/dashboard.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package router
-
-import (
- "github.com/gin-contrib/gzip"
- "github.com/gin-gonic/gin"
- "one-api/controller"
- "one-api/middleware"
-)
-
-func SetDashboardRouter(router *gin.Engine) {
- apiRouter := router.Group("/")
- apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
- apiRouter.Use(middleware.GlobalAPIRateLimit())
- apiRouter.Use(middleware.CORS())
- apiRouter.Use(middleware.TokenAuth())
- {
- apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
- apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
- apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
- apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
- }
-}
diff --git a/new-api/router/main.go b/new-api/router/main.go
deleted file mode 100644
index f6acdc11b63a1e9cb634fb38ae2f144554dd45a2..0000000000000000000000000000000000000000
--- a/new-api/router/main.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package router
-
-import (
- "embed"
- "fmt"
- "net/http"
- "one-api/common"
- "os"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
- SetApiRouter(router)
- SetDashboardRouter(router)
- SetRelayRouter(router)
- SetVideoRouter(router)
- frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
- if common.IsMasterNode && frontendBaseUrl != "" {
- frontendBaseUrl = ""
- common.SysLog("FRONTEND_BASE_URL is ignored on master node")
- }
- if frontendBaseUrl == "" {
- SetWebRouter(router, buildFS, indexPage)
- } else {
- frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
- router.NoRoute(func(c *gin.Context) {
- c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
- })
- }
-}
diff --git a/new-api/router/relay-router.go b/new-api/router/relay-router.go
deleted file mode 100644
index a1fcd62bab41f9e65a050e2668a5a0f583ed7193..0000000000000000000000000000000000000000
--- a/new-api/router/relay-router.go
+++ /dev/null
@@ -1,205 +0,0 @@
-package router
-
-import (
- "one-api/constant"
- "one-api/controller"
- "one-api/middleware"
- "one-api/relay"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func SetRelayRouter(router *gin.Engine) {
- router.Use(middleware.CORS())
- router.Use(middleware.DecompressRequestMiddleware())
- router.Use(middleware.StatsMiddleware())
- // https://platform.openai.com/docs/api-reference/introduction
- modelsRouter := router.Group("/v1/models")
- modelsRouter.Use(middleware.TokenAuth())
- {
- modelsRouter.GET("", func(c *gin.Context) {
- switch {
- case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
- controller.ListModels(c, constant.ChannelTypeAnthropic)
- case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
- controller.RetrieveModel(c, constant.ChannelTypeGemini)
- default:
- controller.ListModels(c, constant.ChannelTypeOpenAI)
- }
- })
-
- modelsRouter.GET("/:model", func(c *gin.Context) {
- switch {
- case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
- controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
- default:
- controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
- }
- })
- }
-
- geminiRouter := router.Group("/v1beta/models")
- geminiRouter.Use(middleware.TokenAuth())
- {
- geminiRouter.GET("", func(c *gin.Context) {
- controller.ListModels(c, constant.ChannelTypeGemini)
- })
- }
-
- geminiCompatibleRouter := router.Group("/v1beta/openai/models")
- geminiCompatibleRouter.Use(middleware.TokenAuth())
- {
- geminiCompatibleRouter.GET("", func(c *gin.Context) {
- controller.ListModels(c, constant.ChannelTypeOpenAI)
- })
- }
-
- playgroundRouter := router.Group("/pg")
- playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
- {
- playgroundRouter.POST("/chat/completions", controller.Playground)
- }
- relayV1Router := router.Group("/v1")
- relayV1Router.Use(middleware.TokenAuth())
- relayV1Router.Use(middleware.ModelRequestRateLimit())
- {
- // WebSocket 路由(统一到 Relay)
- wsRouter := relayV1Router.Group("")
- wsRouter.Use(middleware.Distribute())
- wsRouter.GET("/realtime", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIRealtime)
- })
- }
- {
- //http router
- httpRouter := relayV1Router.Group("")
- httpRouter.Use(middleware.Distribute())
-
- // claude related routes
- httpRouter.POST("/messages", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatClaude)
- })
-
- // chat related routes
- httpRouter.POST("/completions", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAI)
- })
- httpRouter.POST("/chat/completions", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAI)
- })
-
- // response related routes
- httpRouter.POST("/responses", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIResponses)
- })
-
- // image related routes
- httpRouter.POST("/edits", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIImage)
- })
- httpRouter.POST("/images/generations", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIImage)
- })
- httpRouter.POST("/images/edits", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIImage)
- })
-
- // embedding related routes
- httpRouter.POST("/embeddings", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatEmbedding)
- })
-
- // audio related routes
- httpRouter.POST("/audio/transcriptions", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIAudio)
- })
- httpRouter.POST("/audio/translations", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIAudio)
- })
- httpRouter.POST("/audio/speech", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAIAudio)
- })
-
- // rerank related routes
- httpRouter.POST("/rerank", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatRerank)
- })
-
- // gemini relay routes
- httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatGemini)
- })
- httpRouter.POST("/models/*path", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatGemini)
- })
-
- // other relay routes
- httpRouter.POST("/moderations", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatOpenAI)
- })
-
- // not implemented
- httpRouter.POST("/images/variations", controller.RelayNotImplemented)
- httpRouter.GET("/files", controller.RelayNotImplemented)
- httpRouter.POST("/files", controller.RelayNotImplemented)
- httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
- httpRouter.GET("/files/:id", controller.RelayNotImplemented)
- httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
- httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
- httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
- httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
- httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
- httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
- httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
- }
-
- relayMjRouter := router.Group("/mj")
- registerMjRouterGroup(relayMjRouter)
-
- relayMjModeRouter := router.Group("/:mode/mj")
- registerMjRouterGroup(relayMjModeRouter)
- //relayMjRouter.Use()
-
- relaySunoRouter := router.Group("/suno")
- relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
- {
- relaySunoRouter.POST("/submit/:action", controller.RelayTask)
- relaySunoRouter.POST("/fetch", controller.RelayTask)
- relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
- }
-
- relayGeminiRouter := router.Group("/v1beta")
- relayGeminiRouter.Use(middleware.TokenAuth())
- relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
- relayGeminiRouter.Use(middleware.Distribute())
- {
- // Gemini API 路径格式: /v1beta/models/{model_name}:{action}
- relayGeminiRouter.POST("/models/*path", func(c *gin.Context) {
- controller.Relay(c, types.RelayFormatGemini)
- })
- }
-}
-
-func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
- relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
- relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
- {
- relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
- relayMjRouter.POST("/notify", controller.RelayMidjourney)
- relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
- relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
- relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
- relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
- relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
- }
-}
diff --git a/new-api/router/video-router.go b/new-api/router/video-router.go
deleted file mode 100644
index 5b22205cfc1f49e6adac474fe7a125054f8b070a..0000000000000000000000000000000000000000
--- a/new-api/router/video-router.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package router
-
-import (
- "one-api/controller"
- "one-api/middleware"
-
- "github.com/gin-gonic/gin"
-)
-
-func SetVideoRouter(router *gin.Engine) {
- videoV1Router := router.Group("/v1")
- videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
- {
- videoV1Router.POST("/video/generations", controller.RelayTask)
- videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
- }
-
- klingV1Router := router.Group("/kling/v1")
- klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
- {
- klingV1Router.POST("/videos/text2video", controller.RelayTask)
- klingV1Router.POST("/videos/image2video", controller.RelayTask)
- klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
- klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
- }
-
- // Jimeng official API routes - direct mapping to official API format
- jimengOfficialGroup := router.Group("jimeng")
- jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
- {
- // Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31
- jimengOfficialGroup.POST("/", controller.RelayTask)
- }
-}
diff --git a/new-api/router/web-router.go b/new-api/router/web-router.go
deleted file mode 100644
index c19d0b83e7ee3682f539f104c1a8e6417b4f8bac..0000000000000000000000000000000000000000
--- a/new-api/router/web-router.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package router
-
-import (
- "embed"
- "github.com/gin-contrib/gzip"
- "github.com/gin-contrib/static"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/controller"
- "one-api/middleware"
- "strings"
-)
-
-func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
- router.Use(gzip.Gzip(gzip.DefaultCompression))
- router.Use(middleware.GlobalWebRateLimit())
- router.Use(middleware.Cache())
- router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist")))
- router.NoRoute(func(c *gin.Context) {
- if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
- controller.RelayNotFound(c)
- return
- }
- c.Header("Cache-Control", "no-cache")
- c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
- })
-}
diff --git a/new-api/service/audio.go b/new-api/service/audio.go
deleted file mode 100644
index 68622e0eab19f956d229f940825ca43fd6df1888..0000000000000000000000000000000000000000
--- a/new-api/service/audio.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package service
-
-import (
- "encoding/base64"
- "fmt"
- "strings"
-)
-
-func parseAudio(audioBase64 string, format string) (duration float64, err error) {
- audioData, err := base64.StdEncoding.DecodeString(audioBase64)
- if err != nil {
- return 0, fmt.Errorf("base64 decode error: %v", err)
- }
-
- var samplesCount int
- var sampleRate int
-
- switch format {
- case "pcm16":
- samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
- sampleRate = 24000 // 24kHz
- case "g711_ulaw", "g711_alaw":
- samplesCount = len(audioData) // 8位 = 1字节每样本
- sampleRate = 8000 // 8kHz
- default:
- samplesCount = len(audioData) // 8位 = 1字节每样本
- sampleRate = 8000 // 8kHz
- }
-
- duration = float64(samplesCount) / float64(sampleRate)
- return duration, nil
-}
-
-func DecodeBase64AudioData(audioBase64 string) (string, error) {
- // 检查并移除 data:audio/xxx;base64, 前缀
- idx := strings.Index(audioBase64, ",")
- if idx != -1 {
- audioBase64 = audioBase64[idx+1:]
- }
-
- // 解码 Base64 数据
- _, err := base64.StdEncoding.DecodeString(audioBase64)
- if err != nil {
- return "", fmt.Errorf("base64 decode error: %v", err)
- }
-
- return audioBase64, nil
-}
diff --git a/new-api/service/channel.go b/new-api/service/channel.go
deleted file mode 100644
index db040f608c31699d036edcb650602dd19630d57c..0000000000000000000000000000000000000000
--- a/new-api/service/channel.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package service
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- "one-api/setting/operation_setting"
- "one-api/types"
- "strings"
-)
-
-func formatNotifyType(channelId int, status int) string {
- return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status)
-}
-
-// disable & notify
-func DisableChannel(channelError types.ChannelError, reason string) {
- common.SysLog(fmt.Sprintf("通道「%s」(#%d)发生错误,准备禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason))
-
- // 检查是否启用自动禁用功能
- if !channelError.AutoBan {
- common.SysLog(fmt.Sprintf("通道「%s」(#%d)未启用自动禁用功能,跳过禁用操作", channelError.ChannelName, channelError.ChannelId))
- return
- }
-
- success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
- if success {
- subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId)
- content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)
- NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
- }
-}
-
-func EnableChannel(channelId int, usingKey string, channelName string) {
- success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
- if success {
- subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
- content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
- NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content)
- }
-}
-
-func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
- if !common.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if types.IsChannelError(err) {
- return true
- }
- if types.IsSkipRetryError(err) {
- return false
- }
- if err.StatusCode == http.StatusUnauthorized {
- return true
- }
- if err.StatusCode == http.StatusForbidden {
- switch channelType {
- case constant.ChannelTypeGemini:
- return true
- }
- }
- oaiErr := err.ToOpenAIError()
- switch oaiErr.Code {
- case "invalid_api_key":
- return true
- case "account_deactivated":
- return true
- case "billing_not_active":
- return true
- case "pre_consume_token_quota_failed":
- return true
- }
- switch oaiErr.Type {
- case "insufficient_quota":
- return true
- case "insufficient_user_quota":
- return true
- // https://docs.anthropic.com/claude/reference/errors
- case "authentication_error":
- return true
- case "permission_error":
- return true
- case "forbidden":
- return true
- }
-
- lowerMessage := strings.ToLower(err.Error())
- search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
- return search
-}
-
-func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
- if !common.AutomaticEnableChannelEnabled {
- return false
- }
- if newAPIError != nil {
- return false
- }
- if status != common.ChannelStatusAutoDisabled {
- return false
- }
- return true
-}
diff --git a/new-api/service/convert.go b/new-api/service/convert.go
deleted file mode 100644
index 4b2bf66dc938bbed7d00020d9d61c573371bda2e..0000000000000000000000000000000000000000
--- a/new-api/service/convert.go
+++ /dev/null
@@ -1,815 +0,0 @@
-package service
-
-import (
- "encoding/json"
- "fmt"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/relay/channel/openrouter"
- relaycommon "one-api/relay/common"
- "strings"
-)
-
-func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
- openAIRequest := dto.GeneralOpenAIRequest{
- Model: claudeRequest.Model,
- MaxTokens: claudeRequest.MaxTokens,
- Temperature: claudeRequest.Temperature,
- TopP: claudeRequest.TopP,
- Stream: claudeRequest.Stream,
- }
-
- isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
-
- if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
- if isOpenRouter {
- reasoning := openrouter.RequestReasoning{
- MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
- }
- reasoningJSON, err := json.Marshal(reasoning)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal reasoning: %w", err)
- }
- openAIRequest.Reasoning = reasoningJSON
- } else {
- thinkingSuffix := "-thinking"
- if strings.HasSuffix(info.OriginModelName, thinkingSuffix) &&
- !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) {
- openAIRequest.Model = openAIRequest.Model + thinkingSuffix
- }
- }
- }
-
- // Convert stop sequences
- if len(claudeRequest.StopSequences) == 1 {
- openAIRequest.Stop = claudeRequest.StopSequences[0]
- } else if len(claudeRequest.StopSequences) > 1 {
- openAIRequest.Stop = claudeRequest.StopSequences
- }
-
- // Convert tools
- tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
- openAITools := make([]dto.ToolCallRequest, 0)
- for _, claudeTool := range tools {
- openAITool := dto.ToolCallRequest{
- Type: "function",
- Function: dto.FunctionRequest{
- Name: claudeTool.Name,
- Description: claudeTool.Description,
- Parameters: claudeTool.InputSchema,
- },
- }
- openAITools = append(openAITools, openAITool)
- }
- openAIRequest.Tools = openAITools
-
- // Convert messages
- openAIMessages := make([]dto.Message, 0)
-
- // Add system message if present
- if claudeRequest.System != nil {
- if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
- openAIMessage := dto.Message{
- Role: "system",
- }
- openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
- openAIMessages = append(openAIMessages, openAIMessage)
- } else {
- systems := claudeRequest.ParseSystem()
- if len(systems) > 0 {
- openAIMessage := dto.Message{
- Role: "system",
- }
- isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude")
- if isOpenRouterClaude {
- systemMediaMessages := make([]dto.MediaContent, 0, len(systems))
- for _, system := range systems {
- message := dto.MediaContent{
- Type: "text",
- Text: system.GetText(),
- CacheControl: system.CacheControl,
- }
- systemMediaMessages = append(systemMediaMessages, message)
- }
- openAIMessage.SetMediaContent(systemMediaMessages)
- } else {
- systemStr := ""
- for _, system := range systems {
- if system.Text != nil {
- systemStr += *system.Text
- }
- }
- openAIMessage.SetStringContent(systemStr)
- }
- openAIMessages = append(openAIMessages, openAIMessage)
- }
- }
- }
- for _, claudeMessage := range claudeRequest.Messages {
- openAIMessage := dto.Message{
- Role: claudeMessage.Role,
- }
-
- //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
- if claudeMessage.IsStringContent() {
- openAIMessage.SetStringContent(claudeMessage.GetStringContent())
- } else {
- content, err := claudeMessage.ParseContent()
- if err != nil {
- return nil, err
- }
- contents := content
- var toolCalls []dto.ToolCallRequest
- mediaMessages := make([]dto.MediaContent, 0, len(contents))
-
- for _, mediaMsg := range contents {
- switch mediaMsg.Type {
- case "text":
- message := dto.MediaContent{
- Type: "text",
- Text: mediaMsg.GetText(),
- CacheControl: mediaMsg.CacheControl,
- }
- mediaMessages = append(mediaMessages, message)
- case "image":
- // Handle image conversion (base64 to URL or keep as is)
- imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
- //textContent += fmt.Sprintf("[Image: %s]", imageData)
- mediaMessage := dto.MediaContent{
- Type: "image_url",
- ImageUrl: &dto.MessageImageUrl{Url: imageData},
- }
- mediaMessages = append(mediaMessages, mediaMessage)
- case "tool_use":
- toolCall := dto.ToolCallRequest{
- ID: mediaMsg.Id,
- Type: "function",
- Function: dto.FunctionRequest{
- Name: mediaMsg.Name,
- Arguments: toJSONString(mediaMsg.Input),
- },
- }
- toolCalls = append(toolCalls, toolCall)
- case "tool_result":
- // Add tool result as a separate message
- toolName := mediaMsg.Name
- if toolName == "" {
- toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId)
- }
- oaiToolMessage := dto.Message{
- Role: "tool",
- Name: &toolName,
- ToolCallId: mediaMsg.ToolUseId,
- }
- //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
- if mediaMsg.IsStringContent() {
- oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
- } else {
- mediaContents := mediaMsg.ParseMediaContent()
- encodeJson, _ := common.Marshal(mediaContents)
- oaiToolMessage.SetStringContent(string(encodeJson))
- }
- openAIMessages = append(openAIMessages, oaiToolMessage)
- }
- }
-
- if len(toolCalls) > 0 {
- openAIMessage.SetToolCalls(toolCalls)
- }
-
- if len(mediaMessages) > 0 && len(toolCalls) == 0 {
- openAIMessage.SetMediaContent(mediaMessages)
- }
- }
- if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
- openAIMessages = append(openAIMessages, openAIMessage)
- }
- }
-
- openAIRequest.Messages = openAIMessages
-
- return &openAIRequest, nil
-}
-
-func generateStopBlock(index int) *dto.ClaudeResponse {
- return &dto.ClaudeResponse{
- Type: "content_block_stop",
- Index: common.GetPointer[int](index),
- }
-}
-
-func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
- var claudeResponses []*dto.ClaudeResponse
- if info.SendResponseCount == 1 {
- msg := &dto.ClaudeMediaMessage{
- Id: openAIResponse.Id,
- Model: openAIResponse.Model,
- Type: "message",
- Role: "assistant",
- Usage: &dto.ClaudeUsage{
- InputTokens: info.PromptTokens,
- OutputTokens: 0,
- },
- }
- msg.SetContent(make([]any, 0))
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Type: "message_start",
- Message: msg,
- })
- claudeResponses = append(claudeResponses)
- //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- // Type: "ping",
- //})
- if openAIResponse.IsToolCall() {
- info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
- resp := &dto.ClaudeResponse{
- Type: "content_block_start",
- ContentBlock: &dto.ClaudeMediaMessage{
- Id: openAIResponse.GetFirstToolCall().ID,
- Type: "tool_use",
- Name: openAIResponse.GetFirstToolCall().Function.Name,
- Input: map[string]interface{}{},
- },
- }
- resp.SetIndex(0)
- claudeResponses = append(claudeResponses, resp)
- } else {
-
- }
- // 判断首个响应是否存在内容(非标准的 OpenAI 响应)
- if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Index: &info.ClaudeConvertInfo.Index,
- Type: "content_block_start",
- ContentBlock: &dto.ClaudeMediaMessage{
- Type: "text",
- Text: common.GetPointer[string](""),
- },
- })
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Index: &info.ClaudeConvertInfo.Index,
- Type: "content_block_delta",
- Delta: &dto.ClaudeMediaMessage{
- Type: "text_delta",
- Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
- },
- })
- info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
- }
- return claudeResponses
- }
-
- if len(openAIResponse.Choices) == 0 {
- // no choices
- // 可能为非标准的 OpenAI 响应,判断是否已经完成
- if info.Done {
- claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
- oaiUsage := info.ClaudeConvertInfo.Usage
- if oaiUsage != nil {
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Type: "message_delta",
- Usage: &dto.ClaudeUsage{
- InputTokens: oaiUsage.PromptTokens,
- OutputTokens: oaiUsage.CompletionTokens,
- CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
- CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
- },
- Delta: &dto.ClaudeMediaMessage{
- StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
- },
- })
- }
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Type: "message_stop",
- })
- }
- return claudeResponses
- } else {
- chosenChoice := openAIResponse.Choices[0]
- if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
- // should be done
- info.FinishReason = *chosenChoice.FinishReason
- if !info.Done {
- return claudeResponses
- }
- }
- if info.Done {
- claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
- oaiUsage := info.ClaudeConvertInfo.Usage
- if oaiUsage != nil {
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Type: "message_delta",
- Usage: &dto.ClaudeUsage{
- InputTokens: oaiUsage.PromptTokens,
- OutputTokens: oaiUsage.CompletionTokens,
- CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
- CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
- },
- Delta: &dto.ClaudeMediaMessage{
- StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
- },
- })
- }
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Type: "message_stop",
- })
- } else {
- var claudeResponse dto.ClaudeResponse
- var isEmpty bool
- claudeResponse.Type = "content_block_delta"
- if len(chosenChoice.Delta.ToolCalls) > 0 {
- if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
- claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
- info.ClaudeConvertInfo.Index++
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Index: &info.ClaudeConvertInfo.Index,
- Type: "content_block_start",
- ContentBlock: &dto.ClaudeMediaMessage{
- Id: openAIResponse.GetFirstToolCall().ID,
- Type: "tool_use",
- Name: openAIResponse.GetFirstToolCall().Function.Name,
- Input: map[string]interface{}{},
- },
- })
- }
- info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
- // tools delta
- claudeResponse.Delta = &dto.ClaudeMediaMessage{
- Type: "input_json_delta",
- PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
- }
- } else {
- reasoning := chosenChoice.Delta.GetReasoningContent()
- textContent := chosenChoice.Delta.GetContentString()
- if reasoning != "" || textContent != "" {
- if reasoning != "" {
- if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
- //info.ClaudeConvertInfo.Index++
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Index: &info.ClaudeConvertInfo.Index,
- Type: "content_block_start",
- ContentBlock: &dto.ClaudeMediaMessage{
- Type: "thinking",
- Thinking: "",
- },
- })
- }
- info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
- // text delta
- claudeResponse.Delta = &dto.ClaudeMediaMessage{
- Type: "thinking_delta",
- Thinking: reasoning,
- }
- } else {
- if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
- if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
- claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
- info.ClaudeConvertInfo.Index++
- }
- claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
- Index: &info.ClaudeConvertInfo.Index,
- Type: "content_block_start",
- ContentBlock: &dto.ClaudeMediaMessage{
- Type: "text",
- Text: common.GetPointer[string](""),
- },
- })
- }
- info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
- // text delta
- claudeResponse.Delta = &dto.ClaudeMediaMessage{
- Type: "text_delta",
- Text: common.GetPointer[string](textContent),
- }
- }
- } else {
- isEmpty = true
- }
- }
- claudeResponse.Index = &info.ClaudeConvertInfo.Index
- if !isEmpty {
- claudeResponses = append(claudeResponses, &claudeResponse)
- }
- }
- }
-
- return claudeResponses
-}
-
-func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
- var stopReason string
- contents := make([]dto.ClaudeMediaMessage, 0)
- claudeResponse := &dto.ClaudeResponse{
- Id: openAIResponse.Id,
- Type: "message",
- Role: "assistant",
- Model: openAIResponse.Model,
- }
- for _, choice := range openAIResponse.Choices {
- stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
- if choice.FinishReason == "tool_calls" {
- for _, toolUse := range choice.Message.ParseToolCalls() {
- claudeContent := dto.ClaudeMediaMessage{}
- claudeContent.Type = "tool_use"
- claudeContent.Id = toolUse.ID
- claudeContent.Name = toolUse.Function.Name
- var mapParams map[string]interface{}
- if err := common.Unmarshal([]byte(toolUse.Function.Arguments), &mapParams); err == nil {
- claudeContent.Input = mapParams
- } else {
- claudeContent.Input = toolUse.Function.Arguments
- }
- contents = append(contents, claudeContent)
- }
- } else {
- claudeContent := dto.ClaudeMediaMessage{}
- claudeContent.Type = "text"
- claudeContent.SetText(choice.Message.StringContent())
- contents = append(contents, claudeContent)
- }
- }
- claudeResponse.Content = contents
- claudeResponse.StopReason = stopReason
- claudeResponse.Usage = &dto.ClaudeUsage{
- InputTokens: openAIResponse.PromptTokens,
- OutputTokens: openAIResponse.CompletionTokens,
- }
-
- return claudeResponse
-}
-
-func stopReasonOpenAI2Claude(reason string) string {
- switch reason {
- case "stop":
- return "end_turn"
- case "stop_sequence":
- return "stop_sequence"
- case "length":
- fallthrough
- case "max_tokens":
- return "max_tokens"
- case "tool_calls":
- return "tool_use"
- default:
- return reason
- }
-}
-
-func toJSONString(v interface{}) string {
- b, err := json.Marshal(v)
- if err != nil {
- return "{}"
- }
- return string(b)
-}
-
-func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
- openaiRequest := &dto.GeneralOpenAIRequest{
- Model: info.UpstreamModelName,
- Stream: info.IsStream,
- }
-
- // 转换 messages
- var messages []dto.Message
- for _, content := range geminiRequest.Contents {
- message := dto.Message{
- Role: convertGeminiRoleToOpenAI(content.Role),
- }
-
- // 处理 parts
- var mediaContents []dto.MediaContent
- var toolCalls []dto.ToolCallRequest
- for _, part := range content.Parts {
- if part.Text != "" {
- mediaContent := dto.MediaContent{
- Type: "text",
- Text: part.Text,
- }
- mediaContents = append(mediaContents, mediaContent)
- } else if part.InlineData != nil {
- mediaContent := dto.MediaContent{
- Type: "image_url",
- ImageUrl: &dto.MessageImageUrl{
- Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
- Detail: "auto",
- MimeType: part.InlineData.MimeType,
- },
- }
- mediaContents = append(mediaContents, mediaContent)
- } else if part.FileData != nil {
- mediaContent := dto.MediaContent{
- Type: "image_url",
- ImageUrl: &dto.MessageImageUrl{
- Url: part.FileData.FileUri,
- Detail: "auto",
- MimeType: part.FileData.MimeType,
- },
- }
- mediaContents = append(mediaContents, mediaContent)
- } else if part.FunctionCall != nil {
- // 处理 Gemini 的工具调用
- toolCall := dto.ToolCallRequest{
- ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
- Type: "function",
- Function: dto.FunctionRequest{
- Name: part.FunctionCall.FunctionName,
- Arguments: toJSONString(part.FunctionCall.Arguments),
- },
- }
- toolCalls = append(toolCalls, toolCall)
- } else if part.FunctionResponse != nil {
- // 处理 Gemini 的工具响应,创建单独的 tool 消息
- toolMessage := dto.Message{
- Role: "tool",
- ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
- }
- toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
- messages = append(messages, toolMessage)
- }
- }
-
- // 设置消息内容
- if len(toolCalls) > 0 {
- // 如果有工具调用,设置工具调用
- message.SetToolCalls(toolCalls)
- } else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
- // 如果只有一个文本内容,直接设置字符串
- message.Content = mediaContents[0].Text
- } else if len(mediaContents) > 0 {
- // 如果有多个内容或包含媒体,设置为数组
- message.SetMediaContent(mediaContents)
- }
-
- // 只有当消息有内容或工具调用时才添加
- if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
- messages = append(messages, message)
- }
- }
-
- openaiRequest.Messages = messages
-
- if geminiRequest.GenerationConfig.Temperature != nil {
- openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
- }
- if geminiRequest.GenerationConfig.TopP > 0 {
- openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
- }
- if geminiRequest.GenerationConfig.TopK > 0 {
- openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
- }
- if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
- openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
- }
- // gemini stop sequences 最多 5 个,openai stop 最多 4 个
- if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
- openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
- }
- if geminiRequest.GenerationConfig.CandidateCount > 0 {
- openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
- }
-
- // 转换工具调用
- if len(geminiRequest.GetTools()) > 0 {
- var tools []dto.ToolCallRequest
- for _, tool := range geminiRequest.GetTools() {
- if tool.FunctionDeclarations != nil {
- // 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
- functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
- if ok {
- for _, function := range functionDeclarations {
- openAITool := dto.ToolCallRequest{
- Type: "function",
- Function: dto.FunctionRequest{
- Name: function.Name,
- Description: function.Description,
- Parameters: function.Parameters,
- },
- }
- tools = append(tools, openAITool)
- }
- }
- }
- }
- if len(tools) > 0 {
- openaiRequest.Tools = tools
- }
- }
-
- // gemini system instructions
- if geminiRequest.SystemInstructions != nil {
- // 将系统指令作为第一条消息插入
- systemMessage := dto.Message{
- Role: "system",
- Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
- }
- openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
- }
-
- return openaiRequest, nil
-}
-
-func convertGeminiRoleToOpenAI(geminiRole string) string {
- switch geminiRole {
- case "user":
- return "user"
- case "model":
- return "assistant"
- case "function":
- return "function"
- default:
- return "user"
- }
-}
-
-func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
- var texts []string
- for _, part := range parts {
- if part.Text != "" {
- texts = append(texts, part.Text)
- }
- }
- return strings.Join(texts, "\n")
-}
-
-// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
-func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
- geminiResponse := &dto.GeminiChatResponse{
- Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
- PromptFeedback: dto.GeminiChatPromptFeedback{
- SafetyRatings: []dto.GeminiChatSafetyRating{},
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: openAIResponse.PromptTokens,
- CandidatesTokenCount: openAIResponse.CompletionTokens,
- TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
- },
- }
-
- for _, choice := range openAIResponse.Choices {
- candidate := dto.GeminiChatCandidate{
- Index: int64(choice.Index),
- SafetyRatings: []dto.GeminiChatSafetyRating{},
- }
-
- // 设置结束原因
- var finishReason string
- switch choice.FinishReason {
- case "stop":
- finishReason = "STOP"
- case "length":
- finishReason = "MAX_TOKENS"
- case "content_filter":
- finishReason = "SAFETY"
- case "tool_calls":
- finishReason = "STOP"
- default:
- finishReason = "STOP"
- }
- candidate.FinishReason = &finishReason
-
- // 转换消息内容
- content := dto.GeminiChatContent{
- Role: "model",
- Parts: make([]dto.GeminiPart, 0),
- }
-
- // 处理工具调用
- toolCalls := choice.Message.ParseToolCalls()
- if len(toolCalls) > 0 {
- for _, toolCall := range toolCalls {
- // 解析参数
- var args map[string]interface{}
- if toolCall.Function.Arguments != "" {
- if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
- args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
- }
- } else {
- args = make(map[string]interface{})
- }
-
- part := dto.GeminiPart{
- FunctionCall: &dto.FunctionCall{
- FunctionName: toolCall.Function.Name,
- Arguments: args,
- },
- }
- content.Parts = append(content.Parts, part)
- }
- } else {
- // 处理文本内容
- textContent := choice.Message.StringContent()
- if textContent != "" {
- part := dto.GeminiPart{
- Text: textContent,
- }
- content.Parts = append(content.Parts, part)
- }
- }
-
- candidate.Content = content
- geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
- }
-
- return geminiResponse
-}
-
-// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
-func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
- // 检查是否有实际内容或结束标志
- hasContent := false
- hasFinishReason := false
- for _, choice := range openAIResponse.Choices {
- if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
- hasContent = true
- }
- if choice.FinishReason != nil {
- hasFinishReason = true
- }
- }
-
- // 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
- if !hasContent && !hasFinishReason {
- return nil
- }
-
- geminiResponse := &dto.GeminiChatResponse{
- Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
- PromptFeedback: dto.GeminiChatPromptFeedback{
- SafetyRatings: []dto.GeminiChatSafetyRating{},
- },
- UsageMetadata: dto.GeminiUsageMetadata{
- PromptTokenCount: info.PromptTokens,
- CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
- TotalTokenCount: info.PromptTokens,
- },
- }
-
- for _, choice := range openAIResponse.Choices {
- candidate := dto.GeminiChatCandidate{
- Index: int64(choice.Index),
- SafetyRatings: []dto.GeminiChatSafetyRating{},
- }
-
- // 设置结束原因
- if choice.FinishReason != nil {
- var finishReason string
- switch *choice.FinishReason {
- case "stop":
- finishReason = "STOP"
- case "length":
- finishReason = "MAX_TOKENS"
- case "content_filter":
- finishReason = "SAFETY"
- case "tool_calls":
- finishReason = "STOP"
- default:
- finishReason = "STOP"
- }
- candidate.FinishReason = &finishReason
- }
-
- // 转换消息内容
- content := dto.GeminiChatContent{
- Role: "model",
- Parts: make([]dto.GeminiPart, 0),
- }
-
- // 处理工具调用
- if choice.Delta.ToolCalls != nil {
- for _, toolCall := range choice.Delta.ToolCalls {
- // 解析参数
- var args map[string]interface{}
- if toolCall.Function.Arguments != "" {
- if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
- args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
- }
- } else {
- args = make(map[string]interface{})
- }
-
- part := dto.GeminiPart{
- FunctionCall: &dto.FunctionCall{
- FunctionName: toolCall.Function.Name,
- Arguments: args,
- },
- }
- content.Parts = append(content.Parts, part)
- }
- } else {
- // 处理文本内容
- textContent := choice.Delta.GetContentString()
- if textContent != "" {
- part := dto.GeminiPart{
- Text: textContent,
- }
- content.Parts = append(content.Parts, part)
- }
- }
-
- candidate.Content = content
- geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
- }
-
- return geminiResponse
-}
diff --git a/new-api/service/download.go b/new-api/service/download.go
deleted file mode 100644
index 28c1b2c79ff9cc977876508dba9a78f72970852e..0000000000000000000000000000000000000000
--- a/new-api/service/download.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package service
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/setting/system_setting"
- "strings"
-)
-
-// WorkerRequest Worker请求的数据结构
-type WorkerRequest struct {
- URL string `json:"url"`
- Key string `json:"key"`
- Method string `json:"method,omitempty"`
- Headers map[string]string `json:"headers,omitempty"`
- Body json.RawMessage `json:"body,omitempty"`
-}
-
-// DoWorkerRequest 通过Worker发送请求
-func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
- if !system_setting.EnableWorker() {
- return nil, fmt.Errorf("worker not enabled")
- }
- if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
- return nil, fmt.Errorf("only support https url")
- }
-
- // SSRF防护:验证请求URL
- fetchSetting := system_setting.GetFetchSetting()
- if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
- return nil, fmt.Errorf("request reject: %v", err)
- }
-
- workerUrl := system_setting.WorkerUrl
- if !strings.HasSuffix(workerUrl, "/") {
- workerUrl += "/"
- }
-
- // 序列化worker请求数据
- workerPayload, err := json.Marshal(req)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
- }
-
- return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
-}
-
-func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
- if system_setting.EnableWorker() {
- common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
- req := &WorkerRequest{
- URL: originUrl,
- Key: system_setting.WorkerValidKey,
- }
- return DoWorkerRequest(req)
- } else {
- // SSRF防护:验证请求URL(非Worker模式)
- fetchSetting := system_setting.GetFetchSetting()
- if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
- return nil, fmt.Errorf("request reject: %v", err)
- }
-
- common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
- return http.Get(originUrl)
- }
-}
diff --git a/new-api/service/epay.go b/new-api/service/epay.go
deleted file mode 100644
index a25026ac0eafac2aa2110aba1e0861561fdb4a6b..0000000000000000000000000000000000000000
--- a/new-api/service/epay.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package service
-
-import (
- "one-api/setting/operation_setting"
- "one-api/setting/system_setting"
-)
-
-func GetCallbackAddress() string {
- if operation_setting.CustomCallbackAddress == "" {
- return system_setting.ServerAddress
- }
- return operation_setting.CustomCallbackAddress
-}
diff --git a/new-api/service/error.go b/new-api/service/error.go
deleted file mode 100644
index f4b82d39e8e6125979f1d01d262f0838dc33c966..0000000000000000000000000000000000000000
--- a/new-api/service/error.go
+++ /dev/null
@@ -1,155 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/logger"
- "one-api/types"
- "strconv"
- "strings"
-)
-
-func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
- return &dto.MidjourneyResponse{
- Code: code,
- Description: desc,
- }
-}
-
-func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
- return &dto.MidjourneyResponseWithStatusCode{
- StatusCode: statusCode,
- Response: *MidjourneyErrorWrapper(code, desc),
- }
-}
-
-//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
-//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
-// text := err.Error()
-// lowerText := strings.ToLower(text)
-// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
-// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
-// common.SysLog(fmt.Sprintf("error: %s", text))
-// text = "请求上游地址失败"
-// }
-// }
-// openAIError := dto.OpenAIError{
-// Message: text,
-// Type: "new_api_error",
-// Code: code,
-// }
-// return &dto.OpenAIErrorWithStatusCode{
-// Error: openAIError,
-// StatusCode: statusCode,
-// }
-//}
-//
-//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
-// openaiErr := OpenAIErrorWrapper(err, code, statusCode)
-// openaiErr.LocalError = true
-// return openaiErr
-//}
-
-func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
- text := err.Error()
- lowerText := strings.ToLower(text)
- if !strings.HasPrefix(lowerText, "get file base64 from url") {
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
- }
- }
- claudeError := types.ClaudeError{
- Message: text,
- Type: "new_api_error",
- }
- return &dto.ClaudeErrorWithStatusCode{
- Error: claudeError,
- StatusCode: statusCode,
- }
-}
-
-func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
- claudeErr := ClaudeErrorWrapper(err, code, statusCode)
- claudeErr.LocalError = true
- return claudeErr
-}
-
-func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
- newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
-
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return
- }
- CloseResponseBodyGracefully(resp)
- var errResponse dto.GeneralErrorResponse
-
- err = common.Unmarshal(responseBody, &errResponse)
- if err != nil {
- if showBodyWhenFail {
- newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
- } else {
- if common.DebugEnabled {
- logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
- }
- newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
- }
- return
- }
- if errResponse.Error.Message != "" {
- // General format error (OpenAI, Anthropic, Gemini, etc.)
- newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
- } else {
- newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
- }
- return
-}
-
-func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
- if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
- return
- }
- statusCodeMapping := make(map[string]string)
- err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
- if err != nil {
- return
- }
- if newApiErr.StatusCode == http.StatusOK {
- return
- }
- codeStr := strconv.Itoa(newApiErr.StatusCode)
- if _, ok := statusCodeMapping[codeStr]; ok {
- intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
- newApiErr.StatusCode = intCode
- }
-}
-
-func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
- openaiErr := TaskErrorWrapper(err, code, statusCode)
- openaiErr.LocalError = true
- return openaiErr
-}
-
-func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
- text := err.Error()
- lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
- }
- //避免暴露内部错误
- taskError := &dto.TaskError{
- Code: code,
- Message: text,
- StatusCode: statusCode,
- Error: err,
- }
-
- return taskError
-}
diff --git a/new-api/service/file_decoder.go b/new-api/service/file_decoder.go
deleted file mode 100644
index 9be2f0970516b5cb4877c9a907dde1af0cfd97fa..0000000000000000000000000000000000000000
--- a/new-api/service/file_decoder.go
+++ /dev/null
@@ -1,265 +0,0 @@
-package service
-
-import (
- "bytes"
- "encoding/base64"
- "fmt"
- "image"
- _ "image/gif"
- _ "image/jpeg"
- _ "image/png"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/logger"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
-// 如果获取失败,返回 application/octet-stream
-func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
- response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
- if err != nil {
- common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
- return "", err
- }
- defer response.Body.Close()
-
- if response.StatusCode != 200 {
- logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
- return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
- }
-
- if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
- if i := strings.Index(headerType, ";"); i != -1 {
- headerType = headerType[:i]
- }
- if headerType != "application/octet-stream" {
- return headerType, nil
- }
- }
-
- if cd := response.Header.Get("Content-Disposition"); cd != "" {
- parts := strings.Split(cd, ";")
- for _, part := range parts {
- part = strings.TrimSpace(part)
- if strings.HasPrefix(strings.ToLower(part), "filename=") {
- name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
- if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
- name = name[1 : len(name)-1]
- }
- if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
- ext := strings.ToLower(name[dot+1:])
- if ext != "" {
- mt := GetMimeTypeByExtension(ext)
- if mt != "application/octet-stream" {
- return mt, nil
- }
- }
- }
- break
- }
- }
- }
-
- cleanedURL := url
- if q := strings.Index(cleanedURL, "?"); q != -1 {
- cleanedURL = cleanedURL[:q]
- }
- if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
- last := cleanedURL[slash+1:]
- if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
- ext := strings.ToLower(last[dot+1:])
- if ext != "" {
- mt := GetMimeTypeByExtension(ext)
- if mt != "application/octet-stream" {
- return mt, nil
- }
- }
- }
- }
-
- var readData []byte
- limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
- for _, limit := range limits {
- logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
- if len(readData) < limit {
- need := limit - len(readData)
- tmp := make([]byte, need)
- n, _ := io.ReadFull(response.Body, tmp)
- if n > 0 {
- readData = append(readData, tmp[:n]...)
- }
- }
-
- if len(readData) == 0 {
- continue
- }
-
- sniffed := http.DetectContentType(readData)
- if sniffed != "" && sniffed != "application/octet-stream" {
- return sniffed, nil
- }
-
- if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
- switch strings.ToLower(format) {
- case "jpeg", "jpg":
- return "image/jpeg", nil
- case "png":
- return "image/png", nil
- case "gif":
- return "image/gif", nil
- case "bmp":
- return "image/bmp", nil
- case "tiff":
- return "image/tiff", nil
- default:
- if format != "" {
- return "image/" + strings.ToLower(format), nil
- }
- }
- }
- }
-
- // Fallback
- return "application/octet-stream", nil
-}
-
-func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
- contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
-
- // Check if the file has already been downloaded in this request
- if cachedData, exists := c.Get(contextKey); exists {
- if common.DebugEnabled {
- logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
- }
- return cachedData.(*types.LocalFileData), nil
- }
-
- var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
-
- resp, err := DoDownloadRequest(url, reason...)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
-
- // Always use LimitReader to prevent oversized downloads
- fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
- if err != nil {
- return nil, err
- }
- // Check actual size after reading
- if len(fileBytes) > maxFileSize {
- return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
- }
-
- // Convert to base64
- base64Data := base64.StdEncoding.EncodeToString(fileBytes)
-
- mimeType := resp.Header.Get("Content-Type")
- if len(strings.Split(mimeType, ";")) > 1 {
- // If Content-Type has parameters, take the first part
- mimeType = strings.Split(mimeType, ";")[0]
- }
- if mimeType == "application/octet-stream" {
- logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
- // try to guess the MIME type from the url last segment
- urlParts := strings.Split(url, "/")
- if len(urlParts) > 0 {
- lastSegment := urlParts[len(urlParts)-1]
- if strings.Contains(lastSegment, ".") {
- // Extract the file extension
- filename := strings.Split(lastSegment, ".")
- if len(filename) > 1 {
- ext := strings.ToLower(filename[len(filename)-1])
- // Guess MIME type based on file extension
- mimeType = GetMimeTypeByExtension(ext)
- }
- }
- } else {
- // try to guess the MIME type from the file extension
- fileName := resp.Header.Get("Content-Disposition")
- if fileName != "" {
- // Extract the filename from the Content-Disposition header
- parts := strings.Split(fileName, ";")
- for _, part := range parts {
- if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
- fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
- // Remove quotes if present
- if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
- fileName = fileName[1 : len(fileName)-1]
- }
- // Guess MIME type based on file extension
- if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
- mimeType = GetMimeTypeByExtension(ext)
- }
- break
- }
- }
- }
- }
- }
- data := &types.LocalFileData{
- Base64Data: base64Data,
- MimeType: mimeType,
- Size: int64(len(fileBytes)),
- }
- // Store the file data in the context to avoid re-downloading
- c.Set(contextKey, data)
-
- return data, nil
-}
-
-func GetMimeTypeByExtension(ext string) string {
- // Convert to lowercase for case-insensitive comparison
- ext = strings.ToLower(ext)
- switch ext {
- // Text files
- case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
- return "text/plain"
-
- // Image files
- case "jpg", "jpeg":
- return "image/jpeg"
- case "png":
- return "image/png"
- case "gif":
- return "image/gif"
-
- // Audio files
- case "mp3":
- return "audio/mp3"
- case "wav":
- return "audio/wav"
- case "mpeg":
- return "audio/mpeg"
-
- // Video files
- case "mp4":
- return "video/mp4"
- case "wmv":
- return "video/wmv"
- case "flv":
- return "video/flv"
- case "mov":
- return "video/mov"
- case "mpg":
- return "video/mpg"
- case "avi":
- return "video/avi"
- case "mpegps":
- return "video/mpegps"
-
- // Document files
- case "pdf":
- return "application/pdf"
-
- default:
- return "application/octet-stream" // Default for unknown types
- }
-}
diff --git a/new-api/service/http.go b/new-api/service/http.go
deleted file mode 100644
index f1ec0165feab69fe7b44f8cac9cc4d05e12c758d..0000000000000000000000000000000000000000
--- a/new-api/service/http.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package service
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/logger"
-
- "github.com/gin-gonic/gin"
-)
-
-func CloseResponseBodyGracefully(httpResponse *http.Response) {
- if httpResponse == nil || httpResponse.Body == nil {
- return
- }
- err := httpResponse.Body.Close()
- if err != nil {
- common.SysError("failed to close response body: " + err.Error())
- }
-}
-
-func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
- if c.Writer == nil {
- return
- }
-
- body := io.NopCloser(bytes.NewBuffer(data))
-
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- if src != nil {
- for k, v := range src.Header {
- // avoid setting Content-Length
- if k == "Content-Length" {
- continue
- }
- c.Writer.Header().Set(k, v[0])
- }
- }
-
- // set Content-Length header manually BEFORE calling WriteHeader
- c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
-
- // Write header with status code (this sends the headers)
- if src != nil {
- c.Writer.WriteHeader(src.StatusCode)
- } else {
- c.Writer.WriteHeader(http.StatusOK)
- }
-
- _, err := io.Copy(c.Writer, body)
- if err != nil {
- logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
- }
-}
diff --git a/new-api/service/http_client.go b/new-api/service/http_client.go
deleted file mode 100644
index dd597b9240e10453d4845fba88420f6720dd2d17..0000000000000000000000000000000000000000
--- a/new-api/service/http_client.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package service
-
-import (
- "context"
- "fmt"
- "net"
- "net/http"
- "net/url"
- "one-api/common"
- "sync"
- "time"
-
- "golang.org/x/net/proxy"
-)
-
-var (
- httpClient *http.Client
- proxyClientLock sync.Mutex
- proxyClients = make(map[string]*http.Client)
-)
-
-func InitHttpClient() {
- if common.RelayTimeout == 0 {
- httpClient = &http.Client{}
- } else {
- httpClient = &http.Client{
- Timeout: time.Duration(common.RelayTimeout) * time.Second,
- }
- }
-}
-
-func GetHttpClient() *http.Client {
- return httpClient
-}
-
-// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
-func ResetProxyClientCache() {
- proxyClientLock.Lock()
- defer proxyClientLock.Unlock()
- for _, client := range proxyClients {
- if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
- transport.CloseIdleConnections()
- }
- }
- proxyClients = make(map[string]*http.Client)
-}
-
-// NewProxyHttpClient 创建支持代理的 HTTP 客户端
-func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
- if proxyURL == "" {
- return http.DefaultClient, nil
- }
-
- proxyClientLock.Lock()
- if client, ok := proxyClients[proxyURL]; ok {
- proxyClientLock.Unlock()
- return client, nil
- }
- proxyClientLock.Unlock()
-
- parsedURL, err := url.Parse(proxyURL)
- if err != nil {
- return nil, err
- }
-
- switch parsedURL.Scheme {
- case "http", "https":
- client := &http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyURL(parsedURL),
- },
- }
- client.Timeout = time.Duration(common.RelayTimeout) * time.Second
- proxyClientLock.Lock()
- proxyClients[proxyURL] = client
- proxyClientLock.Unlock()
- return client, nil
-
- case "socks5", "socks5h":
- // 获取认证信息
- var auth *proxy.Auth
- if parsedURL.User != nil {
- auth = &proxy.Auth{
- User: parsedURL.User.Username(),
- Password: "",
- }
- if password, ok := parsedURL.User.Password(); ok {
- auth.Password = password
- }
- }
-
- // 创建 SOCKS5 代理拨号器
- // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同
- dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct)
- if err != nil {
- return nil, err
- }
-
- client := &http.Client{
- Transport: &http.Transport{
- DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
- return dialer.Dial(network, addr)
- },
- },
- }
- client.Timeout = time.Duration(common.RelayTimeout) * time.Second
- proxyClientLock.Lock()
- proxyClients[proxyURL] = client
- proxyClientLock.Unlock()
- return client, nil
-
- default:
- return nil, fmt.Errorf("unsupported proxy scheme: %s, must be http, https, socks5 or socks5h", parsedURL.Scheme)
- }
-}
diff --git a/new-api/service/image.go b/new-api/service/image.go
deleted file mode 100644
index 13b4eb93f9af9dc22d1f31b083a32e1d288cd7d3..0000000000000000000000000000000000000000
--- a/new-api/service/image.go
+++ /dev/null
@@ -1,176 +0,0 @@
-package service
-
-import (
- "bytes"
- "encoding/base64"
- "errors"
- "fmt"
- "image"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "strings"
-
- "golang.org/x/image/webp"
-)
-
-func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
- // 去除base64数据的URL前缀(如果有)
- if idx := strings.Index(base64String, ","); idx != -1 {
- base64String = base64String[idx+1:]
- }
-
- if len(base64String) == 0 {
- return image.Config{}, "", "", errors.New("base64 string is empty")
- }
-
- // 将base64字符串解码为字节切片
- decodedData, err := base64.StdEncoding.DecodeString(base64String)
- if err != nil {
- fmt.Println("Error: Failed to decode base64 string")
- return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
- }
-
- // 创建一个bytes.Buffer用于存储解码后的数据
- reader := bytes.NewReader(decodedData)
- config, format, err := getImageConfig(reader)
- return config, format, base64String, err
-}
-
-func DecodeBase64FileData(base64String string) (string, string, error) {
- var mimeType string
- var idx int
- idx = strings.Index(base64String, ",")
- if idx == -1 {
- _, file_type, base64, err := DecodeBase64ImageData(base64String)
- return "image/" + file_type, base64, err
- }
- mimeType = base64String[:idx]
- base64String = base64String[idx+1:]
- idx = strings.Index(mimeType, ";")
- if idx == -1 {
- _, file_type, base64, err := DecodeBase64ImageData(base64String)
- return "image/" + file_type, base64, err
- }
- mimeType = mimeType[:idx]
- idx = strings.Index(mimeType, ":")
- if idx == -1 {
- _, file_type, base64, err := DecodeBase64ImageData(base64String)
- return "image/" + file_type, base64, err
- }
- mimeType = mimeType[idx+1:]
- return mimeType, base64String, nil
-}
-
-// GetImageFromUrl 获取图片的类型和base64编码的数据
-func GetImageFromUrl(url string) (mimeType string, data string, err error) {
- resp, err := DoDownloadRequest(url)
- if err != nil {
- return "", "", fmt.Errorf("failed to download image: %w", err)
- }
- defer resp.Body.Close()
-
- // Check HTTP status code
- if resp.StatusCode != http.StatusOK {
- return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
- }
-
- contentType := resp.Header.Get("Content-Type")
- if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
- return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
- }
- maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
-
- // Check Content-Length if available
- if resp.ContentLength > maxImageSize {
- return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
- }
-
- // Use LimitReader to prevent reading oversized images
- limitReader := io.LimitReader(resp.Body, maxImageSize)
- buffer := &bytes.Buffer{}
-
- written, err := io.Copy(buffer, limitReader)
- if err != nil {
- return "", "", fmt.Errorf("failed to read image data: %w", err)
- }
- if written >= maxImageSize {
- return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
- }
-
- data = base64.StdEncoding.EncodeToString(buffer.Bytes())
- mimeType = contentType
-
- // Handle application/octet-stream type
- if mimeType == "application/octet-stream" {
- _, format, _, err := DecodeBase64ImageData(data)
- if err != nil {
- return "", "", err
- }
- mimeType = "image/" + format
- }
-
- return mimeType, data, nil
-}
-
-func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
- response, err := DoDownloadRequest(imageUrl)
- if err != nil {
- common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
- return image.Config{}, "", err
- }
- defer response.Body.Close()
-
- if response.StatusCode != 200 {
- err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
- return image.Config{}, "", err
- }
-
- mimeType := response.Header.Get("Content-Type")
-
- if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
- return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
- }
-
- var readData []byte
- for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
- common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
-
- // 从response.Body读取更多的数据直到达到当前的限制
- additionalData := make([]byte, limit-int64(len(readData)))
- n, _ := io.ReadFull(response.Body, additionalData)
- readData = append(readData, additionalData[:n]...)
-
- // 使用io.MultiReader组合已经读取的数据和response.Body
- limitReader := io.MultiReader(bytes.NewReader(readData), response.Body)
-
- var config image.Config
- var format string
- config, format, err = getImageConfig(limitReader)
- if err == nil {
- return config, format, nil
- }
- }
-
- return image.Config{}, "", err // 返回最后一个错误
-}
-
-func getImageConfig(reader io.Reader) (image.Config, string, error) {
- // 读取图片的头部信息来获取图片尺寸
- config, format, err := image.DecodeConfig(reader)
- if err != nil {
- err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
- common.SysLog(err.Error())
- config, err = webp.DecodeConfig(reader)
- if err != nil {
- err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
- common.SysLog(err.Error())
- }
- format = "webp"
- }
- if err != nil {
- return image.Config{}, "", err
- }
- return config, format, nil
-}
diff --git a/new-api/service/log_info_generate.go b/new-api/service/log_info_generate.go
deleted file mode 100644
index 56dcd8bb976b7c7685f7db1a0fafe7c68c936ebd..0000000000000000000000000000000000000000
--- a/new-api/service/log_info_generate.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package service
-
-import (
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
- other := make(map[string]interface{})
- other["model_ratio"] = modelRatio
- other["group_ratio"] = groupRatio
- other["completion_ratio"] = completionRatio
- other["cache_tokens"] = cacheTokens
- other["cache_ratio"] = cacheRatio
- other["model_price"] = modelPrice
- other["user_group_ratio"] = userGroupRatio
- other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
- if relayInfo.ReasoningEffort != "" {
- other["reasoning_effort"] = relayInfo.ReasoningEffort
- }
- if relayInfo.IsModelMapped {
- other["is_model_mapped"] = true
- other["upstream_model_name"] = relayInfo.UpstreamModelName
- }
-
- isSystemPromptOverwritten := common.GetContextKeyBool(ctx, constant.ContextKeySystemPromptOverride)
- if isSystemPromptOverwritten {
- other["is_system_prompt_overwritten"] = true
- }
-
- adminInfo := make(map[string]interface{})
- adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
- isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey)
- if isMultiKey {
- adminInfo["is_multi_key"] = true
- adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
- }
- other["admin_info"] = adminInfo
- return other
-}
-
-func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
- info["ws"] = true
- info["audio_input"] = usage.InputTokenDetails.AudioTokens
- info["audio_output"] = usage.OutputTokenDetails.AudioTokens
- info["text_input"] = usage.InputTokenDetails.TextTokens
- info["text_output"] = usage.OutputTokenDetails.TextTokens
- info["audio_ratio"] = audioRatio
- info["audio_completion_ratio"] = audioCompletionRatio
- return info
-}
-
-func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
- info["audio"] = true
- info["audio_input"] = usage.PromptTokensDetails.AudioTokens
- info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
- info["text_input"] = usage.PromptTokensDetails.TextTokens
- info["text_output"] = usage.CompletionTokenDetails.TextTokens
- info["audio_ratio"] = audioRatio
- info["audio_completion_ratio"] = audioCompletionRatio
- return info
-}
-
-func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
- info["claude"] = true
- info["cache_creation_tokens"] = cacheCreationTokens
- info["cache_creation_ratio"] = cacheCreationRatio
- return info
-}
-
-func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
- other := make(map[string]interface{})
- other["model_price"] = priceData.ModelPrice
- other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
- if priceData.GroupRatioInfo.HasSpecialRatio {
- other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
- }
- return other
-}
diff --git a/new-api/service/midjourney.go b/new-api/service/midjourney.go
deleted file mode 100644
index 0e73f339671148c091b79910578701768af1be72..0000000000000000000000000000000000000000
--- a/new-api/service/midjourney.go
+++ /dev/null
@@ -1,258 +0,0 @@
-package service
-
-import (
- "context"
- "encoding/json"
- "io"
- "log"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relayconstant "one-api/relay/constant"
- "one-api/setting"
- "strconv"
- "strings"
- "time"
-
- "github.com/gin-gonic/gin"
-)
-
-func CoverActionToModelName(mjAction string) string {
- modelName := "mj_" + strings.ToLower(mjAction)
- if mjAction == constant.MjActionSwapFace {
- modelName = "swap_face"
- }
- return modelName
-}
-
-func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
- action := ""
- if relayMode == relayconstant.RelayModeMidjourneyAction {
- // plus request
- err := CoverPlusActionToNormalAction(midjRequest)
- if err != nil {
- return "", err, false
- }
- action = midjRequest.Action
- } else {
- switch relayMode {
- case relayconstant.RelayModeMidjourneyImagine:
- action = constant.MjActionImagine
- case relayconstant.RelayModeMidjourneyVideo:
- action = constant.MjActionVideo
- case relayconstant.RelayModeMidjourneyEdits:
- action = constant.MjActionEdits
- case relayconstant.RelayModeMidjourneyDescribe:
- action = constant.MjActionDescribe
- case relayconstant.RelayModeMidjourneyBlend:
- action = constant.MjActionBlend
- case relayconstant.RelayModeMidjourneyShorten:
- action = constant.MjActionShorten
- case relayconstant.RelayModeMidjourneyChange:
- action = midjRequest.Action
- case relayconstant.RelayModeMidjourneyModal:
- action = constant.MjActionModal
- case relayconstant.RelayModeSwapFace:
- action = constant.MjActionSwapFace
- case relayconstant.RelayModeMidjourneyUpload:
- action = constant.MjActionUpload
- case relayconstant.RelayModeMidjourneySimpleChange:
- params := ConvertSimpleChangeParams(midjRequest.Content)
- if params == nil {
- return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
- }
- action = params.Action
- case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
- return "", nil, true
- default:
- return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
- }
- }
- modelName := CoverActionToModelName(action)
- return modelName, nil, true
-}
-
-func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
- // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
- customId := midjRequest.CustomId
- if customId == "" {
- return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
- }
- splits := strings.Split(customId, "::")
- var action string
- if splits[1] == "JOB" {
- action = splits[2]
- } else {
- action = splits[1]
- }
-
- if action == "" {
- return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
- }
- if strings.Contains(action, "upsample") {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionUpscale
- } else if strings.Contains(action, "variation") {
- midjRequest.Index = 1
- if action == "variation" {
- index, err := strconv.Atoi(splits[3])
- if err != nil {
- return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
- }
- midjRequest.Index = index
- midjRequest.Action = constant.MjActionVariation
- } else if action == "low_variation" {
- midjRequest.Action = constant.MjActionLowVariation
- } else if action == "high_variation" {
- midjRequest.Action = constant.MjActionHighVariation
- }
- } else if strings.Contains(action, "pan") {
- midjRequest.Action = constant.MjActionPan
- midjRequest.Index = 1
- } else if strings.Contains(action, "reroll") {
- midjRequest.Action = constant.MjActionReRoll
- midjRequest.Index = 1
- } else if action == "Outpaint" {
- midjRequest.Action = constant.MjActionZoom
- midjRequest.Index = 1
- } else if action == "CustomZoom" {
- midjRequest.Action = constant.MjActionCustomZoom
- midjRequest.Index = 1
- } else if action == "Inpaint" {
- midjRequest.Action = constant.MjActionInPaint
- midjRequest.Index = 1
- } else {
- return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
- }
- return nil
-}
-
-func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
- split := strings.Split(content, " ")
- if len(split) != 2 {
- return nil
- }
-
- action := strings.ToLower(split[1])
- changeParams := &dto.MidjourneyRequest{}
- changeParams.TaskId = split[0]
-
- if action[0] == 'u' {
- changeParams.Action = "UPSCALE"
- } else if action[0] == 'v' {
- changeParams.Action = "VARIATION"
- } else if action == "r" {
- changeParams.Action = "REROLL"
- return changeParams
- } else {
- return nil
- }
-
- index, err := strconv.Atoi(action[1:2])
- if err != nil || index < 1 || index > 4 {
- return nil
- }
- changeParams.Index = index
- return changeParams
-}
-
-func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
- var nullBytes []byte
- //var requestBody io.Reader
- //requestBody = c.Request.Body
- // read request body to json, delete accountFilter and notifyHook
- var mapResult map[string]interface{}
- // if get request, no need to read request body
- if c.Request.Method != "GET" {
- err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
- }
- if !setting.MjAccountFilterEnabled {
- delete(mapResult, "accountFilter")
- }
- if !setting.MjNotifyEnabled {
- delete(mapResult, "notifyHook")
- }
- //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- // make new request with mapResult
- }
- if setting.MjModeClearEnabled {
- if prompt, ok := mapResult["prompt"].(string); ok {
- prompt = strings.Replace(prompt, "--fast", "", -1)
- prompt = strings.Replace(prompt, "--relax", "", -1)
- prompt = strings.Replace(prompt, "--turbo", "", -1)
-
- mapResult["prompt"] = prompt
- }
- }
- reqBody, err := json.Marshal(mapResult)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
- }
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
- }
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- // 使用带有超时的 context 创建新的请求
- req = req.WithContext(ctx)
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
- if auth != "" {
- auth = strings.TrimPrefix(auth, "Bearer ")
- req.Header.Set("mj-api-secret", auth)
- }
- defer cancel()
- resp, err := GetHttpClient().Do(req)
- if err != nil {
- common.SysLog("do request failed: " + err.Error())
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
- }
- statusCode := resp.StatusCode
- //if statusCode != 200 {
- // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
- //}
- err = req.Body.Close()
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
- }
- err = c.Request.Body.Close()
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
- }
- var midjResponse dto.MidjourneyResponse
- var midjourneyUploadsResponse dto.MidjourneyUploadResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
- }
- CloseResponseBodyGracefully(resp)
- respStr := string(responseBody)
- log.Printf("respStr: %s", respStr)
- if respStr == "" {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
- } else {
- err = json.Unmarshal(responseBody, &midjResponse)
- if err != nil {
- err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
- if err2 != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
- }
- }
- }
- //log.Printf("midjResponse: %v", midjResponse)
- //for k, v := range resp.Header {
- // c.Writer.Header().Set(k, v[0])
- //}
- return &dto.MidjourneyResponseWithStatusCode{
- StatusCode: statusCode,
- Response: midjResponse,
- }, responseBody, nil
-}
diff --git a/new-api/service/notify-limit.go b/new-api/service/notify-limit.go
deleted file mode 100644
index 46129a8dfb91a2a634da4042248c3904b38522cc..0000000000000000000000000000000000000000
--- a/new-api/service/notify-limit.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package service
-
-import (
- "fmt"
- "github.com/bytedance/gopkg/util/gopool"
- "one-api/common"
- "one-api/constant"
- "strconv"
- "sync"
- "time"
-)
-
-// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
-var (
- notifyLimitStore sync.Map
- cleanupOnce sync.Once
-)
-
-type limitCount struct {
- Count int
- Timestamp time.Time
-}
-
-func getDuration() time.Duration {
- minute := constant.NotificationLimitDurationMinute
- return time.Duration(minute) * time.Minute
-}
-
-// startCleanupTask starts a background task to clean up expired entries
-func startCleanupTask() {
- gopool.Go(func() {
- for {
- time.Sleep(time.Hour)
- now := time.Now()
- notifyLimitStore.Range(func(key, value interface{}) bool {
- if limit, ok := value.(limitCount); ok {
- if now.Sub(limit.Timestamp) >= getDuration() {
- notifyLimitStore.Delete(key)
- }
- }
- return true
- })
- }
- })
-}
-
-// CheckNotificationLimit checks if the user has exceeded their notification limit
-// Returns true if the user can send notification, false if limit exceeded
-func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
- if common.RedisEnabled {
- return checkRedisLimit(userId, notifyType)
- }
- return checkMemoryLimit(userId, notifyType)
-}
-
-func checkRedisLimit(userId int, notifyType string) (bool, error) {
- key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
-
- // Get current count
- count, err := common.RedisGet(key)
- if err != nil && err.Error() != "redis: nil" {
- return false, fmt.Errorf("failed to get notification count: %w", err)
- }
-
- // If key doesn't exist, initialize it
- if count == "" {
- err = common.RedisSet(key, "1", getDuration())
- return true, err
- }
-
- currentCount, _ := strconv.Atoi(count)
- limit := constant.NotifyLimitCount
-
- // Check if limit is already reached
- if currentCount >= limit {
- return false, nil
- }
-
- // Only increment if under limit
- err = common.RedisIncr(key, 1)
- if err != nil {
- return false, fmt.Errorf("failed to increment notification count: %w", err)
- }
-
- return true, nil
-}
-
-func checkMemoryLimit(userId int, notifyType string) (bool, error) {
- // Ensure cleanup task is started
- cleanupOnce.Do(startCleanupTask)
-
- key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
- now := time.Now()
-
- // Get current limit count or initialize new one
- var currentLimit limitCount
- if value, ok := notifyLimitStore.Load(key); ok {
- currentLimit = value.(limitCount)
- // Check if the entry has expired
- if now.Sub(currentLimit.Timestamp) >= getDuration() {
- currentLimit = limitCount{Count: 0, Timestamp: now}
- }
- } else {
- currentLimit = limitCount{Count: 0, Timestamp: now}
- }
-
- // Increment count
- currentLimit.Count++
-
- // Check against limits
- limit := constant.NotifyLimitCount
-
- // Store updated count
- notifyLimitStore.Store(key, currentLimit)
-
- return currentLimit.Count <= limit, nil
-}
diff --git a/new-api/service/passkey/service.go b/new-api/service/passkey/service.go
deleted file mode 100644
index a130e73a4fc556666f6916ecbddb86255cd7ca44..0000000000000000000000000000000000000000
--- a/new-api/service/passkey/service.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package passkey
-
-import (
- "errors"
- "fmt"
- "net"
- "net/http"
- "net/url"
- "strings"
- "time"
-
- "one-api/common"
- "one-api/setting/system_setting"
-
- "github.com/go-webauthn/webauthn/protocol"
- webauthn "github.com/go-webauthn/webauthn/webauthn"
-)
-
-const (
- RegistrationSessionKey = "passkey_registration_session"
- LoginSessionKey = "passkey_login_session"
- VerifySessionKey = "passkey_verify_session"
-)
-
-// BuildWebAuthn constructs a WebAuthn instance using the current passkey settings and request context.
-func BuildWebAuthn(r *http.Request) (*webauthn.WebAuthn, error) {
- settings := system_setting.GetPasskeySettings()
- if settings == nil {
- return nil, errors.New("未找到 Passkey 设置")
- }
-
- displayName := strings.TrimSpace(settings.RPDisplayName)
- if displayName == "" {
- displayName = common.SystemName
- }
-
- origins, err := resolveOrigins(r, settings)
- if err != nil {
- return nil, err
- }
-
- rpID, err := resolveRPID(r, settings, origins)
- if err != nil {
- return nil, err
- }
-
- selection := protocol.AuthenticatorSelection{
- ResidentKey: protocol.ResidentKeyRequirementRequired,
- RequireResidentKey: protocol.ResidentKeyRequired(),
- UserVerification: protocol.UserVerificationRequirement(settings.UserVerification),
- }
- if selection.UserVerification == "" {
- selection.UserVerification = protocol.VerificationPreferred
- }
- if attachment := strings.TrimSpace(settings.AttachmentPreference); attachment != "" {
- selection.AuthenticatorAttachment = protocol.AuthenticatorAttachment(attachment)
- }
-
- config := &webauthn.Config{
- RPID: rpID,
- RPDisplayName: displayName,
- RPOrigins: origins,
- AuthenticatorSelection: selection,
- Debug: common.DebugEnabled,
- Timeouts: webauthn.TimeoutsConfig{
- Login: webauthn.TimeoutConfig{
- Enforce: true,
- Timeout: 2 * time.Minute,
- TimeoutUVD: 2 * time.Minute,
- },
- Registration: webauthn.TimeoutConfig{
- Enforce: true,
- Timeout: 2 * time.Minute,
- TimeoutUVD: 2 * time.Minute,
- },
- },
- }
-
- return webauthn.New(config)
-}
-
-func resolveOrigins(r *http.Request, settings *system_setting.PasskeySettings) ([]string, error) {
- originsStr := strings.TrimSpace(settings.Origins)
- if originsStr != "" {
- originList := strings.Split(originsStr, ",")
- origins := make([]string, 0, len(originList))
- for _, origin := range originList {
- trimmed := strings.TrimSpace(origin)
- if trimmed == "" {
- continue
- }
- if !settings.AllowInsecureOrigin && strings.HasPrefix(strings.ToLower(trimmed), "http://") {
- return nil, fmt.Errorf("Passkey 不允许使用不安全的 Origin: %s", trimmed)
- }
- origins = append(origins, trimmed)
- }
- if len(origins) == 0 {
- // 如果配置了Origins但过滤后为空,使用自动推导
- goto autoDetect
- }
- return origins, nil
- }
-
-autoDetect:
- scheme := detectScheme(r)
- if scheme == "http" && !settings.AllowInsecureOrigin && r.Host != "localhost" && r.Host != "127.0.0.1" && !strings.HasPrefix(r.Host, "127.0.0.1:") && !strings.HasPrefix(r.Host, "localhost:") {
- return nil, fmt.Errorf("Passkey 仅支持 HTTPS,当前访问: %s://%s,请在 Passkey 设置中允许不安全 Origin 或配置 HTTPS", scheme, r.Host)
- }
- // 优先使用请求的完整Host(包含端口)
- host := r.Host
-
- // 如果无法从请求获取Host,尝试从ServerAddress获取
- if host == "" && system_setting.ServerAddress != "" {
- if parsed, err := url.Parse(system_setting.ServerAddress); err == nil && parsed.Host != "" {
- host = parsed.Host
- if scheme == "" && parsed.Scheme != "" {
- scheme = parsed.Scheme
- }
- }
- }
- if host == "" {
- return nil, fmt.Errorf("无法确定 Passkey 的 Origin,请在系统设置或 Passkey 设置中指定。当前 Host: '%s', ServerAddress: '%s'", r.Host, system_setting.ServerAddress)
- }
- if scheme == "" {
- scheme = "https"
- }
- origin := fmt.Sprintf("%s://%s", scheme, host)
- return []string{origin}, nil
-}
-
-func resolveRPID(r *http.Request, settings *system_setting.PasskeySettings, origins []string) (string, error) {
- rpID := strings.TrimSpace(settings.RPID)
- if rpID != "" {
- return hostWithoutPort(rpID), nil
- }
- if len(origins) == 0 {
- return "", errors.New("Passkey 未配置 Origin,无法推导 RPID")
- }
- parsed, err := url.Parse(origins[0])
- if err != nil {
- return "", fmt.Errorf("无法解析 Passkey Origin: %w", err)
- }
- return hostWithoutPort(parsed.Host), nil
-}
-
-func hostWithoutPort(host string) string {
- host = strings.TrimSpace(host)
- if host == "" {
- return ""
- }
- if strings.Contains(host, ":") {
- if host, _, err := net.SplitHostPort(host); err == nil {
- return host
- }
- }
- return host
-}
-
-func detectScheme(r *http.Request) string {
- if r == nil {
- return ""
- }
- if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
- parts := strings.Split(proto, ",")
- return strings.ToLower(strings.TrimSpace(parts[0]))
- }
- if r.TLS != nil {
- return "https"
- }
- if r.URL != nil && r.URL.Scheme != "" {
- return strings.ToLower(r.URL.Scheme)
- }
- if r.Header.Get("X-Forwarded-Protocol") != "" {
- return strings.ToLower(strings.TrimSpace(r.Header.Get("X-Forwarded-Protocol")))
- }
- return "http"
-}
diff --git a/new-api/service/passkey/session.go b/new-api/service/passkey/session.go
deleted file mode 100644
index c8f96fd3dc17fafdee0d07d9dc40fca500ef3290..0000000000000000000000000000000000000000
--- a/new-api/service/passkey/session.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package passkey
-
-import (
- "encoding/json"
- "errors"
-
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- webauthn "github.com/go-webauthn/webauthn/webauthn"
-)
-
-var errSessionNotFound = errors.New("Passkey 会话不存在或已过期")
-
-func SaveSessionData(c *gin.Context, key string, data *webauthn.SessionData) error {
- session := sessions.Default(c)
- if data == nil {
- session.Delete(key)
- return session.Save()
- }
- payload, err := json.Marshal(data)
- if err != nil {
- return err
- }
- session.Set(key, string(payload))
- return session.Save()
-}
-
-func PopSessionData(c *gin.Context, key string) (*webauthn.SessionData, error) {
- session := sessions.Default(c)
- raw := session.Get(key)
- if raw == nil {
- return nil, errSessionNotFound
- }
- session.Delete(key)
- _ = session.Save()
- var data webauthn.SessionData
- switch value := raw.(type) {
- case string:
- if err := json.Unmarshal([]byte(value), &data); err != nil {
- return nil, err
- }
- case []byte:
- if err := json.Unmarshal(value, &data); err != nil {
- return nil, err
- }
- default:
- return nil, errors.New("Passkey 会话格式无效")
- }
- return &data, nil
-}
diff --git a/new-api/service/passkey/user.go b/new-api/service/passkey/user.go
deleted file mode 100644
index 64484fecf3790d1c61e4156bdb80a8cee11ca9ec..0000000000000000000000000000000000000000
--- a/new-api/service/passkey/user.go
+++ /dev/null
@@ -1,71 +0,0 @@
-package passkey
-
-import (
- "fmt"
- "strconv"
- "strings"
-
- "one-api/model"
-
- webauthn "github.com/go-webauthn/webauthn/webauthn"
-)
-
-type WebAuthnUser struct {
- user *model.User
- credential *model.PasskeyCredential
-}
-
-func NewWebAuthnUser(user *model.User, credential *model.PasskeyCredential) *WebAuthnUser {
- return &WebAuthnUser{user: user, credential: credential}
-}
-
-func (u *WebAuthnUser) WebAuthnID() []byte {
- if u == nil || u.user == nil {
- return nil
- }
- return []byte(strconv.Itoa(u.user.Id))
-}
-
-func (u *WebAuthnUser) WebAuthnName() string {
- if u == nil || u.user == nil {
- return ""
- }
- name := strings.TrimSpace(u.user.Username)
- if name == "" {
- return fmt.Sprintf("user-%d", u.user.Id)
- }
- return name
-}
-
-func (u *WebAuthnUser) WebAuthnDisplayName() string {
- if u == nil || u.user == nil {
- return ""
- }
- display := strings.TrimSpace(u.user.DisplayName)
- if display != "" {
- return display
- }
- return u.WebAuthnName()
-}
-
-func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
- if u == nil || u.credential == nil {
- return nil
- }
- cred := u.credential.ToWebAuthnCredential()
- return []webauthn.Credential{cred}
-}
-
-func (u *WebAuthnUser) ModelUser() *model.User {
- if u == nil {
- return nil
- }
- return u.user
-}
-
-func (u *WebAuthnUser) PasskeyCredential() *model.PasskeyCredential {
- if u == nil {
- return nil
- }
- return u.credential
-}
diff --git a/new-api/service/pre_consume_quota.go b/new-api/service/pre_consume_quota.go
deleted file mode 100644
index 60f77fd3d2e1ce2d0634ed0473585baa28c8265d..0000000000000000000000000000000000000000
--- a/new-api/service/pre_consume_quota.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package service
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/logger"
- "one-api/model"
- relaycommon "one-api/relay/common"
- "one-api/types"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
-)
-
-func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
- if relayInfo.FinalPreConsumedQuota != 0 {
- logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
- gopool.Go(func() {
- relayInfoCopy := *relayInfo
-
- err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
- if err != nil {
- common.SysLog("error return pre-consumed quota: " + err.Error())
- }
- })
- }
-}
-
-// PreConsumeQuota checks if the user has enough quota to pre-consume.
-// It returns the pre-consumed quota if successful, or an error if not.
-func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
- }
- if userQuota <= 0 {
- return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- if userQuota-preConsumedQuota < 0 {
- return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
-
- trustQuota := common.GetTrustQuota()
-
- relayInfo.UserQuota = userQuota
- if userQuota > trustQuota {
- // 用户额度充足,判断令牌额度是否充足
- if !relayInfo.TokenUnlimited {
- // 非无限令牌,判断令牌额度是否充足
- tokenQuota := c.GetInt("token_quota")
- if tokenQuota > trustQuota {
- // 令牌额度充足,信任令牌
- preConsumedQuota = 0
- logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
- }
- } else {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
- }
- }
-
- if preConsumedQuota > 0 {
- err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
- if err != nil {
- return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
- }
- err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
- if err != nil {
- return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
- }
- logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
- }
- relayInfo.FinalPreConsumedQuota = preConsumedQuota
- return nil
-}
diff --git a/new-api/service/quota.go b/new-api/service/quota.go
deleted file mode 100644
index dbf7061b9284120fa1a05d0fc0117b7fccbc9abe..0000000000000000000000000000000000000000
--- a/new-api/service/quota.go
+++ /dev/null
@@ -1,564 +0,0 @@
-package service
-
-import (
- "errors"
- "fmt"
- "log"
- "math"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/logger"
- "one-api/model"
- relaycommon "one-api/relay/common"
- "one-api/setting/ratio_setting"
- "one-api/setting/system_setting"
- "one-api/types"
- "strings"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
-
- "github.com/gin-gonic/gin"
- "github.com/shopspring/decimal"
-)
-
-type TokenDetails struct {
- TextTokens int
- AudioTokens int
-}
-
-type QuotaInfo struct {
- InputDetails TokenDetails
- OutputDetails TokenDetails
- ModelName string
- UsePrice bool
- ModelPrice float64
- ModelRatio float64
- GroupRatio float64
-}
-
-func hasCustomModelRatio(modelName string, currentRatio float64) bool {
- defaultRatio, exists := ratio_setting.GetDefaultModelRatioMap()[modelName]
- if !exists {
- return true
- }
- return currentRatio != defaultRatio
-}
-
-func calculateAudioQuota(info QuotaInfo) int {
- if info.UsePrice {
- modelPrice := decimal.NewFromFloat(info.ModelPrice)
- quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
- groupRatio := decimal.NewFromFloat(info.GroupRatio)
-
- quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
- return int(quota.IntPart())
- }
-
- completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
- audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
- audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
-
- groupRatio := decimal.NewFromFloat(info.GroupRatio)
- modelRatio := decimal.NewFromFloat(info.ModelRatio)
- ratio := groupRatio.Mul(modelRatio)
-
- inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
- outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
- inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
- outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
-
- quota := decimal.Zero
- quota = quota.Add(inputTextTokens)
- quota = quota.Add(outputTextTokens.Mul(completionRatio))
- quota = quota.Add(inputAudioTokens.Mul(audioRatio))
- quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
-
- quota = quota.Mul(ratio)
-
- // If ratio is not zero and quota is less than or equal to zero, set quota to 1
- if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
- quota = decimal.NewFromInt(1)
- }
-
- return int(quota.Round(0).IntPart())
-}
-
-func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
- if relayInfo.UsePrice {
- return nil
- }
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return err
- }
-
- token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
- if err != nil {
- return err
- }
-
- modelName := relayInfo.OriginModelName
- textInputTokens := usage.InputTokenDetails.TextTokens
- textOutTokens := usage.OutputTokenDetails.TextTokens
- audioInputTokens := usage.InputTokenDetails.AudioTokens
- audioOutTokens := usage.OutputTokenDetails.AudioTokens
- groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
- modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
-
- autoGroup, exists := ctx.Get("auto_group")
- if exists {
- groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
- log.Printf("final group ratio: %f", groupRatio)
- relayInfo.UsingGroup = autoGroup.(string)
- }
-
- actualGroupRatio := groupRatio
- userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
- if ok {
- actualGroupRatio = userGroupRatio
- }
-
- quotaInfo := QuotaInfo{
- InputDetails: TokenDetails{
- TextTokens: textInputTokens,
- AudioTokens: audioInputTokens,
- },
- OutputDetails: TokenDetails{
- TextTokens: textOutTokens,
- AudioTokens: audioOutTokens,
- },
- ModelName: modelName,
- UsePrice: relayInfo.UsePrice,
- ModelRatio: modelRatio,
- GroupRatio: actualGroupRatio,
- }
-
- quota := calculateAudioQuota(quotaInfo)
-
- if userQuota < quota {
- return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
- }
-
- if !token.UnlimitedQuota && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
- }
-
- err = PostConsumeQuota(relayInfo, quota, 0, false)
- if err != nil {
- return err
- }
- logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
- return nil
-}
-
-func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, extraContent string) {
-
- useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
- textInputTokens := usage.InputTokenDetails.TextTokens
- textOutTokens := usage.OutputTokenDetails.TextTokens
-
- audioInputTokens := usage.InputTokenDetails.AudioTokens
- audioOutTokens := usage.OutputTokenDetails.AudioTokens
-
- tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
- audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
-
- modelRatio := relayInfo.PriceData.ModelRatio
- groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
- modelPrice := relayInfo.PriceData.ModelPrice
- usePrice := relayInfo.PriceData.UsePrice
-
- quotaInfo := QuotaInfo{
- InputDetails: TokenDetails{
- TextTokens: textInputTokens,
- AudioTokens: audioInputTokens,
- },
- OutputDetails: TokenDetails{
- TextTokens: textOutTokens,
- AudioTokens: audioOutTokens,
- },
- ModelName: modelName,
- UsePrice: usePrice,
- ModelRatio: modelRatio,
- GroupRatio: groupRatio,
- }
-
- quota := calculateAudioQuota(quotaInfo)
-
- totalTokens := usage.TotalTokens
- var logContent string
- if !usePrice {
- logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
- modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
- } else {
- logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
- }
-
- // record all the consume log even if quota is 0
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- logContent += fmt.Sprintf("(可能是上游超时)")
- logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
-
- logModel := modelName
- if extraContent != "" {
- logContent += ", " + extraContent
- }
- other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
- model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- PromptTokens: usage.InputTokens,
- CompletionTokens: usage.OutputTokens,
- ModelName: logModel,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: relayInfo.TokenId,
- UseTimeSeconds: int(useTimeSeconds),
- IsStream: relayInfo.IsStream,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
-}
-
-func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
-
- useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
- promptTokens := usage.PromptTokens
- completionTokens := usage.CompletionTokens
- modelName := relayInfo.OriginModelName
-
- tokenName := ctx.GetString("token_name")
- completionRatio := relayInfo.PriceData.CompletionRatio
- modelRatio := relayInfo.PriceData.ModelRatio
- groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
- modelPrice := relayInfo.PriceData.ModelPrice
- cacheRatio := relayInfo.PriceData.CacheRatio
- cacheTokens := usage.PromptTokensDetails.CachedTokens
-
- cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
- cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
-
- if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
- promptTokens -= cacheTokens
- isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
- if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
- maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
- if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
- cacheCreationTokens = maybeCacheCreationTokens
- }
- }
- promptTokens -= cacheCreationTokens
- }
-
- calculateQuota := 0.0
- if !relayInfo.PriceData.UsePrice {
- calculateQuota = float64(promptTokens)
- calculateQuota += float64(cacheTokens) * cacheRatio
- calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
- calculateQuota += float64(completionTokens) * completionRatio
- calculateQuota = calculateQuota * groupRatio * modelRatio
- } else {
- calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
- }
-
- if modelRatio != 0 && calculateQuota <= 0 {
- calculateQuota = 1
- }
-
- quota := int(calculateQuota)
-
- totalTokens := promptTokens + completionTokens
-
- var logContent string
- // record all the consume log even if quota is 0
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- logContent += fmt.Sprintf("(可能是上游出错)")
- logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
-
- quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
- if quotaDelta > 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- } else if quotaDelta < 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(-quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- }
-
- if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
- if err != nil {
- logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- }
-
- other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
- model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- ModelName: modelName,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: relayInfo.TokenId,
- UseTimeSeconds: int(useTimeSeconds),
- IsStream: relayInfo.IsStream,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
-
-}
-
-func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
- if priceData.CacheCreationRatio == 1 {
- return 0
- }
- quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
- promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
- promptCacheReadPrice := quotaPrice * priceData.CacheRatio
- completionPrice := quotaPrice * priceData.CompletionRatio
-
- cost, _ := usage.Cost.(float64)
- totalPromptTokens := float64(usage.PromptTokens)
- completionTokens := float64(usage.CompletionTokens)
- promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
-
- return int(math.Round((cost -
- totalPromptTokens*quotaPrice +
- promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
- completionTokens*completionPrice) /
- (promptCacheCreatePrice - quotaPrice)))
-}
-
-func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
-
- useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
- textInputTokens := usage.PromptTokensDetails.TextTokens
- textOutTokens := usage.CompletionTokenDetails.TextTokens
-
- audioInputTokens := usage.PromptTokensDetails.AudioTokens
- audioOutTokens := usage.CompletionTokenDetails.AudioTokens
-
- tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
- audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
-
- modelRatio := relayInfo.PriceData.ModelRatio
- groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
- modelPrice := relayInfo.PriceData.ModelPrice
- usePrice := relayInfo.PriceData.UsePrice
-
- quotaInfo := QuotaInfo{
- InputDetails: TokenDetails{
- TextTokens: textInputTokens,
- AudioTokens: audioInputTokens,
- },
- OutputDetails: TokenDetails{
- TextTokens: textOutTokens,
- AudioTokens: audioOutTokens,
- },
- ModelName: relayInfo.OriginModelName,
- UsePrice: usePrice,
- ModelRatio: modelRatio,
- GroupRatio: groupRatio,
- }
-
- quota := calculateAudioQuota(quotaInfo)
-
- totalTokens := usage.TotalTokens
- var logContent string
- if !usePrice {
- logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
- modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
- } else {
- logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
- }
-
- // record all the consume log even if quota is 0
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- logContent += fmt.Sprintf("(可能是上游超时)")
- logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
-
- quotaDelta := quota - relayInfo.FinalPreConsumedQuota
-
- if quotaDelta > 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- } else if quotaDelta < 0 {
- logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
- logger.FormatQuota(-quotaDelta),
- logger.FormatQuota(quota),
- logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
- ))
- }
-
- if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
- if err != nil {
- logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- }
-
- logModel := relayInfo.OriginModelName
- if extraContent != "" {
- logContent += ", " + extraContent
- }
- other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
- model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
- ChannelId: relayInfo.ChannelId,
- PromptTokens: usage.PromptTokens,
- CompletionTokens: usage.CompletionTokens,
- ModelName: logModel,
- TokenName: tokenName,
- Quota: quota,
- Content: logContent,
- TokenId: relayInfo.TokenId,
- UseTimeSeconds: int(useTimeSeconds),
- IsStream: relayInfo.IsStream,
- Group: relayInfo.UsingGroup,
- Other: other,
- })
-}
-
-func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- if relayInfo.IsPlayground {
- return nil
- }
- //if relayInfo.TokenUnlimited {
- // return nil
- //}
- token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
- if err != nil {
- return err
- }
- if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
- }
- err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- if err != nil {
- return err
- }
- return nil
-}
-
-func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
-
- if quota > 0 {
- err = model.DecreaseUserQuota(relayInfo.UserId, quota)
- } else {
- err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
- }
- if err != nil {
- return err
- }
-
- if !relayInfo.IsPlayground {
- if quota > 0 {
- err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- } else {
- err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
- }
- if err != nil {
- return err
- }
- }
-
- if sendEmail {
- if (quota + preConsumedQuota) != 0 {
- checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
- }
- }
-
- return nil
-}
-
-func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
- gopool.Go(func() {
- userSetting := relayInfo.UserSetting
- threshold := common.QuotaRemindThreshold
- if userSetting.QuotaWarningThreshold != 0 {
- threshold = int(userSetting.QuotaWarningThreshold)
- }
-
- //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
- quotaTooLow := false
- consumeQuota := quota + preConsumedQuota
- if relayInfo.UserQuota-consumeQuota < threshold {
- quotaTooLow = true
- }
- if quotaTooLow {
- prompt := "您的额度即将用尽"
- topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
-
- // 根据通知方式生成不同的内容格式
- var content string
- var values []interface{}
-
- notifyType := userSetting.NotifyType
- if notifyType == "" {
- notifyType = dto.NotifyTypeEmail
- }
-
- if notifyType == dto.NotifyTypeBark {
- // Bark推送使用简短文本,不支持HTML
- content = "{{value}},剩余额度:{{value}},请及时充值"
- values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
- } else {
- // 默认内容格式,适用于Email和Webhook
- content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
- values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
- }
-
- err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values))
- if err != nil {
- common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
- }
- }
- })
-}
diff --git a/new-api/service/sensitive.go b/new-api/service/sensitive.go
deleted file mode 100644
index f01d477f163513a7edecd0901245101e6b1f3ddf..0000000000000000000000000000000000000000
--- a/new-api/service/sensitive.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package service
-
-import (
- "errors"
- "one-api/dto"
- "one-api/setting"
- "strings"
-)
-
-func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
- if len(messages) == 0 {
- return nil, nil
- }
-
- for _, message := range messages {
- arrayContent := message.ParseContent()
- for _, m := range arrayContent {
- if m.Type == "image_url" {
- // TODO: check image url
- continue
- }
- // 检查 text 是否为空
- if m.Text == "" {
- continue
- }
- if ok, words := SensitiveWordContains(m.Text); ok {
- return words, errors.New("sensitive words detected")
- }
- }
- }
- return nil, nil
-}
-
-func CheckSensitiveText(text string) (bool, []string) {
- return SensitiveWordContains(text)
-}
-
-// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
-func SensitiveWordContains(text string) (bool, []string) {
- if len(setting.SensitiveWords) == 0 {
- return false, nil
- }
- if len(text) == 0 {
- return false, nil
- }
- checkText := strings.ToLower(text)
- return AcSearch(checkText, setting.SensitiveWords, true)
-}
-
-// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
-func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
- if len(setting.SensitiveWords) == 0 {
- return false, nil, text
- }
- checkText := strings.ToLower(text)
- m := getOrBuildAC(setting.SensitiveWords)
- hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
- if len(hits) > 0 {
- words := make([]string, 0, len(hits))
- var builder strings.Builder
- builder.Grow(len(text))
- lastPos := 0
-
- for _, hit := range hits {
- pos := hit.Pos
- word := string(hit.Word)
- builder.WriteString(text[lastPos:pos])
- builder.WriteString("**###**")
- lastPos = pos + len(word)
- words = append(words, word)
- }
- builder.WriteString(text[lastPos:])
- return true, words, builder.String()
- }
- return false, nil, text
-}
diff --git a/new-api/service/str.go b/new-api/service/str.go
deleted file mode 100644
index e0496a1ac7ab32537ea64c30a16c1008d9aef369..0000000000000000000000000000000000000000
--- a/new-api/service/str.go
+++ /dev/null
@@ -1,152 +0,0 @@
-package service
-
-import (
- "bytes"
- "fmt"
- "hash/fnv"
- "sort"
- "strings"
- "sync"
-
- goahocorasick "github.com/anknown/ahocorasick"
-)
-
-func SundaySearch(text string, pattern string) bool {
- // 计算偏移表
- offset := make(map[rune]int)
- for i, c := range pattern {
- offset[c] = len(pattern) - i
- }
-
- // 文本串长度和模式串长度
- n, m := len(text), len(pattern)
-
- // 主循环,i表示当前对齐的文本串位置
- for i := 0; i <= n-m; {
- // 检查子串
- j := 0
- for j < m && text[i+j] == pattern[j] {
- j++
- }
- // 如果完全匹配,返回匹配位置
- if j == m {
- return true
- }
-
- // 如果还有剩余字符,则检查下一位字符在偏移表中的值
- if i+m < n {
- next := rune(text[i+m])
- if val, ok := offset[next]; ok {
- i += val // 存在于偏移表中,进行跳跃
- } else {
- i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
- }
- } else {
- break
- }
- }
- return false // 如果没有找到匹配,返回-1
-}
-
-func RemoveDuplicate(s []string) []string {
- result := make([]string, 0, len(s))
- temp := map[string]struct{}{}
- for _, item := range s {
- if _, ok := temp[item]; !ok {
- temp[item] = struct{}{}
- result = append(result, item)
- }
- }
- return result
-}
-
-func InitAc(dict []string) *goahocorasick.Machine {
- m := new(goahocorasick.Machine)
- runes := readRunes(dict)
- if err := m.Build(runes); err != nil {
- fmt.Println(err)
- return nil
- }
- return m
-}
-
-var acCache sync.Map
-
-func acKey(dict []string) string {
- if len(dict) == 0 {
- return ""
- }
- normalized := make([]string, 0, len(dict))
- for _, w := range dict {
- w = strings.ToLower(strings.TrimSpace(w))
- if w != "" {
- normalized = append(normalized, w)
- }
- }
- if len(normalized) == 0 {
- return ""
- }
- sort.Strings(normalized)
- hasher := fnv.New64a()
- for _, w := range normalized {
- hasher.Write([]byte{0})
- hasher.Write([]byte(w))
- }
- return fmt.Sprintf("%x", hasher.Sum64())
-}
-
-func getOrBuildAC(dict []string) *goahocorasick.Machine {
- key := acKey(dict)
- if key == "" {
- return nil
- }
- if v, ok := acCache.Load(key); ok {
- if m, ok2 := v.(*goahocorasick.Machine); ok2 {
- return m
- }
- }
- m := InitAc(dict)
- if m == nil {
- return nil
- }
- if actual, loaded := acCache.LoadOrStore(key, m); loaded {
- if cached, ok := actual.(*goahocorasick.Machine); ok {
- return cached
- }
- }
- return m
-}
-
-func readRunes(dict []string) [][]rune {
- var runes [][]rune
-
- for _, word := range dict {
- word = strings.ToLower(word)
- l := bytes.TrimSpace([]byte(word))
- runes = append(runes, bytes.Runes(l))
- }
-
- return runes
-}
-
-func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
- if len(dict) == 0 {
- return false, nil
- }
- if len(findText) == 0 {
- return false, nil
- }
- m := getOrBuildAC(dict)
- if m == nil {
- return false, nil
- }
- hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
- if len(hits) > 0 {
- words := make([]string, 0)
- for _, hit := range hits {
- words = append(words, string(hit.Word))
- }
- return true, words
- }
- return false, nil
-}
diff --git a/new-api/service/task.go b/new-api/service/task.go
deleted file mode 100644
index 11e4f9c495ea1146ab3ec90a113a792dbe66d81a..0000000000000000000000000000000000000000
--- a/new-api/service/task.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package service
-
-import (
- "one-api/constant"
- "strings"
-)
-
-func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
- return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
-}
diff --git a/new-api/service/token_counter.go b/new-api/service/token_counter.go
deleted file mode 100644
index 710ef4ff4d6cae0d2e0476ca5519e66f444939f7..0000000000000000000000000000000000000000
--- a/new-api/service/token_counter.go
+++ /dev/null
@@ -1,602 +0,0 @@
-package service
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "image"
- _ "image/gif"
- _ "image/jpeg"
- _ "image/png"
- "log"
- "math"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/types"
- "strings"
- "sync"
- "unicode/utf8"
-
- "github.com/gin-gonic/gin"
- "github.com/tiktoken-go/tokenizer"
- "github.com/tiktoken-go/tokenizer/codec"
-)
-
-// tokenEncoderMap won't grow after initialization
-var defaultTokenEncoder tokenizer.Codec
-
-// tokenEncoderMap is used to store token encoders for different models
-var tokenEncoderMap = make(map[string]tokenizer.Codec)
-
-// tokenEncoderMutex protects tokenEncoderMap for concurrent access
-var tokenEncoderMutex sync.RWMutex
-
-func InitTokenEncoders() {
- common.SysLog("initializing token encoders")
- defaultTokenEncoder = codec.NewCl100kBase()
- common.SysLog("token encoders initialized")
-}
-
-func getTokenEncoder(model string) tokenizer.Codec {
- // First, try to get the encoder from cache with read lock
- tokenEncoderMutex.RLock()
- if encoder, exists := tokenEncoderMap[model]; exists {
- tokenEncoderMutex.RUnlock()
- return encoder
- }
- tokenEncoderMutex.RUnlock()
-
- // If not in cache, create new encoder with write lock
- tokenEncoderMutex.Lock()
- defer tokenEncoderMutex.Unlock()
-
- // Double-check if another goroutine already created the encoder
- if encoder, exists := tokenEncoderMap[model]; exists {
- return encoder
- }
-
- // Create new encoder
- modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
- if err != nil {
- // Cache the default encoder for this model to avoid repeated failures
- tokenEncoderMap[model] = defaultTokenEncoder
- return defaultTokenEncoder
- }
-
- // Cache the new encoder
- tokenEncoderMap[model] = modelCodec
- return modelCodec
-}
-
-func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
- if text == "" {
- return 0
- }
- tkm, _ := tokenEncoder.Count(text)
- return tkm
-}
-
-func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
- if fileMeta == nil {
- return 0, fmt.Errorf("image_url_is_nil")
- }
-
- // Defaults for 4o/4.1/4.5 family unless overridden below
- baseTokens := 85
- tileTokens := 170
-
- // Model classification
- lowerModel := strings.ToLower(model)
-
- // Special cases from existing behavior
- if strings.HasPrefix(lowerModel, "glm-4") {
- return 1047, nil
- }
-
- // Patch-based models (32x32 patches, capped at 1536, with multiplier)
- isPatchBased := false
- multiplier := 1.0
- switch {
- case strings.Contains(lowerModel, "gpt-4.1-mini"):
- isPatchBased = true
- multiplier = 1.62
- case strings.Contains(lowerModel, "gpt-4.1-nano"):
- isPatchBased = true
- multiplier = 2.46
- case strings.HasPrefix(lowerModel, "o4-mini"):
- isPatchBased = true
- multiplier = 1.72
- case strings.HasPrefix(lowerModel, "gpt-5-mini"):
- isPatchBased = true
- multiplier = 1.62
- case strings.HasPrefix(lowerModel, "gpt-5-nano"):
- isPatchBased = true
- multiplier = 2.46
- }
-
- // Tile-based model tokens and bases per doc
- if !isPatchBased {
- if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
- baseTokens = 2833
- tileTokens = 5667
- } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
- baseTokens = 70
- tileTokens = 140
- } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
- baseTokens = 75
- tileTokens = 150
- } else if strings.Contains(lowerModel, "computer-use-preview") {
- baseTokens = 65
- tileTokens = 129
- } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
- baseTokens = 85
- tileTokens = 170
- }
- }
-
- // Respect existing feature flags/short-circuits
- if fileMeta.Detail == "low" && !isPatchBased {
- return baseTokens, nil
- }
- if !constant.GetMediaTokenNotStream && !stream {
- return 3 * baseTokens, nil
- }
- // Normalize detail
- if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
- fileMeta.Detail = "high"
- }
- // Whether to count image tokens at all
- if !constant.GetMediaToken {
- return 3 * baseTokens, nil
- }
-
- // Decode image to get dimensions
- var config image.Config
- var err error
- var format string
- var b64str string
-
- if fileMeta.ParsedData != nil {
- config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
- } else {
- if strings.HasPrefix(fileMeta.OriginData, "http") {
- config, format, err = DecodeUrlImageData(fileMeta.OriginData)
- } else {
- common.SysLog(fmt.Sprintf("decoding image"))
- config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
- }
- fileMeta.MimeType = format
- }
-
- if err != nil {
- return 0, err
- }
-
- if config.Width == 0 || config.Height == 0 {
- // not an image
- if format != "" && b64str != "" {
- // file type
- return 3 * baseTokens, nil
- }
- return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
- }
-
- width := config.Width
- height := config.Height
- log.Printf("format: %s, width: %d, height: %d", format, width, height)
-
- if isPatchBased {
- // 32x32 patch-based calculation with 1536 cap and model multiplier
- ceilDiv := func(a, b int) int { return (a + b - 1) / b }
- rawPatchesW := ceilDiv(width, 32)
- rawPatchesH := ceilDiv(height, 32)
- rawPatches := rawPatchesW * rawPatchesH
- if rawPatches > 1536 {
- // scale down
- area := float64(width * height)
- r := math.Sqrt(float64(32*32*1536) / area)
- wScaled := float64(width) * r
- hScaled := float64(height) * r
- // adjust to fit whole number of patches after scaling
- adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
- adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
- adj := math.Min(adjW, adjH)
- if !math.IsNaN(adj) && adj > 0 {
- r = r * adj
- }
- wScaled = float64(width) * r
- hScaled = float64(height) * r
- patchesW := math.Ceil(wScaled / 32.0)
- patchesH := math.Ceil(hScaled / 32.0)
- imageTokens := int(patchesW * patchesH)
- if imageTokens > 1536 {
- imageTokens = 1536
- }
- return int(math.Round(float64(imageTokens) * multiplier)), nil
- }
- // below cap
- imageTokens := rawPatches
- return int(math.Round(float64(imageTokens) * multiplier)), nil
- }
-
- // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
- // Step 1: fit within 2048x2048 square
- maxSide := math.Max(float64(width), float64(height))
- fitScale := 1.0
- if maxSide > 2048 {
- fitScale = maxSide / 2048.0
- }
- fitW := int(math.Round(float64(width) / fitScale))
- fitH := int(math.Round(float64(height) / fitScale))
-
- // Step 2: scale so that shortest side is exactly 768
- minSide := math.Min(float64(fitW), float64(fitH))
- if minSide == 0 {
- return baseTokens, nil
- }
- shortScale := 768.0 / minSide
- finalW := int(math.Round(float64(fitW) * shortScale))
- finalH := int(math.Round(float64(fitH) * shortScale))
-
- // Count 512px tiles
- tilesW := (finalW + 512 - 1) / 512
- tilesH := (finalH + 512 - 1) / 512
- tiles := tilesW * tilesH
-
- if common.DebugEnabled {
- log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
- }
-
- return tiles*tileTokens + baseTokens, nil
-}
-
-func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
- if !constant.GetMediaToken {
- return 0, nil
- }
- if !constant.GetMediaTokenNotStream && !info.IsStream {
- return 0, nil
- }
- if info.RelayFormat == types.RelayFormatOpenAIRealtime {
- return 0, nil
- }
- if meta == nil {
- return 0, errors.New("token count meta is nil")
- }
-
- model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
- tkm := 0
-
- if meta.TokenType == types.TokenTypeTextNumber {
- tkm += utf8.RuneCountInString(meta.CombineText)
- } else {
- tkm += CountTextToken(meta.CombineText, model)
- }
-
- if info.RelayFormat == types.RelayFormatOpenAI {
- tkm += meta.ToolsCount * 8
- tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
- tkm += meta.NameCount * 3
- tkm += 3
- }
-
- shouldFetchFiles := true
-
- if info.RelayFormat == types.RelayFormatGemini {
- shouldFetchFiles = false
- }
-
- if shouldFetchFiles {
- for _, file := range meta.Files {
- if strings.HasPrefix(file.OriginData, "http") {
- mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
- if err != nil {
- return 0, fmt.Errorf("error getting file base64 from url: %v", err)
- }
- if strings.HasPrefix(mineType, "image/") {
- file.FileType = types.FileTypeImage
- } else if strings.HasPrefix(mineType, "video/") {
- file.FileType = types.FileTypeVideo
- } else if strings.HasPrefix(mineType, "audio/") {
- file.FileType = types.FileTypeAudio
- } else {
- file.FileType = types.FileTypeFile
- }
- file.MimeType = mineType
- } else if strings.HasPrefix(file.OriginData, "data:") {
- // get mime type from base64 header
- parts := strings.SplitN(file.OriginData, ",", 2)
- if len(parts) >= 1 {
- header := parts[0]
- // Extract mime type from "data:mime/type;base64" format
- if strings.Contains(header, ":") && strings.Contains(header, ";") {
- mimeStart := strings.Index(header, ":") + 1
- mimeEnd := strings.Index(header, ";")
- if mimeStart < mimeEnd {
- mineType := header[mimeStart:mimeEnd]
- if strings.HasPrefix(mineType, "image/") {
- file.FileType = types.FileTypeImage
- } else if strings.HasPrefix(mineType, "video/") {
- file.FileType = types.FileTypeVideo
- } else if strings.HasPrefix(mineType, "audio/") {
- file.FileType = types.FileTypeAudio
- } else {
- file.FileType = types.FileTypeFile
- }
- file.MimeType = mineType
- }
- }
- }
- }
- }
- }
-
- for i, file := range meta.Files {
- switch file.FileType {
- case types.FileTypeImage:
- if info.RelayFormat == types.RelayFormatGemini {
- tkm += 256
- } else {
- token, err := getImageToken(file, model, info.IsStream)
- if err != nil {
- return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
- }
- tkm += token
- }
- case types.FileTypeAudio:
- tkm += 256
- case types.FileTypeVideo:
- tkm += 4096 * 2
- case types.FileTypeFile:
- tkm += 4096
- default:
- tkm += 4096 // Default case for unknown file types
- }
- }
-
- common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
- return tkm, nil
-}
-
-func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
- tkm := 0
-
- // Count tokens in messages
- msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
- if err != nil {
- return 0, err
- }
- tkm += msgTokens
-
- // Count tokens in system message
- if request.System != "" {
- systemTokens := CountTokenInput(request.System, model)
- tkm += systemTokens
- }
-
- if request.Tools != nil {
- // check is array
- if tools, ok := request.Tools.([]any); ok {
- if len(tools) > 0 {
- parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
- if err1 != nil {
- return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
- }
- toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
- if err2 != nil {
- return 0, fmt.Errorf("tools: %v", err)
- }
- tkm += toolTokens
- }
- } else {
- return 0, errors.New("tools: Input should be a valid list")
- }
- }
-
- return tkm, nil
-}
-
-func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
- tokenEncoder := getTokenEncoder(model)
- tokenNum := 0
-
- for _, message := range messages {
- // Count tokens for role
- tokenNum += getTokenNum(tokenEncoder, message.Role)
- if message.IsStringContent() {
- tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
- } else {
- content, err := message.ParseContent()
- if err != nil {
- return 0, err
- }
- for _, mediaMessage := range content {
- switch mediaMessage.Type {
- case "text":
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
- case "image":
- //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
- //if err != nil {
- // return 0, err
- //}
- tokenNum += 1000
- case "tool_use":
- if mediaMessage.Input != nil {
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
- inputJSON, _ := json.Marshal(mediaMessage.Input)
- tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
- }
- case "tool_result":
- if mediaMessage.Content != nil {
- contentJSON, _ := json.Marshal(mediaMessage.Content)
- tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
- }
- }
- }
- }
- }
-
- // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
- tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
-
- return tokenNum, nil
-}
-
-func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
- tokenEncoder := getTokenEncoder(model)
- tokenNum := 0
-
- for _, tool := range tools {
- tokenNum += getTokenNum(tokenEncoder, tool.Name)
- tokenNum += getTokenNum(tokenEncoder, tool.Description)
-
- schemaJSON, err := json.Marshal(tool.InputSchema)
- if err != nil {
- return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
- }
- tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
- }
-
- // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
- tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
-
- return tokenNum, nil
-}
-
-func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
- audioToken := 0
- textToken := 0
- switch request.Type {
- case dto.RealtimeEventTypeSessionUpdate:
- if request.Session != nil {
- msgTokens := CountTextToken(request.Session.Instructions, model)
- textToken += msgTokens
- }
- case dto.RealtimeEventResponseAudioDelta:
- // count audio token
- atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
- if err != nil {
- return 0, 0, fmt.Errorf("error counting audio token: %v", err)
- }
- audioToken += atk
- case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
- // count text token
- tkm := CountTextToken(request.Delta, model)
- textToken += tkm
- case dto.RealtimeEventInputAudioBufferAppend:
- // count audio token
- atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
- if err != nil {
- return 0, 0, fmt.Errorf("error counting audio token: %v", err)
- }
- audioToken += atk
- case dto.RealtimeEventConversationItemCreated:
- if request.Item != nil {
- switch request.Item.Type {
- case "message":
- for _, content := range request.Item.Content {
- if content.Type == "input_text" {
- tokens := CountTextToken(content.Text, model)
- textToken += tokens
- }
- }
- }
- }
- case dto.RealtimeEventTypeResponseDone:
- // count tools token
- if !info.IsFirstRequest {
- if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
- for _, tool := range info.RealtimeTools {
- toolTokens := CountTokenInput(tool, model)
- textToken += 8
- textToken += toolTokens
- }
- }
- }
- }
- return textToken, audioToken, nil
-}
-
-func CountTokenInput(input any, model string) int {
- switch v := input.(type) {
- case string:
- return CountTextToken(v, model)
- case []string:
- text := ""
- for _, s := range v {
- text += s
- }
- return CountTextToken(text, model)
- case []interface{}:
- text := ""
- for _, item := range v {
- text += fmt.Sprintf("%v", item)
- }
- return CountTextToken(text, model)
- }
- return CountTokenInput(fmt.Sprintf("%v", input), model)
-}
-
-func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
- tokens := 0
- for _, message := range messages {
- tkm := CountTokenInput(message.Delta.GetContentString(), model)
- tokens += tkm
- if message.Delta.ToolCalls != nil {
- for _, tool := range message.Delta.ToolCalls {
- tkm := CountTokenInput(tool.Function.Name, model)
- tokens += tkm
- tkm = CountTokenInput(tool.Function.Arguments, model)
- tokens += tkm
- }
- }
- }
- return tokens
-}
-
-func CountTTSToken(text string, model string) int {
- if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text)
- } else {
- return CountTextToken(text, model)
- }
-}
-
-func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
- if audioBase64 == "" {
- return 0, nil
- }
- duration, err := parseAudio(audioBase64, audioFormat)
- if err != nil {
- return 0, err
- }
- return int(duration / 60 * 100 / 0.06), nil
-}
-
-func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
- if audioBase64 == "" {
- return 0, nil
- }
- duration, err := parseAudio(audioBase64, audioFormat)
- if err != nil {
- return 0, err
- }
- return int(duration / 60 * 200 / 0.24), nil
-}
-
-//func CountAudioToken(sec float64, audioType string) {
-// if audioType == "input" {
-//
-// }
-//}
-
-// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTextToken(text string, model string) int {
- if text == "" {
- return 0
- }
- tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text)
-}
diff --git a/new-api/service/usage_helpr.go b/new-api/service/usage_helpr.go
deleted file mode 100644
index c232d2b38919364b11d46ed6c821d4f43930918c..0000000000000000000000000000000000000000
--- a/new-api/service/usage_helpr.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package service
-
-import (
- "one-api/dto"
-)
-
-//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
-// switch relayMode {
-// case constant.RelayModeChatCompletions:
-// return CountTokenMessages(textRequest.Messages, textRequest.Model)
-// case constant.RelayModeCompletions:
-// return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
-// case constant.RelayModeModerations:
-// return CountTokenInput(textRequest.Input, textRequest.Model), nil
-// }
-// return 0, errors.New("unknown relay mode")
-//}
-
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
- usage := &dto.Usage{}
- usage.PromptTokens = promptTokens
- ctkm := CountTextToken(responseText, modeName)
- usage.CompletionTokens = ctkm
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage
-}
-
-func ValidUsage(usage *dto.Usage) bool {
- return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
-}
diff --git a/new-api/service/user_notify.go b/new-api/service/user_notify.go
deleted file mode 100644
index 169df3d4794fc703f5978381ee035ad66a85f983..0000000000000000000000000000000000000000
--- a/new-api/service/user_notify.go
+++ /dev/null
@@ -1,146 +0,0 @@
-package service
-
-import (
- "fmt"
- "net/http"
- "net/url"
- "one-api/common"
- "one-api/dto"
- "one-api/model"
- "one-api/setting/system_setting"
- "strings"
-)
-
-func NotifyRootUser(t string, subject string, content string) {
- user := model.GetRootUser().ToBaseUser()
- err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error()))
- }
-}
-
-func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
- notifyType := userSetting.NotifyType
- if notifyType == "" {
- notifyType = dto.NotifyTypeEmail
- }
-
- // Check notification limit
- canSend, err := CheckNotificationLimit(userId, data.Type)
- if err != nil {
- common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
- return err
- }
- if !canSend {
- return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
- }
-
- switch notifyType {
- case dto.NotifyTypeEmail:
- // check setting email
- userEmail = userSetting.NotificationEmail
- if userEmail == "" {
- common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
- return nil
- }
- return sendEmailNotify(userEmail, data)
- case dto.NotifyTypeWebhook:
- webhookURLStr := userSetting.WebhookUrl
- if webhookURLStr == "" {
- common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
- return nil
- }
-
- // 获取 webhook secret
- webhookSecret := userSetting.WebhookSecret
- return SendWebhookNotify(webhookURLStr, webhookSecret, data)
- case dto.NotifyTypeBark:
- barkURL := userSetting.BarkUrl
- if barkURL == "" {
- common.SysLog(fmt.Sprintf("user %d has no bark url, skip sending bark", userId))
- return nil
- }
- return sendBarkNotify(barkURL, data)
- }
- return nil
-}
-
-func sendEmailNotify(userEmail string, data dto.Notify) error {
- // make email content
- content := data.Content
- // 处理占位符
- for _, value := range data.Values {
- content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
- }
- return common.SendEmail(data.Title, userEmail, content)
-}
-
-func sendBarkNotify(barkURL string, data dto.Notify) error {
- // 处理占位符
- content := data.Content
- for _, value := range data.Values {
- content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
- }
-
- // 替换模板变量
- finalURL := strings.ReplaceAll(barkURL, "{{title}}", url.QueryEscape(data.Title))
- finalURL = strings.ReplaceAll(finalURL, "{{content}}", url.QueryEscape(content))
-
- // 发送GET请求到Bark
- var req *http.Request
- var resp *http.Response
- var err error
-
- if system_setting.EnableWorker() {
- // 使用worker发送请求
- workerReq := &WorkerRequest{
- URL: finalURL,
- Key: system_setting.WorkerValidKey,
- Method: http.MethodGet,
- Headers: map[string]string{
- "User-Agent": "OneAPI-Bark-Notify/1.0",
- },
- }
-
- resp, err = DoWorkerRequest(workerReq)
- if err != nil {
- return fmt.Errorf("failed to send bark request through worker: %v", err)
- }
- defer resp.Body.Close()
-
- // 检查响应状态
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
- }
- } else {
- // SSRF防护:验证Bark URL(非Worker模式)
- fetchSetting := system_setting.GetFetchSetting()
- if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
- return fmt.Errorf("request reject: %v", err)
- }
-
- // 直接发送请求
- req, err = http.NewRequest(http.MethodGet, finalURL, nil)
- if err != nil {
- return fmt.Errorf("failed to create bark request: %v", err)
- }
-
- // 设置User-Agent
- req.Header.Set("User-Agent", "OneAPI-Bark-Notify/1.0")
-
- // 发送请求
- client := GetHttpClient()
- resp, err = client.Do(req)
- if err != nil {
- return fmt.Errorf("failed to send bark request: %v", err)
- }
- defer resp.Body.Close()
-
- // 检查响应状态
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
- }
- }
-
- return nil
-}
diff --git a/new-api/service/webhook.go b/new-api/service/webhook.go
deleted file mode 100644
index 263f35d0cc0d97daddc549f7534e2e6dd0001429..0000000000000000000000000000000000000000
--- a/new-api/service/webhook.go
+++ /dev/null
@@ -1,125 +0,0 @@
-package service
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/setting/system_setting"
- "time"
-)
-
-// WebhookPayload webhook 通知的负载数据
-type WebhookPayload struct {
- Type string `json:"type"`
- Title string `json:"title"`
- Content string `json:"content"`
- Values []interface{} `json:"values,omitempty"`
- Timestamp int64 `json:"timestamp"`
-}
-
-// generateSignature 生成 webhook 签名
-func generateSignature(secret string, payload []byte) string {
- h := hmac.New(sha256.New, []byte(secret))
- h.Write(payload)
- return hex.EncodeToString(h.Sum(nil))
-}
-
-// SendWebhookNotify 发送 webhook 通知
-func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
- // 处理占位符
- content := data.Content
- for _, value := range data.Values {
- content = fmt.Sprintf(content, value)
- }
-
- // 构建 webhook 负载
- payload := WebhookPayload{
- Type: data.Type,
- Title: data.Title,
- Content: content,
- Values: data.Values,
- Timestamp: time.Now().Unix(),
- }
-
- // 序列化负载
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- return fmt.Errorf("failed to marshal webhook payload: %v", err)
- }
-
- // 创建 HTTP 请求
- var req *http.Request
- var resp *http.Response
-
- if system_setting.EnableWorker() {
- // 构建worker请求数据
- workerReq := &WorkerRequest{
- URL: webhookURL,
- Key: system_setting.WorkerValidKey,
- Method: http.MethodPost,
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- Body: payloadBytes,
- }
-
- // 如果有secret,添加签名到headers
- if secret != "" {
- signature := generateSignature(secret, payloadBytes)
- workerReq.Headers["X-Webhook-Signature"] = signature
- workerReq.Headers["Authorization"] = "Bearer " + secret
- }
-
- resp, err = DoWorkerRequest(workerReq)
- if err != nil {
- return fmt.Errorf("failed to send webhook request through worker: %v", err)
- }
- defer resp.Body.Close()
-
- // 检查响应状态
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
- }
- } else {
- // SSRF防护:验证Webhook URL(非Worker模式)
- fetchSetting := system_setting.GetFetchSetting()
- if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
- return fmt.Errorf("request reject: %v", err)
- }
-
- req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
- if err != nil {
- return fmt.Errorf("failed to create webhook request: %v", err)
- }
-
- // 设置请求头
- req.Header.Set("Content-Type", "application/json")
-
- // 如果有 secret,生成签名
- if secret != "" {
- signature := generateSignature(secret, payloadBytes)
- req.Header.Set("X-Webhook-Signature", signature)
- }
-
- // 发送请求
- client := GetHttpClient()
- resp, err = client.Do(req)
- if err != nil {
- return fmt.Errorf("failed to send webhook request: %v", err)
- }
- defer resp.Body.Close()
-
- // 检查响应状态
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
- }
- }
-
- return nil
-}
diff --git a/new-api/setting/auto_group.go b/new-api/setting/auto_group.go
deleted file mode 100644
index 1bb3ef0324b0582d9667b4c583fd82b882a7acc0..0000000000000000000000000000000000000000
--- a/new-api/setting/auto_group.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package setting
-
-import "encoding/json"
-
-var AutoGroups = []string{
- "default",
-}
-
-var DefaultUseAutoGroup = false
-
-func ContainsAutoGroup(group string) bool {
- for _, autoGroup := range AutoGroups {
- if autoGroup == group {
- return true
- }
- }
- return false
-}
-
-func UpdateAutoGroupsByJsonString(jsonString string) error {
- AutoGroups = make([]string, 0)
- return json.Unmarshal([]byte(jsonString), &AutoGroups)
-}
-
-func AutoGroups2JsonString() string {
- jsonBytes, err := json.Marshal(AutoGroups)
- if err != nil {
- return "[]"
- }
- return string(jsonBytes)
-}
diff --git a/new-api/setting/chat.go b/new-api/setting/chat.go
deleted file mode 100644
index 05c9eb25fc66eeaaa1664ea172eab16d356f3305..0000000000000000000000000000000000000000
--- a/new-api/setting/chat.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package setting
-
-import (
- "encoding/json"
- "one-api/common"
-)
-
-var Chats = []map[string]string{
- //{
- // "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
- //},
- {
- "Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
- },
- {
- "流畅阅读": "fluentread",
- },
- {
- "Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
- },
- {
- "AI as Workspace": "https://aiaw.app/set-provider?provider={\"type\":\"openai\",\"settings\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\",\"compatibility\":\"strict\"}}",
- },
- {
- "AMA 问天": "ama://set-api-key?server={address}&key={key}",
- },
- {
- "OpenCat": "opencat://team/join?domain={address}&token={key}",
- },
-}
-
-func UpdateChatsByJsonString(jsonString string) error {
- Chats = make([]map[string]string, 0)
- return json.Unmarshal([]byte(jsonString), &Chats)
-}
-
-func Chats2JsonString() string {
- jsonBytes, err := json.Marshal(Chats)
- if err != nil {
- common.SysLog("error marshalling chats: " + err.Error())
- return "[]"
- }
- return string(jsonBytes)
-}
diff --git a/new-api/setting/config/config.go b/new-api/setting/config/config.go
deleted file mode 100644
index 286242f86dd851ac2852167f2094fae90e55ca8a..0000000000000000000000000000000000000000
--- a/new-api/setting/config/config.go
+++ /dev/null
@@ -1,259 +0,0 @@
-package config
-
-import (
- "encoding/json"
- "one-api/common"
- "reflect"
- "strconv"
- "strings"
- "sync"
-)
-
-// ConfigManager 统一管理所有配置
-type ConfigManager struct {
- configs map[string]interface{}
- mutex sync.RWMutex
-}
-
-var GlobalConfig = NewConfigManager()
-
-func NewConfigManager() *ConfigManager {
- return &ConfigManager{
- configs: make(map[string]interface{}),
- }
-}
-
-// Register 注册一个配置模块
-func (cm *ConfigManager) Register(name string, config interface{}) {
- cm.mutex.Lock()
- defer cm.mutex.Unlock()
- cm.configs[name] = config
-}
-
-// Get 获取指定配置模块
-func (cm *ConfigManager) Get(name string) interface{} {
- cm.mutex.RLock()
- defer cm.mutex.RUnlock()
- return cm.configs[name]
-}
-
-// LoadFromDB 从数据库加载配置
-func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
- cm.mutex.Lock()
- defer cm.mutex.Unlock()
-
- for name, config := range cm.configs {
- prefix := name + "."
- configMap := make(map[string]string)
-
- // 收集属于此配置的所有选项
- for key, value := range options {
- if strings.HasPrefix(key, prefix) {
- configKey := strings.TrimPrefix(key, prefix)
- configMap[configKey] = value
- }
- }
-
- // 如果找到配置项,则更新配置
- if len(configMap) > 0 {
- if err := updateConfigFromMap(config, configMap); err != nil {
- common.SysError("failed to update config " + name + ": " + err.Error())
- continue
- }
- }
- }
-
- return nil
-}
-
-// SaveToDB 将配置保存到数据库
-func (cm *ConfigManager) SaveToDB(updateFunc func(key, value string) error) error {
- cm.mutex.RLock()
- defer cm.mutex.RUnlock()
-
- for name, config := range cm.configs {
- configMap, err := configToMap(config)
- if err != nil {
- return err
- }
-
- for key, value := range configMap {
- dbKey := name + "." + key
- if err := updateFunc(dbKey, value); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-// 辅助函数:将配置对象转换为map
-func configToMap(config interface{}) (map[string]string, error) {
- result := make(map[string]string)
-
- val := reflect.ValueOf(config)
- if val.Kind() == reflect.Ptr {
- val = val.Elem()
- }
-
- if val.Kind() != reflect.Struct {
- return nil, nil
- }
-
- typ := val.Type()
- for i := 0; i < val.NumField(); i++ {
- field := val.Field(i)
- fieldType := typ.Field(i)
-
- // 跳过未导出字段
- if !fieldType.IsExported() {
- continue
- }
-
- // 获取json标签作为键名
- key := fieldType.Tag.Get("json")
- if key == "" || key == "-" {
- key = fieldType.Name
- }
-
- // 处理不同类型的字段
- var strValue string
- switch field.Kind() {
- case reflect.String:
- strValue = field.String()
- case reflect.Bool:
- strValue = strconv.FormatBool(field.Bool())
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- strValue = strconv.FormatInt(field.Int(), 10)
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- strValue = strconv.FormatUint(field.Uint(), 10)
- case reflect.Float32, reflect.Float64:
- strValue = strconv.FormatFloat(field.Float(), 'f', -1, 64)
- case reflect.Map, reflect.Slice, reflect.Struct:
- // 复杂类型使用JSON序列化
- bytes, err := json.Marshal(field.Interface())
- if err != nil {
- return nil, err
- }
- strValue = string(bytes)
- default:
- // 跳过不支持的类型
- continue
- }
-
- result[key] = strValue
- }
-
- return result, nil
-}
-
-// 辅助函数:从map更新配置对象
-func updateConfigFromMap(config interface{}, configMap map[string]string) error {
- val := reflect.ValueOf(config)
- if val.Kind() != reflect.Ptr {
- return nil
- }
- val = val.Elem()
-
- if val.Kind() != reflect.Struct {
- return nil
- }
-
- typ := val.Type()
- for i := 0; i < val.NumField(); i++ {
- field := val.Field(i)
- fieldType := typ.Field(i)
-
- // 跳过未导出字段
- if !fieldType.IsExported() {
- continue
- }
-
- // 获取json标签作为键名
- key := fieldType.Tag.Get("json")
- if key == "" || key == "-" {
- key = fieldType.Name
- }
-
- // 检查map中是否有对应的值
- strValue, ok := configMap[key]
- if !ok {
- continue
- }
-
- // 根据字段类型设置值
- if !field.CanSet() {
- continue
- }
-
- switch field.Kind() {
- case reflect.String:
- field.SetString(strValue)
- case reflect.Bool:
- boolValue, err := strconv.ParseBool(strValue)
- if err != nil {
- continue
- }
- field.SetBool(boolValue)
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- intValue, err := strconv.ParseInt(strValue, 10, 64)
- if err != nil {
- continue
- }
- field.SetInt(intValue)
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- uintValue, err := strconv.ParseUint(strValue, 10, 64)
- if err != nil {
- continue
- }
- field.SetUint(uintValue)
- case reflect.Float32, reflect.Float64:
- floatValue, err := strconv.ParseFloat(strValue, 64)
- if err != nil {
- continue
- }
- field.SetFloat(floatValue)
- case reflect.Map, reflect.Slice, reflect.Struct:
- // 复杂类型使用JSON反序列化
- err := json.Unmarshal([]byte(strValue), field.Addr().Interface())
- if err != nil {
- continue
- }
- }
- }
-
- return nil
-}
-
-// ConfigToMap 将配置对象转换为map(导出函数)
-func ConfigToMap(config interface{}) (map[string]string, error) {
- return configToMap(config)
-}
-
-// UpdateConfigFromMap 从map更新配置对象(导出函数)
-func UpdateConfigFromMap(config interface{}, configMap map[string]string) error {
- return updateConfigFromMap(config, configMap)
-}
-
-// ExportAllConfigs 导出所有已注册的配置为扁平结构
-func (cm *ConfigManager) ExportAllConfigs() map[string]string {
- cm.mutex.RLock()
- defer cm.mutex.RUnlock()
-
- result := make(map[string]string)
-
- for name, cfg := range cm.configs {
- configMap, err := ConfigToMap(cfg)
- if err != nil {
- continue
- }
-
- // 使用 "模块名.配置项" 的格式添加到结果中
- for key, value := range configMap {
- result[name+"."+key] = value
- }
- }
-
- return result
-}
diff --git a/new-api/setting/console_setting/config.go b/new-api/setting/console_setting/config.go
deleted file mode 100644
index 5ca069e5514dba66f16786832d9c73c76ae6eeb7..0000000000000000000000000000000000000000
--- a/new-api/setting/console_setting/config.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package console_setting
-
-import "one-api/setting/config"
-
-type ConsoleSetting struct {
- ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
- UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
- Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
- FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
- ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
- UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
- AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
- FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
-}
-
-// 默认配置
-var defaultConsoleSetting = ConsoleSetting{
- ApiInfo: "",
- UptimeKumaGroups: "",
- Announcements: "",
- FAQ: "",
- ApiInfoEnabled: true,
- UptimeKumaEnabled: true,
- AnnouncementsEnabled: true,
- FAQEnabled: true,
-}
-
-// 全局实例
-var consoleSetting = defaultConsoleSetting
-
-func init() {
- // 注册到全局配置管理器,键名为 console_setting
- config.GlobalConfig.Register("console_setting", &consoleSetting)
-}
-
-// GetConsoleSetting 获取 ConsoleSetting 配置实例
-func GetConsoleSetting() *ConsoleSetting {
- return &consoleSetting
-}
diff --git a/new-api/setting/console_setting/validation.go b/new-api/setting/console_setting/validation.go
deleted file mode 100644
index 19d0e21f1100a2d84a4e1daed7dac0c5908e4762..0000000000000000000000000000000000000000
--- a/new-api/setting/console_setting/validation.go
+++ /dev/null
@@ -1,304 +0,0 @@
-package console_setting
-
-import (
- "encoding/json"
- "fmt"
- "net/url"
- "regexp"
- "sort"
- "strings"
- "time"
-)
-
-var (
- urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
- dangerousChars = []string{"
-