timigogo commited on
Commit
a3284c8
·
verified ·
1 Parent(s): e061846

Upload 397 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.dockerignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .github
2
+ .git
3
+ *.md
4
+ .vscode
5
+ .gitignore
6
+ Makefile
7
+ docs
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ web/bun.lockb filter=lfs diff=lfs merge=lfs -text
37
+ web/public/ratio.png filter=lfs diff=lfs merge=lfs -text
BT.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 密钥为环境变量SESSION_SECRET
2
+
3
+ ![8285bba413e770fe9620f1bf9b40d44e](https://github.com/user-attachments/assets/7a6fc03e-c457-45e4-b8f9-184508fc26b0)
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM oven/bun:latest AS builder
2
+
3
+ WORKDIR /build
4
+ COPY web/package.json .
5
+ RUN bun install
6
+ COPY ./web .
7
+ COPY ./VERSION .
8
+ RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
9
+
10
+ FROM golang:alpine AS builder2
11
+
12
+ ENV GO111MODULE=on \
13
+ CGO_ENABLED=0 \
14
+ GOOS=linux
15
+
16
+ WORKDIR /build
17
+
18
+ ADD go.mod go.sum ./
19
+ RUN go mod download
20
+
21
+ COPY . .
22
+ COPY --from=builder /build/dist ./web/dist
23
+ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
24
+
25
+ FROM alpine
26
+
27
+ RUN apk update \
28
+ && apk upgrade \
29
+ && apk add --no-cache ca-certificates tzdata ffmpeg \
30
+ && update-ca-certificates
31
+
32
+ COPY --from=builder2 /build/one-api /
33
+ EXPOSE 3000
34
+ WORKDIR /data
35
+ ENTRYPOINT ["/one-api"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Midjourney.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Midjourney Proxy API文档
2
+
3
+ **简介**:Midjourney Proxy API文档
4
+
5
+ ## 接口列表
6
+ 支持的接口如下:
7
+ + [x] /mj/submit/imagine
8
+ + [x] /mj/submit/change
9
+ + [x] /mj/submit/blend
10
+ + [x] /mj/submit/describe
11
+ + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
12
+ + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
13
+ + [x] /task/list-by-condition
14
+ + [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
15
+ + [x] /mj/submit/modal
16
+ + [x] /mj/submit/shorten
17
+ + [x] /mj/task/{id}/image-seed
18
+ + [x] /mj/insight-face/swap (InsightFace)
19
+
20
+ ## 模型列表
21
+
22
+ ### midjourney-proxy支持
23
+
24
+ - mj_imagine (绘图)
25
+ - mj_variation (变换)
26
+ - mj_reroll (重绘)
27
+ - mj_blend (混合)
28
+ - mj_upscale (放大)
29
+ - mj_describe (图生文)
30
+
31
+ ### 仅midjourney-proxy-plus支持
32
+
33
+ - mj_zoom (比例变焦)
34
+ - mj_shorten (提示词缩短)
35
+ - mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
36
+ - mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
37
+ - mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
38
+ - mj_high_variation (强变换)
39
+ - mj_low_variation (弱变换)
40
+ - mj_pan (平移)
41
+ - swap_face (换脸)
42
+
43
+ ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
44
+ ```json
45
+ {
46
+ "mj_imagine": 0.1,
47
+ "mj_variation": 0.1,
48
+ "mj_reroll": 0.1,
49
+ "mj_blend": 0.1,
50
+ "mj_modal": 0.1,
51
+ "mj_zoom": 0.1,
52
+ "mj_shorten": 0.1,
53
+ "mj_high_variation": 0.1,
54
+ "mj_low_variation": 0.1,
55
+ "mj_pan": 0.1,
56
+ "mj_inpaint": 0,
57
+ "mj_custom_zoom": 0,
58
+ "mj_describe": 0.05,
59
+ "mj_upscale": 0.05,
60
+ "swap_face": 0.05
61
+ }
62
+ ```
63
+ 其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
64
+
65
+ ## 渠道设置
66
+
67
+ ### 对接 midjourney-proxy(plus)
68
+
69
+ 1.
70
+
71
+ 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
72
+
73
+ 2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
74
+ ,模型请参考上方模型列表
75
+ 3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080
76
+ 4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
77
+
78
+ ### 对接上游new api
79
+
80
+ 1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
81
+ 2. **代理**填写上游new api的地址,例如:http://localhost:3000
82
+ 3. 密钥填写上游new api的密钥
README.en.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ ![new-api](/web/public/logo.png)
4
+
5
+ # New API
6
+
7
+ 🍥 Next Generation LLM Gateway and AI Asset Management System
8
+
9
+ <a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
10
+
11
+ <p align="center">
12
+ <a href="https://raw.githubusercontent.com/Calcium-Ion/new-api/main/LICENSE">
13
+ <img src="https://img.shields.io/github/license/Calcium-Ion/new-api?color=brightgreen" alt="license">
14
+ </a>
15
+ <a href="https://github.com/Calcium-Ion/new-api/releases/latest">
16
+ <img src="https://img.shields.io/github/v/release/Calcium-Ion/new-api?color=brightgreen&include_prereleases" alt="release">
17
+ </a>
18
+ <a href="https://github.com/users/Calcium-Ion/packages/container/package/new-api">
19
+ <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
20
+ </a>
21
+ <a href="https://hub.docker.com/r/CalciumIon/new-api">
22
+ <img src="https://img.shields.io/badge/docker-dockerHub-blue" alt="docker">
23
+ </a>
24
+ <a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
25
+ <img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
26
+ </a>
27
+ </p>
28
+ </div>
29
+
30
+ ## 📝 Project Description
31
+
32
+ > [!NOTE]
33
+ > This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
34
+
35
+ > [!IMPORTANT]
36
+ > - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and relevant laws and regulations. Not to be used for illegal purposes.
37
+ > - This project is for personal learning only. Stability is not guaranteed, and no technical support is provided.
38
+
39
+ ## ✨ Key Features
40
+
41
+ 1. 🎨 New UI interface (some interfaces pending update)
42
+ 2. 🌍 Multi-language support (work in progress)
43
+ 3. 🎨 Added [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface support, [Integration Guide](Midjourney.md)
44
+ 4. 💰 Online recharge support, configurable in system settings:
45
+ - [x] EasyPay
46
+ 5. 🔍 Query usage quota by key:
47
+ - Works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
48
+ 6. 📑 Configurable items per page in pagination
49
+ 7. 🔄 Compatible with original One API database (one-api.db)
50
+ 8. 💵 Support per-request model pricing, configurable in System Settings - Operation Settings
51
+ 9. ⚖️ Support channel **weighted random** selection
52
+ 10. 📈 Data dashboard (console)
53
+ 11. 🔒 Configurable model access per token
54
+ 12. 🤖 Telegram authorization login support:
55
+ 1. System Settings - Configure Login Registration - Allow Telegram Login
56
+ 2. Send /setdomain command to [@Botfather](https://t.me/botfather)
57
+ 3. Select your bot, then enter http(s)://your-website/login
58
+ 4. Telegram Bot name is the bot username without @
59
+ 13. 🎵 Added [Suno API](https://github.com/Suno-API/Suno-API) interface support, [Integration Guide](Suno.md)
60
+ 14. 🔄 Support for Rerank models, compatible with Cohere and Jina, can integrate with Dify, [Integration Guide](Rerank.md)
61
+ 15. ⚡ **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - Support for OpenAI's Realtime API, including Azure channels
62
+ 16. 🧠 Support for setting reasoning effort through model name suffix:
63
+ - Add suffix `-high` to set high reasoning effort (e.g., `o3-mini-high`)
64
+ - Add suffix `-medium` to set medium reasoning effort
65
+ - Add suffix `-low` to set low reasoning effort
66
+ 17. 🔄 Thinking to content option `thinking_to_content` in `Channel->Edit->Channel Extra Settings`, default is `false`, when `true`, the `reasoning_content` of the thinking content will be converted to `<think>` tags and concatenated to the content returned.
67
+ 18. 🔄 Model rate limit, support setting total request limit and successful request limit in `System Settings->Rate Limit Settings`
68
+ 19. 💰 Cache billing support, when enabled can charge a configurable ratio for cache hits:
69
+ 1. Set `Prompt Cache Ratio` in `System Settings -> Operation Settings`
70
+ 2. Set `Prompt Cache Ratio` in channel settings, range 0-1 (e.g., 0.5 means 50% charge on cache hits)
71
+ 3. Supported channels:
72
+ - [x] OpenAI
73
+ - [x] Azure
74
+ - [x] DeepSeek
75
+ - [ ] Claude
76
+
77
+ ## Model Support
78
+ This version additionally supports:
79
+ 1. Third-party model **gpts** (gpt-4-gizmo-*)
80
+ 2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [Integration Guide](Midjourney.md)
81
+ 3. Custom channels with full API URL support
82
+ 4. [Suno API](https://github.com/Suno-API/Suno-API) interface, [Integration Guide](Suno.md)
83
+ 5. Rerank models, supporting [Cohere](https://cohere.ai/) and [Jina](https://jina.ai/), [Integration Guide](Rerank.md)
84
+ 6. Dify
85
+
86
+ You can add custom models gpt-4-gizmo-* in channels. These are third-party models and cannot be called with official OpenAI keys.
87
+
88
+ ## Additional Configurations Beyond One API
89
+ - `GENERATE_DEFAULT_TOKEN`: Generate initial token for new users, default `false`
90
+ - `STREAMING_TIMEOUT`: Set streaming response timeout, default 60 seconds
91
+ - `DIFY_DEBUG`: Output workflow and node info to client for Dify channel, default `true`
92
+ - `FORCE_STREAM_OPTION`: Override client stream_options parameter, default `true`
93
+ - `GET_MEDIA_TOKEN`: Calculate image tokens, default `true`
94
+ - `GET_MEDIA_TOKEN_NOT_STREAM`: Calculate image tokens in non-stream mode, default `true`
95
+ - `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
96
+ - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
97
+ - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
98
+ - `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
99
+ - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
100
+ - `CRYPTO_SECRET`: Encryption key for encrypting database content
101
+ - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
102
+ - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
103
+ - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
104
+
105
+ ## Deployment
106
+
107
+ > [!TIP]
108
+ > Latest Docker image: `calciumion/new-api:latest`
109
+ > Default account: root, password: 123456
110
+
111
+ ### Multi-Server Deployment
112
+ - Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
113
+ - If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment.
114
+
115
+ ### Requirements
116
+ - Local database (default): SQLite (Docker deployment must mount `/data` directory)
117
+ - Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
118
+
119
+ ### Deployment with BT Panel
120
+ Install BT Panel (**version 9.2.0** or above) from [BT Panel Official Website](https://www.bt.cn/new/download.html), choose the stable version script to download and install.
121
+ After installation, log in to BT Panel and click Docker in the menu bar. First-time access will prompt to install Docker service. Click Install Now and follow the prompts to complete installation.
122
+ After installation, find **New-API** in the app store, click install, configure basic options to complete installation.
123
+ [Pictorial Guide](BT.md)
124
+
125
+ ### Docker Deployment
126
+
127
+ ### Using Docker Compose (Recommended)
128
+ ```shell
129
+ # Clone project
130
+ git clone https://github.com/Calcium-Ion/new-api.git
131
+ cd new-api
132
+ # Edit docker-compose.yml as needed
133
+ # nano docker-compose.yml
134
+ # vim docker-compose.yml
135
+ # Start
136
+ docker-compose up -d
137
+ ```
138
+
139
+ #### Update Version
140
+ ```shell
141
+ docker-compose pull
142
+ docker-compose up -d
143
+ ```
144
+
145
+ ### Direct Docker Image Usage
146
+ ```shell
147
+ # SQLite deployment:
148
+ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
149
+
150
+ # MySQL deployment (add -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"), modify database connection parameters as needed
151
+ # Example:
152
+ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
153
+ ```
154
+
155
+ #### Update Version
156
+ ```shell
157
+ # Pull the latest image
158
+ docker pull calciumion/new-api:latest
159
+ # Stop and remove the old container
160
+ docker stop new-api
161
+ docker rm new-api
162
+ # Run the new container with the same parameters as before
163
+ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
164
+ ```
165
+
166
+ Alternatively, you can use Watchtower for automatic updates (not recommended, may cause database incompatibility):
167
+ ```shell
168
+ docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
169
+ ```
170
+
171
+ ## Channel Retry
172
+ Channel retry is implemented, configurable in `Settings->Operation Settings->General Settings`. **Cache recommended**.
173
+ If retry is enabled, the system will automatically use the next priority channel for the same request after a failed request.
174
+
175
+ ### Cache Configuration
176
+ 1. `REDIS_CONN_STRING`: Use Redis as cache
177
+ + Example: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
178
+ 2. `MEMORY_CACHE_ENABLED`: Enable memory cache, default `false`
179
+ + Example: `MEMORY_CACHE_ENABLED=true`
180
+
181
+ ### Why Some Errors Don't Retry
182
+ Error codes 400, 504, 524 won't retry
183
+ ### To Enable Retry for 400
184
+ In `Channel->Edit`, set `Status Code Override` to:
185
+ ```json
186
+ {
187
+ "400": "500"
188
+ }
189
+ ```
190
+
191
+ ## Integration Guides
192
+ - [Midjourney Integration](Midjourney.md)
193
+ - [Suno Integration](Suno.md)
194
+
195
+ ## Related Projects
196
+ - [One API](https://github.com/songquanpeng/one-api): Original project
197
+ - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support
198
+ - [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-gen AI B/C solution
199
+ - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota by key
200
+
201
+ ## 🌟 Star History
202
+
203
+ [![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)
README.md CHANGED
@@ -1,10 +1,190 @@
1
- ---
2
- title: New Api
3
- emoji: ⚡
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="right">
2
+ <strong>中文</strong> | <a href="./README.en.md">English</a>
3
+ </p>
4
+ <div align="center">
5
+
6
+ ![new-api](/web/public/logo.png)
7
+
8
+ # New API
9
+
10
+ 🍥新一代大模型网关与AI资产管理系统
11
+
12
+ <a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
13
+
14
+ <p align="center">
15
+ <a href="https://raw.githubusercontent.com/Calcium-Ion/new-api/main/LICENSE">
16
+ <img src="https://img.shields.io/github/license/Calcium-Ion/new-api?color=brightgreen" alt="license">
17
+ </a>
18
+ <a href="https://github.com/Calcium-Ion/new-api/releases/latest">
19
+ <img src="https://img.shields.io/github/v/release/Calcium-Ion/new-api?color=brightgreen&include_prereleases" alt="release">
20
+ </a>
21
+ <a href="https://github.com/users/Calcium-Ion/packages/container/package/new-api">
22
+ <img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
23
+ </a>
24
+ <a href="https://hub.docker.com/r/CalciumIon/new-api">
25
+ <img src="https://img.shields.io/badge/docker-dockerHub-blue" alt="docker">
26
+ </a>
27
+ <a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
28
+ <img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
29
+ </a>
30
+ </p>
31
+ </div>
32
+
33
+ ## 📝 项目说明
34
+
35
+ > [!NOTE]
36
+ > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
37
+
38
+ > [!IMPORTANT]
39
+ > - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
40
+ > - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
41
+ > - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
42
+
43
+ ## 📚 文档
44
+
45
+ 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
46
+
47
+ ## ✨ 主要特性
48
+
49
+ New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
50
+
51
+ 1. 🎨 全新的UI界面
52
+ 2. 🌍 多语言支持
53
+ 3. 💰 支持在线充值功能(易支付)
54
+ 4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
55
+ 5. 🔄 兼容原版One API的数据库
56
+ 6. 💵 支持模型按次数收费
57
+ 7. ⚖️ 支持渠道加权随机
58
+ 8. 📈 数据看板(控制台)
59
+ 9. 🔒 令牌分组、模型限制
60
+ 10. 🤖 支持更多授权登陆方式(LinuxDO,Telegram、OIDC)
61
+ 11. 🔄 支持Rerank模型(Cohere和Jina),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
62
+ 12. ⚡ 支持OpenAI Realtime API(包括Azure渠道),[接口文档](https://docs.newapi.pro/api/openai-realtime)
63
+ 13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
64
+ 14. 支持使用路由/chat2link进入聊天界面
65
+ 15. 🧠 支持通过模型名称后缀设置 reasoning effort:
66
+ 1. OpenAI o系列模型
67
+ - 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
68
+ - 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
69
+ - 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
70
+ 2. Claude 思考模型
71
+ - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
72
+ 16. 🔄 思考转内容功能
73
+ 17. 🔄 针对用户的模型限流功能
74
+ 18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
75
+ 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
76
+ 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
77
+ 3. 支持的渠道:
78
+ - [x] OpenAI
79
+ - [x] Azure
80
+ - [x] DeepSeek
81
+ - [x] Claude
82
+
83
+ ## 模型支持
84
+
85
+ 此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api):
86
+
87
+ 1. 第三方模型 **gpts** (gpt-4-gizmo-*)
88
+ 2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接口文档](https://docs.newapi.pro/api/midjourney-proxy-image)
89
+ 3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
90
+ 4. 自定义渠道,支持填入完整调用地址
91
+ 5. Rerank模型([Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)),[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
92
+ 6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
93
+ 7. Dify,当前仅支持chatflow
94
+
95
+ ## 环境变量配置
96
+
97
+ 详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
98
+
99
+ - `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
100
+ - `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
101
+ - `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
102
+ - `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
103
+ - `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
104
+ - `GET_MEDIA_TOKEN_NOT_STREAM`:非流情况下是否统计图片token,默认 `true`
105
+ - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认 `true`
106
+ - `COHERE_SAFETY_SETTING`:Cohere模型安全设置,可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
107
+ - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16`
108
+ - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20`
109
+ - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
110
+ - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview`
111
+ - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
112
+ - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
113
+
114
+ ## 部署
115
+
116
+ 详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation):
117
+
118
+ > [!TIP]
119
+ > 最新版Docker镜像:`calciumion/new-api:latest`
120
+
121
+ ### 多机部署注意事项
122
+ - 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
123
+ - 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取
124
+
125
+ ### 部署要求
126
+ - 本地数据库(默认):SQLite(Docker部署必须挂载`/data`目录)
127
+ - 远程数据库:MySQL版本 >= 5.7.8,PgSQL版本 >= 9.6
128
+
129
+ ### 部署方式
130
+
131
+ #### 使用宝塔面板Docker功能部署
132
+ 安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
133
+ [图文教程](BT.md)
134
+
135
+ #### 使用Docker Compose部署(推荐)
136
+ ```shell
137
+ # 下载项目
138
+ git clone https://github.com/Calcium-Ion/new-api.git
139
+ cd new-api
140
+ # 按需编辑docker-compose.yml
141
+ # 启动
142
+ docker-compose up -d
143
+ ```
144
+
145
+ #### 直接使用Docker镜像
146
+ ```shell
147
+ # 使用SQLite
148
+ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
149
+
150
+ # 使用MySQL
151
+ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
152
+ ```
153
+
154
+ ## 渠道重试与缓存
155
+ 渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
156
+
157
+ ### 缓存设置方法
158
+ 1. `REDIS_CONN_STRING`:设置Redis作为缓存
159
+ 2. `MEMORY_CACHE_ENABLED`:启用内存缓存(设置了Redis则无需手动设置)
160
+
161
+ ## 接口文档
162
+
163
+ 详细接口文档请参考[接口文档](https://docs.newapi.pro/api):
164
+
165
+ - [聊天接口(Chat)](https://docs.newapi.pro/api/openai-chat)
166
+ - [图像接口(Image)](https://docs.newapi.pro/api/openai-image)
167
+ - [重排序接口(Rerank)](https://docs.newapi.pro/api/jinaai-rerank)
168
+ - [实时对话接口(Realtime)](https://docs.newapi.pro/api/openai-realtime)
169
+ - [Claude聊天接口(messages)](https://docs.newapi.pro/api/anthropic-chat)
170
+
171
+ ## 相关项目
172
+ - [One API](https://github.com/songquanpeng/one-api):原版项目
173
+ - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
174
+ - [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代AI一站式B/C端解决方案
175
+ - [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
176
+
177
+ 其他基于New API的项目:
178
+ - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
179
+ - [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
180
+
181
+ ## 帮助支持
182
+
183
+ 如有问题,请参考[帮助支持](https://docs.newapi.pro/support):
184
+ - [社区交流](https://docs.newapi.pro/support/community-interaction)
185
+ - [反馈问题](https://docs.newapi.pro/support/feedback-issues)
186
+ - [常见问题](https://docs.newapi.pro/support/faq)
187
+
188
+ ## 🌟 Star History
189
+
190
+ [![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)
Rerank.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rerank API文档
2
+
3
+ **简介**:Rerank API文档
4
+
5
+ ## 接入Dify
6
+ 模型供应商选择Jina,按要求填写模型信息即可接入Dify。
7
+
8
+ ## 请求方式
9
+
10
+ Post: /v1/rerank
11
+
12
+ Request:
13
+
14
+ ```json
15
+ {
16
+ "model": "jina-reranker-v2-base-multilingual",
17
+ "query": "What is the capital of the United States?",
18
+ "top_n": 3,
19
+ "documents": [
20
+ "Carson City is the capital city of the American state of Nevada.",
21
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
22
+ "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
23
+ "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
24
+ "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
25
+ ]
26
+ }
27
+ ```
28
+
29
+ Response:
30
+
31
+ ```json
32
+ {
33
+ "results": [
34
+ {
35
+ "document": {
36
+ "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
37
+ },
38
+ "index": 2,
39
+ "relevance_score": 0.9999702
40
+ },
41
+ {
42
+ "document": {
43
+ "text": "Carson City is the capital city of the American state of Nevada."
44
+ },
45
+ "index": 0,
46
+ "relevance_score": 0.67800725
47
+ },
48
+ {
49
+ "document": {
50
+ "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
51
+ },
52
+ "index": 3,
53
+ "relevance_score": 0.02800752
54
+ }
55
+ ],
56
+ "usage": {
57
+ "prompt_tokens": 158,
58
+ "completion_tokens": 0,
59
+ "total_tokens": 158
60
+ }
61
+ }
62
+ ```
Suno.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Suno API文档
2
+
3
+ **简介**:Suno API文档
4
+
5
+ ## 接口列表
6
+ 支持的接口如下:
7
+ + [x] /suno/submit/music
8
+ + [x] /suno/submit/lyrics
9
+ + [x] /suno/fetch
10
+ + [x] /suno/fetch/:id
11
+
12
+ ## 模型列表
13
+
14
+ ### Suno API支持
15
+
16
+ - suno_music (自定义模式、灵感模式、续写)
17
+ - suno_lyrics (生成歌词)
18
+
19
+
20
+ ## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
21
+ ```json
22
+ {
23
+ "suno_music": 0.3,
24
+ "suno_lyrics": 0.01
25
+ }
26
+ ```
27
+
28
+ ## 渠道设置
29
+
30
+ ### 对接 Suno API
31
+
32
+ 1.
33
+ 部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
34
+
35
+ 2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
36
+ ,模型请参考上方模型列表
37
+ 3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
38
+ 4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
39
+
40
+ ### 对接上游new api
41
+
42
+ 1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
43
+ 2. **代理**填写上游new api的地址,例如:http://localhost:3000
44
+ 3. 密钥填写上游new api的密钥
VERSION ADDED
File without changes
bin/migration_v0.2-v0.3.sql ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ UPDATE users
2
+ SET quota = quota + (
3
+ SELECT SUM(remain_quota)
4
+ FROM tokens
5
+ WHERE tokens.user_id = users.id
6
+ )
bin/migration_v0.3-v0.4.sql ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INSERT INTO abilities (`group`, model, channel_id, enabled)
2
+ SELECT c.`group`, m.model, c.id, 1
3
+ FROM channels c
4
+ CROSS JOIN (
5
+ SELECT 'gpt-3.5-turbo' AS model UNION ALL
6
+ SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
7
+ SELECT 'gpt-4' AS model UNION ALL
8
+ SELECT 'gpt-4-0314' AS model
9
+ ) AS m
10
+ WHERE c.status = 1
11
+ AND NOT EXISTS (
12
+ SELECT 1
13
+ FROM abilities a
14
+ WHERE a.`group` = c.`group`
15
+ AND a.model = m.model
16
+ AND a.channel_id = c.id
17
+ );
bin/time_test.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -lt 3 ]; then
4
+ echo "Usage: time_test.sh <domain> <key> <count> [<model>]"
5
+ exit 1
6
+ fi
7
+
8
+ domain=$1
9
+ key=$2
10
+ count=$3
11
+ model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
12
+
13
+ total_time=0
14
+ times=()
15
+
16
+ for ((i=1; i<=count; i++)); do
17
+ result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \
18
+ https://"$domain"/v1/chat/completions \
19
+ -H "Content-Type: application/json" \
20
+ -H "Authorization: Bearer $key" \
21
+ -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
22
+ http_code=$(echo "$result" | awk '{print $1}')
23
+ time=$(echo "$result" | awk '{print $2}')
24
+ echo "HTTP status code: $http_code, Time taken: $time"
25
+ total_time=$(bc <<< "$total_time + $time")
26
+ times+=("$time")
27
+ done
28
+
29
+ average_time=$(echo "scale=4; $total_time / $count" | bc)
30
+
31
+ sum_of_squares=0
32
+ for time in "${times[@]}"; do
33
+ difference=$(echo "scale=4; $time - $average_time" | bc)
34
+ square=$(echo "scale=4; $difference * $difference" | bc)
35
+ sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc)
36
+ done
37
+
38
+ standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc)
39
+
40
+ echo "Average time: $average_time±$standard_deviation"
common/constants.go ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "os"
5
+ "strconv"
6
+ "sync"
7
+ "time"
8
+
9
+ "github.com/google/uuid"
10
+ )
11
+
12
+ var StartTime = time.Now().Unix() // unit: second
13
+ var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
14
+ var SystemName = "New API"
15
+ var Footer = ""
16
+ var Logo = ""
17
+ var TopUpLink = ""
18
+
19
+ // var ChatLink = ""
20
+ // var ChatLink2 = ""
21
+ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
22
+ var DisplayInCurrencyEnabled = true
23
+ var DisplayTokenStatEnabled = true
24
+ var DrawingEnabled = true
25
+ var TaskEnabled = true
26
+ var DataExportEnabled = true
27
+ var DataExportInterval = 5 // unit: minute
28
+ var DataExportDefaultTime = "hour" // unit: minute
29
+ var DefaultCollapseSidebar = false // default value of collapse sidebar
30
+
31
+ // Any options with "Secret", "Token" in its key won't be return by GetOptions
32
+
33
+ var SessionSecret = uuid.New().String()
34
+ var CryptoSecret = uuid.New().String()
35
+
36
+ var OptionMap map[string]string
37
+ var OptionMapRWMutex sync.RWMutex
38
+
39
+ var ItemsPerPage = 10
40
+ var MaxRecentItems = 100
41
+
42
+ var PasswordLoginEnabled = true
43
+ var PasswordRegisterEnabled = true
44
+ var EmailVerificationEnabled = false
45
+ var GitHubOAuthEnabled = false
46
+ var LinuxDOOAuthEnabled = false
47
+ var WeChatAuthEnabled = false
48
+ var TelegramOAuthEnabled = false
49
+ var TurnstileCheckEnabled = false
50
+ var RegisterEnabled = true
51
+
52
+ var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
53
+ var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
54
+ var EmailDomainWhitelist = []string{
55
+ "gmail.com",
56
+ "163.com",
57
+ "126.com",
58
+ "qq.com",
59
+ "outlook.com",
60
+ "hotmail.com",
61
+ "icloud.com",
62
+ "yahoo.com",
63
+ "foxmail.com",
64
+ }
65
+
66
+ var DebugEnabled = os.Getenv("DEBUG") == "true"
67
+ var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
68
+
69
+ var LogConsumeEnabled = true
70
+
71
+ var SMTPServer = ""
72
+ var SMTPPort = 587
73
+ var SMTPSSLEnabled = false
74
+ var SMTPAccount = ""
75
+ var SMTPFrom = ""
76
+ var SMTPToken = ""
77
+
78
+ var GitHubClientId = ""
79
+ var GitHubClientSecret = ""
80
+ var LinuxDOClientId = ""
81
+ var LinuxDOClientSecret = ""
82
+
83
+ var WeChatServerAddress = ""
84
+ var WeChatServerToken = ""
85
+ var WeChatAccountQRCodeImageURL = ""
86
+
87
+ var TurnstileSiteKey = ""
88
+ var TurnstileSecretKey = ""
89
+
90
+ var TelegramBotToken = ""
91
+ var TelegramBotName = ""
92
+
93
+ var QuotaForNewUser = 0
94
+ var QuotaForInviter = 0
95
+ var QuotaForInvitee = 0
96
+ var ChannelDisableThreshold = 5.0
97
+ var AutomaticDisableChannelEnabled = false
98
+ var AutomaticEnableChannelEnabled = false
99
+ var QuotaRemindThreshold = 1000
100
+ var PreConsumedQuota = 500
101
+
102
+ var RetryTimes = 0
103
+
104
+ //var RootUserEmail = ""
105
+
106
+ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
107
+
108
+ var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
109
+ var RequestInterval = time.Duration(requestInterval) * time.Second
110
+
111
+ var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
112
+
113
+ var BatchUpdateEnabled = false
114
+ var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
115
+
116
+ var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
117
+
118
+ var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
119
+
120
+ // https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
121
+ var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
122
+
123
+ const (
124
+ RequestIdKey = "X-Oneapi-Request-Id"
125
+ )
126
+
127
+ const (
128
+ RoleGuestUser = 0
129
+ RoleCommonUser = 1
130
+ RoleAdminUser = 10
131
+ RoleRootUser = 100
132
+ )
133
+
134
+ func IsValidateRole(role int) bool {
135
+ return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
136
+ }
137
+
138
+ var (
139
+ FileUploadPermission = RoleGuestUser
140
+ FileDownloadPermission = RoleGuestUser
141
+ ImageUploadPermission = RoleGuestUser
142
+ ImageDownloadPermission = RoleGuestUser
143
+ )
144
+
145
+ // All duration's unit is seconds
146
+ // Shouldn't larger then RateLimitKeyExpirationDuration
147
+ var (
148
+ GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
149
+ GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
150
+ GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
151
+
152
+ GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
153
+ GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
154
+ GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
155
+
156
+ UploadRateLimitNum = 10
157
+ UploadRateLimitDuration int64 = 60
158
+
159
+ DownloadRateLimitNum = 10
160
+ DownloadRateLimitDuration int64 = 60
161
+
162
+ CriticalRateLimitNum = 20
163
+ CriticalRateLimitDuration int64 = 20 * 60
164
+ )
165
+
166
+ var RateLimitKeyExpirationDuration = 20 * time.Minute
167
+
168
+ const (
169
+ UserStatusEnabled = 1 // don't use 0, 0 is the default value!
170
+ UserStatusDisabled = 2 // also don't use 0
171
+ )
172
+
173
+ const (
174
+ TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
175
+ TokenStatusDisabled = 2 // also don't use 0
176
+ TokenStatusExpired = 3
177
+ TokenStatusExhausted = 4
178
+ )
179
+
180
+ const (
181
+ RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
182
+ RedemptionCodeStatusDisabled = 2 // also don't use 0
183
+ RedemptionCodeStatusUsed = 3 // also don't use 0
184
+ )
185
+
186
+ const (
187
+ ChannelStatusUnknown = 0
188
+ ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
189
+ ChannelStatusManuallyDisabled = 2 // also don't use 0
190
+ ChannelStatusAutoDisabled = 3
191
+ )
192
+
193
+ const (
194
+ ChannelTypeUnknown = 0
195
+ ChannelTypeOpenAI = 1
196
+ ChannelTypeMidjourney = 2
197
+ ChannelTypeAzure = 3
198
+ ChannelTypeOllama = 4
199
+ ChannelTypeMidjourneyPlus = 5
200
+ ChannelTypeOpenAIMax = 6
201
+ ChannelTypeOhMyGPT = 7
202
+ ChannelTypeCustom = 8
203
+ ChannelTypeAILS = 9
204
+ ChannelTypeAIProxy = 10
205
+ ChannelTypePaLM = 11
206
+ ChannelTypeAPI2GPT = 12
207
+ ChannelTypeAIGC2D = 13
208
+ ChannelTypeAnthropic = 14
209
+ ChannelTypeBaidu = 15
210
+ ChannelTypeZhipu = 16
211
+ ChannelTypeAli = 17
212
+ ChannelTypeXunfei = 18
213
+ ChannelType360 = 19
214
+ ChannelTypeOpenRouter = 20
215
+ ChannelTypeAIProxyLibrary = 21
216
+ ChannelTypeFastGPT = 22
217
+ ChannelTypeTencent = 23
218
+ ChannelTypeGemini = 24
219
+ ChannelTypeMoonshot = 25
220
+ ChannelTypeZhipu_v4 = 26
221
+ ChannelTypePerplexity = 27
222
+ ChannelTypeLingYiWanWu = 31
223
+ ChannelTypeAws = 33
224
+ ChannelTypeCohere = 34
225
+ ChannelTypeMiniMax = 35
226
+ ChannelTypeSunoAPI = 36
227
+ ChannelTypeDify = 37
228
+ ChannelTypeJina = 38
229
+ ChannelCloudflare = 39
230
+ ChannelTypeSiliconFlow = 40
231
+ ChannelTypeVertexAi = 41
232
+ ChannelTypeMistral = 42
233
+ ChannelTypeDeepSeek = 43
234
+ ChannelTypeMokaAI = 44
235
+ ChannelTypeVolcEngine = 45
236
+ ChannelTypeBaiduV2 = 46
237
+ ChannelTypeXinference = 47
238
+ ChannelTypeDummy // this one is only for count, do not add any channel after this
239
+
240
+ )
241
+
242
+ var ChannelBaseURLs = []string{
243
+ "", // 0
244
+ "https://api.openai.com", // 1
245
+ "https://oa.api2d.net", // 2
246
+ "", // 3
247
+ "http://localhost:11434", // 4
248
+ "https://api.openai-sb.com", // 5
249
+ "https://api.openaimax.com", // 6
250
+ "https://api.ohmygpt.com", // 7
251
+ "", // 8
252
+ "https://api.caipacity.com", // 9
253
+ "https://api.aiproxy.io", // 10
254
+ "", // 11
255
+ "https://api.api2gpt.com", // 12
256
+ "https://api.aigc2d.com", // 13
257
+ "https://api.anthropic.com", // 14
258
+ "https://aip.baidubce.com", // 15
259
+ "https://open.bigmodel.cn", // 16
260
+ "https://dashscope.aliyuncs.com", // 17
261
+ "", // 18
262
+ "https://api.360.cn", // 19
263
+ "https://openrouter.ai/api", // 20
264
+ "https://api.aiproxy.io", // 21
265
+ "https://fastgpt.run/api/openapi", // 22
266
+ "https://hunyuan.tencentcloudapi.com", //23
267
+ "https://generativelanguage.googleapis.com", //24
268
+ "https://api.moonshot.cn", //25
269
+ "https://open.bigmodel.cn", //26
270
+ "https://api.perplexity.ai", //27
271
+ "", //28
272
+ "", //29
273
+ "", //30
274
+ "https://api.lingyiwanwu.com", //31
275
+ "", //32
276
+ "", //33
277
+ "https://api.cohere.ai", //34
278
+ "https://api.minimax.chat", //35
279
+ "", //36
280
+ "https://api.dify.ai", //37
281
+ "https://api.jina.ai", //38
282
+ "https://api.cloudflare.com", //39
283
+ "https://api.siliconflow.cn", //40
284
+ "", //41
285
+ "https://api.mistral.ai", //42
286
+ "https://api.deepseek.com", //43
287
+ "https://api.moka.ai", //44
288
+ "https://ark.cn-beijing.volces.com", //45
289
+ "https://qianfan.baidubce.com", //46
290
+ "", //47
291
+ }
common/crypto.go ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "crypto/hmac"
5
+ "crypto/sha256"
6
+ "encoding/hex"
7
+ "golang.org/x/crypto/bcrypt"
8
+ )
9
+
10
+ func GenerateHMACWithKey(key []byte, data string) string {
11
+ h := hmac.New(sha256.New, key)
12
+ h.Write([]byte(data))
13
+ return hex.EncodeToString(h.Sum(nil))
14
+ }
15
+
16
+ func GenerateHMAC(data string) string {
17
+ h := hmac.New(sha256.New, []byte(CryptoSecret))
18
+ h.Write([]byte(data))
19
+ return hex.EncodeToString(h.Sum(nil))
20
+ }
21
+
22
+ func Password2Hash(password string) (string, error) {
23
+ passwordBytes := []byte(password)
24
+ hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
25
+ return string(hashedPassword), err
26
+ }
27
+
28
+ func ValidatePasswordAndHash(password string, hash string) bool {
29
+ err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
30
+ return err == nil
31
+ }
common/custom-event.go ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
2
+ // Use of this source code is governed by a MIT style
3
+ // license that can be found in the LICENSE file.
4
+
5
+ package common
6
+
7
+ import (
8
+ "fmt"
9
+ "io"
10
+ "net/http"
11
+ "strings"
12
+ )
13
+
14
+ type stringWriter interface {
15
+ io.Writer
16
+ writeString(string) (int, error)
17
+ }
18
+
19
+ type stringWrapper struct {
20
+ io.Writer
21
+ }
22
+
23
+ func (w stringWrapper) writeString(str string) (int, error) {
24
+ return w.Writer.Write([]byte(str))
25
+ }
26
+
27
+ func checkWriter(writer io.Writer) stringWriter {
28
+ if w, ok := writer.(stringWriter); ok {
29
+ return w
30
+ } else {
31
+ return stringWrapper{writer}
32
+ }
33
+ }
34
+
35
+ // Server-Sent Events
36
+ // W3C Working Draft 29 October 2009
37
+ // http://www.w3.org/TR/2009/WD-eventsource-20091029/
38
+
39
+ var contentType = []string{"text/event-stream"}
40
+ var noCache = []string{"no-cache"}
41
+
42
+ var fieldReplacer = strings.NewReplacer(
43
+ "\n", "\\n",
44
+ "\r", "\\r")
45
+
46
+ var dataReplacer = strings.NewReplacer(
47
+ "\n", "\n",
48
+ "\r", "\\r")
49
+
50
+ type CustomEvent struct {
51
+ Event string
52
+ Id string
53
+ Retry uint
54
+ Data interface{}
55
+ }
56
+
57
+ func encode(writer io.Writer, event CustomEvent) error {
58
+ w := checkWriter(writer)
59
+ return writeData(w, event.Data)
60
+ }
61
+
62
+ func writeData(w stringWriter, data interface{}) error {
63
+ dataReplacer.WriteString(w, fmt.Sprint(data))
64
+ if strings.HasPrefix(data.(string), "data") {
65
+ w.writeString("\n\n")
66
+ }
67
+ return nil
68
+ }
69
+
70
+ func (r CustomEvent) Render(w http.ResponseWriter) error {
71
+ r.WriteContentType(w)
72
+ return encode(w, r)
73
+ }
74
+
75
+ func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
76
+ header := w.Header()
77
+ header["Content-Type"] = contentType
78
+
79
+ if _, exist := header["Cache-Control"]; !exist {
80
+ header["Cache-Control"] = noCache
81
+ }
82
+ }
common/database.go ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ var UsingSQLite = false
4
+ var UsingPostgreSQL = false
5
+ var UsingMySQL = false
6
+ var UsingClickHouse = false
7
+
8
+ var SQLitePath = "one-api.db?_busy_timeout=5000"
common/email-outlook-auth.go ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "errors"
5
+ "net/smtp"
6
+ "strings"
7
+ )
8
+
9
+ type outlookAuth struct {
10
+ username, password string
11
+ }
12
+
13
+ func LoginAuth(username, password string) smtp.Auth {
14
+ return &outlookAuth{username, password}
15
+ }
16
+
17
+ func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
18
+ return "LOGIN", []byte{}, nil
19
+ }
20
+
21
+ func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
22
+ if more {
23
+ switch string(fromServer) {
24
+ case "Username:":
25
+ return []byte(a.username), nil
26
+ case "Password:":
27
+ return []byte(a.password), nil
28
+ default:
29
+ return nil, errors.New("unknown fromServer")
30
+ }
31
+ }
32
+ return nil, nil
33
+ }
34
+
35
+ func isOutlookServer(server string) bool {
36
+ // 兼容多地区的outlook邮箱和ofb邮箱
37
+ // 其实应该加一个Option来区分是否用LOGIN的方式登录
38
+ // 先临时兼容一下
39
+ return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
40
+ }
common/email.go ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "crypto/tls"
5
+ "encoding/base64"
6
+ "fmt"
7
+ "net/smtp"
8
+ "strings"
9
+ "time"
10
+ )
11
+
12
+ func generateMessageID() (string, error) {
13
+ split := strings.Split(SMTPFrom, "@")
14
+ if len(split) < 2 {
15
+ return "", fmt.Errorf("invalid SMTP account")
16
+ }
17
+ domain := strings.Split(SMTPFrom, "@")[1]
18
+ return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil
19
+ }
20
+
21
+ func SendEmail(subject string, receiver string, content string) error {
22
+ if SMTPFrom == "" { // for compatibility
23
+ SMTPFrom = SMTPAccount
24
+ }
25
+ id, err2 := generateMessageID()
26
+ if err2 != nil {
27
+ return err2
28
+ }
29
+ if SMTPServer == "" && SMTPAccount == "" {
30
+ return fmt.Errorf("SMTP 服务器未配置")
31
+ }
32
+ encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
33
+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
34
+ "From: %s<%s>\r\n"+
35
+ "Subject: %s\r\n"+
36
+ "Date: %s\r\n"+
37
+ "Message-ID: %s\r\n"+ // 添加 Message-ID 头
38
+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
39
+ receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content))
40
+ auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
41
+ addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
42
+ to := strings.Split(receiver, ";")
43
+ var err error
44
+ if SMTPPort == 465 || SMTPSSLEnabled {
45
+ tlsConfig := &tls.Config{
46
+ InsecureSkipVerify: true,
47
+ ServerName: SMTPServer,
48
+ }
49
+ conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
50
+ if err != nil {
51
+ return err
52
+ }
53
+ client, err := smtp.NewClient(conn, SMTPServer)
54
+ if err != nil {
55
+ return err
56
+ }
57
+ defer client.Close()
58
+ if err = client.Auth(auth); err != nil {
59
+ return err
60
+ }
61
+ if err = client.Mail(SMTPFrom); err != nil {
62
+ return err
63
+ }
64
+ receiverEmails := strings.Split(receiver, ";")
65
+ for _, receiver := range receiverEmails {
66
+ if err = client.Rcpt(receiver); err != nil {
67
+ return err
68
+ }
69
+ }
70
+ w, err := client.Data()
71
+ if err != nil {
72
+ return err
73
+ }
74
+ _, err = w.Write(mail)
75
+ if err != nil {
76
+ return err
77
+ }
78
+ err = w.Close()
79
+ if err != nil {
80
+ return err
81
+ }
82
+ } else if isOutlookServer(SMTPAccount) || SMTPServer == "smtp.azurecomm.net" {
83
+ auth = LoginAuth(SMTPAccount, SMTPToken)
84
+ err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
85
+ } else {
86
+ err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
87
+ }
88
+ return err
89
+ }
common/embed-file-system.go ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "embed"
5
+ "github.com/gin-contrib/static"
6
+ "io/fs"
7
+ "net/http"
8
+ )
9
+
10
+ // Credit: https://github.com/gin-contrib/static/issues/19
11
+
12
+ type embedFileSystem struct {
13
+ http.FileSystem
14
+ }
15
+
16
+ func (e embedFileSystem) Exists(prefix string, path string) bool {
17
+ _, err := e.Open(path)
18
+ if err != nil {
19
+ return false
20
+ }
21
+ return true
22
+ }
23
+
24
+ func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
25
+ efs, err := fs.Sub(fsEmbed, targetPath)
26
+ if err != nil {
27
+ panic(err)
28
+ }
29
+ return embedFileSystem{
30
+ FileSystem: http.FS(efs),
31
+ }
32
+ }
common/env.go ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "fmt"
5
+ "os"
6
+ "strconv"
7
+ )
8
+
9
+ func GetEnvOrDefault(env string, defaultValue int) int {
10
+ if env == "" || os.Getenv(env) == "" {
11
+ return defaultValue
12
+ }
13
+ num, err := strconv.Atoi(os.Getenv(env))
14
+ if err != nil {
15
+ SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
16
+ return defaultValue
17
+ }
18
+ return num
19
+ }
20
+
21
+ func GetEnvOrDefaultString(env string, defaultValue string) string {
22
+ if env == "" || os.Getenv(env) == "" {
23
+ return defaultValue
24
+ }
25
+ return os.Getenv(env)
26
+ }
27
+
28
+ func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
29
+ if env == "" || os.Getenv(env) == "" {
30
+ return defaultValue
31
+ }
32
+ b, err := strconv.ParseBool(os.Getenv(env))
33
+ if err != nil {
34
+ SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
35
+ return defaultValue
36
+ }
37
+ return b
38
+ }
common/gin.go ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "github.com/gin-gonic/gin"
7
+ "io"
8
+ "strings"
9
+ )
10
+
11
+ const KeyRequestBody = "key_request_body"
12
+
13
+ func GetRequestBody(c *gin.Context) ([]byte, error) {
14
+ requestBody, _ := c.Get(KeyRequestBody)
15
+ if requestBody != nil {
16
+ return requestBody.([]byte), nil
17
+ }
18
+ requestBody, err := io.ReadAll(c.Request.Body)
19
+ if err != nil {
20
+ return nil, err
21
+ }
22
+ _ = c.Request.Body.Close()
23
+ c.Set(KeyRequestBody, requestBody)
24
+ return requestBody.([]byte), nil
25
+ }
26
+
27
+ func UnmarshalBodyReusable(c *gin.Context, v any) error {
28
+ requestBody, err := GetRequestBody(c)
29
+ if err != nil {
30
+ return err
31
+ }
32
+ contentType := c.Request.Header.Get("Content-Type")
33
+ if strings.HasPrefix(contentType, "application/json") {
34
+ err = json.Unmarshal(requestBody, &v)
35
+ } else {
36
+ // skip for now
37
+ // TODO: someday non json request have variant model, we will need to implementation this
38
+ }
39
+ if err != nil {
40
+ return err
41
+ }
42
+ // Reset request body
43
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
44
+ return nil
45
+ }
common/go-channel.go ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "time"
5
+ )
6
+
7
+ func SafeSendBool(ch chan bool, value bool) (closed bool) {
8
+ defer func() {
9
+ // Recover from panic if one occured. A panic would mean the channel was closed.
10
+ if recover() != nil {
11
+ closed = true
12
+ }
13
+ }()
14
+
15
+ // This will panic if the channel is closed.
16
+ ch <- value
17
+
18
+ // If the code reaches here, then the channel was not closed.
19
+ return false
20
+ }
21
+
22
+ func SafeSendString(ch chan string, value string) (closed bool) {
23
+ defer func() {
24
+ // Recover from panic if one occured. A panic would mean the channel was closed.
25
+ if recover() != nil {
26
+ closed = true
27
+ }
28
+ }()
29
+
30
+ // This will panic if the channel is closed.
31
+ ch <- value
32
+
33
+ // If the code reaches here, then the channel was not closed.
34
+ return false
35
+ }
36
+
37
+ // SafeSendStringTimeout send, return true, else return false
38
+ func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
39
+ defer func() {
40
+ // Recover from panic if one occured. A panic would mean the channel was closed.
41
+ if recover() != nil {
42
+ closed = false
43
+ }
44
+ }()
45
+
46
+ // This will panic if the channel is closed.
47
+ select {
48
+ case ch <- value:
49
+ return true
50
+ case <-time.After(time.Duration(timeout) * time.Second):
51
+ return false
52
+ }
53
+ }
common/gopool.go ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "context"
5
+ "fmt"
6
+ "github.com/bytedance/gopkg/util/gopool"
7
+ "math"
8
+ )
9
+
10
+ var relayGoPool gopool.Pool
11
+
12
+ func init() {
13
+ relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
14
+ relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
15
+ if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
16
+ SafeSendBool(stopChan, true)
17
+ }
18
+ SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
19
+ })
20
+ }
21
+
22
+ func RelayCtxGo(ctx context.Context, f func()) {
23
+ relayGoPool.CtxGo(ctx, f)
24
+ }
common/init.go ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "flag"
5
+ "fmt"
6
+ "log"
7
+ "os"
8
+ "path/filepath"
9
+ )
10
+
11
+ var (
12
+ Port = flag.Int("port", 3000, "the listening port")
13
+ PrintVersion = flag.Bool("version", false, "print version and exit")
14
+ PrintHelp = flag.Bool("help", false, "print help and exit")
15
+ LogDir = flag.String("log-dir", "./logs", "specify the log directory")
16
+ )
17
+
18
+ func printHelp() {
19
+ fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
20
+ fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
21
+ fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
22
+ fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
23
+ }
24
+
25
+ func LoadEnv() {
26
+ flag.Parse()
27
+
28
+ if *PrintVersion {
29
+ fmt.Println(Version)
30
+ os.Exit(0)
31
+ }
32
+
33
+ if *PrintHelp {
34
+ printHelp()
35
+ os.Exit(0)
36
+ }
37
+
38
+ if os.Getenv("SESSION_SECRET") != "" {
39
+ ss := os.Getenv("SESSION_SECRET")
40
+ if ss == "random_string" {
41
+ log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
42
+ log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
43
+ log.Fatal("Please set SESSION_SECRET to a random string.")
44
+ } else {
45
+ SessionSecret = ss
46
+ }
47
+ }
48
+ if os.Getenv("CRYPTO_SECRET") != "" {
49
+ CryptoSecret = os.Getenv("CRYPTO_SECRET")
50
+ } else {
51
+ CryptoSecret = SessionSecret
52
+ }
53
+ if os.Getenv("SQLITE_PATH") != "" {
54
+ SQLitePath = os.Getenv("SQLITE_PATH")
55
+ }
56
+ if *LogDir != "" {
57
+ var err error
58
+ *LogDir, err = filepath.Abs(*LogDir)
59
+ if err != nil {
60
+ log.Fatal(err)
61
+ }
62
+ if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
63
+ err = os.Mkdir(*LogDir, 0777)
64
+ if err != nil {
65
+ log.Fatal(err)
66
+ }
67
+ }
68
+ }
69
+ }
common/json.go ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ )
7
+
8
+ func DecodeJson(data []byte, v any) error {
9
+ return json.NewDecoder(bytes.NewReader(data)).Decode(v)
10
+ }
11
+
12
+ func DecodeJsonStr(data string, v any) error {
13
+ return DecodeJson(StringToByteSlice(data), v)
14
+ }
common/logger.go ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "fmt"
7
+ "github.com/bytedance/gopkg/util/gopool"
8
+ "github.com/gin-gonic/gin"
9
+ "io"
10
+ "log"
11
+ "os"
12
+ "path/filepath"
13
+ "sync"
14
+ "time"
15
+ )
16
+
17
+ const (
18
+ loggerINFO = "INFO"
19
+ loggerWarn = "WARN"
20
+ loggerError = "ERR"
21
+ )
22
+
23
+ const maxLogCount = 1000000
24
+
25
+ var logCount int
26
+ var setupLogLock sync.Mutex
27
+ var setupLogWorking bool
28
+
29
+ func SetupLogger() {
30
+ if *LogDir != "" {
31
+ ok := setupLogLock.TryLock()
32
+ if !ok {
33
+ log.Println("setup log is already working")
34
+ return
35
+ }
36
+ defer func() {
37
+ setupLogLock.Unlock()
38
+ setupLogWorking = false
39
+ }()
40
+ logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
41
+ fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
42
+ if err != nil {
43
+ log.Fatal("failed to open log file")
44
+ }
45
+ gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
46
+ gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
47
+ }
48
+ }
49
+
50
+ func SysLog(s string) {
51
+ t := time.Now()
52
+ _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
53
+ }
54
+
55
+ func SysError(s string) {
56
+ t := time.Now()
57
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
58
+ }
59
+
60
+ func LogInfo(ctx context.Context, msg string) {
61
+ logHelper(ctx, loggerINFO, msg)
62
+ }
63
+
64
+ func LogWarn(ctx context.Context, msg string) {
65
+ logHelper(ctx, loggerWarn, msg)
66
+ }
67
+
68
+ func LogError(ctx context.Context, msg string) {
69
+ logHelper(ctx, loggerError, msg)
70
+ }
71
+
72
+ func logHelper(ctx context.Context, level string, msg string) {
73
+ writer := gin.DefaultErrorWriter
74
+ if level == loggerINFO {
75
+ writer = gin.DefaultWriter
76
+ }
77
+ id := ctx.Value(RequestIdKey)
78
+ now := time.Now()
79
+ _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
80
+ logCount++ // we don't need accurate count, so no lock here
81
+ if logCount > maxLogCount && !setupLogWorking {
82
+ logCount = 0
83
+ setupLogWorking = true
84
+ gopool.Go(func() {
85
+ SetupLogger()
86
+ })
87
+ }
88
+ }
89
+
90
+ func FatalLog(v ...any) {
91
+ t := time.Now()
92
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
93
+ os.Exit(1)
94
+ }
95
+
96
+ func LogQuota(quota int) string {
97
+ if DisplayInCurrencyEnabled {
98
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
99
+ } else {
100
+ return fmt.Sprintf("%d 点额度", quota)
101
+ }
102
+ }
103
+
104
+ func FormatQuota(quota int) string {
105
+ if DisplayInCurrencyEnabled {
106
+ return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
107
+ } else {
108
+ return fmt.Sprintf("%d", quota)
109
+ }
110
+ }
111
+
112
+ // LogJson 仅供测试使用 only for test
113
+ func LogJson(ctx context.Context, msg string, obj any) {
114
+ jsonStr, err := json.Marshal(obj)
115
+ if err != nil {
116
+ LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
117
+ return
118
+ }
119
+ LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
120
+ }
common/pprof.go ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "fmt"
5
+ "github.com/shirou/gopsutil/cpu"
6
+ "os"
7
+ "runtime/pprof"
8
+ "time"
9
+ )
10
+
11
+ // Monitor 定时监控cpu使用率,超过阈值输出pprof文件
12
+ func Monitor() {
13
+ for {
14
+ percent, err := cpu.Percent(time.Second, false)
15
+ if err != nil {
16
+ panic(err)
17
+ }
18
+ if percent[0] > 80 {
19
+ fmt.Println("cpu usage too high")
20
+ // write pprof file
21
+ if _, err := os.Stat("./pprof"); os.IsNotExist(err) {
22
+ err := os.Mkdir("./pprof", os.ModePerm)
23
+ if err != nil {
24
+ SysLog("创建pprof文件夹失败 " + err.Error())
25
+ continue
26
+ }
27
+ }
28
+ f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405")))
29
+ if err != nil {
30
+ SysLog("创建pprof文件失败 " + err.Error())
31
+ continue
32
+ }
33
+ err = pprof.StartCPUProfile(f)
34
+ if err != nil {
35
+ SysLog("启动pprof失败 " + err.Error())
36
+ continue
37
+ }
38
+ time.Sleep(10 * time.Second) // profile for 30 seconds
39
+ pprof.StopCPUProfile()
40
+ f.Close()
41
+ }
42
+ time.Sleep(30 * time.Second)
43
+ }
44
+ }
common/rate-limit.go ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "sync"
5
+ "time"
6
+ )
7
+
8
+ type InMemoryRateLimiter struct {
9
+ store map[string]*[]int64
10
+ mutex sync.Mutex
11
+ expirationDuration time.Duration
12
+ }
13
+
14
+ func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
15
+ if l.store == nil {
16
+ l.mutex.Lock()
17
+ if l.store == nil {
18
+ l.store = make(map[string]*[]int64)
19
+ l.expirationDuration = expirationDuration
20
+ if expirationDuration > 0 {
21
+ go l.clearExpiredItems()
22
+ }
23
+ }
24
+ l.mutex.Unlock()
25
+ }
26
+ }
27
+
28
+ func (l *InMemoryRateLimiter) clearExpiredItems() {
29
+ for {
30
+ time.Sleep(l.expirationDuration)
31
+ l.mutex.Lock()
32
+ now := time.Now().Unix()
33
+ for key := range l.store {
34
+ queue := l.store[key]
35
+ size := len(*queue)
36
+ if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
37
+ delete(l.store, key)
38
+ }
39
+ }
40
+ l.mutex.Unlock()
41
+ }
42
+ }
43
+
44
+ // Request parameter duration's unit is seconds
45
+ func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
46
+ l.mutex.Lock()
47
+ defer l.mutex.Unlock()
48
+ // [old <-- new]
49
+ queue, ok := l.store[key]
50
+ now := time.Now().Unix()
51
+ if ok {
52
+ if len(*queue) < maxRequestNum {
53
+ *queue = append(*queue, now)
54
+ return true
55
+ } else {
56
+ if now-(*queue)[0] >= duration {
57
+ *queue = (*queue)[1:]
58
+ *queue = append(*queue, now)
59
+ return true
60
+ } else {
61
+ return false
62
+ }
63
+ }
64
+ } else {
65
+ s := make([]int64, 0, maxRequestNum)
66
+ l.store[key] = &s
67
+ *(l.store[key]) = append(*(l.store[key]), now)
68
+ }
69
+ return true
70
+ }
common/redis.go ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "context"
5
+ "errors"
6
+ "fmt"
7
+ "os"
8
+ "reflect"
9
+ "strconv"
10
+ "time"
11
+
12
+ "github.com/go-redis/redis/v8"
13
+ "gorm.io/gorm"
14
+ )
15
+
16
+ var RDB *redis.Client
17
+ var RedisEnabled = true
18
+
19
+ // InitRedisClient This function is called after init()
20
+ func InitRedisClient() (err error) {
21
+ if os.Getenv("REDIS_CONN_STRING") == "" {
22
+ RedisEnabled = false
23
+ SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
24
+ return nil
25
+ }
26
+ if os.Getenv("SYNC_FREQUENCY") == "" {
27
+ SysLog("SYNC_FREQUENCY not set, use default value 60")
28
+ SyncFrequency = 60
29
+ }
30
+ SysLog("Redis is enabled")
31
+ opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
32
+ if err != nil {
33
+ FatalLog("failed to parse Redis connection string: " + err.Error())
34
+ }
35
+ opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
36
+ RDB = redis.NewClient(opt)
37
+
38
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
39
+ defer cancel()
40
+
41
+ _, err = RDB.Ping(ctx).Result()
42
+ if err != nil {
43
+ FatalLog("Redis ping test failed: " + err.Error())
44
+ }
45
+ if DebugEnabled {
46
+ SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
47
+ SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
48
+ }
49
+ return err
50
+ }
51
+
52
+ func ParseRedisOption() *redis.Options {
53
+ opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
54
+ if err != nil {
55
+ FatalLog("failed to parse Redis connection string: " + err.Error())
56
+ }
57
+ return opt
58
+ }
59
+
60
+ func RedisSet(key string, value string, expiration time.Duration) error {
61
+ if DebugEnabled {
62
+ SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
63
+ }
64
+ ctx := context.Background()
65
+ return RDB.Set(ctx, key, value, expiration).Err()
66
+ }
67
+
68
+ func RedisGet(key string) (string, error) {
69
+ if DebugEnabled {
70
+ SysLog(fmt.Sprintf("Redis GET: key=%s", key))
71
+ }
72
+ ctx := context.Background()
73
+ val, err := RDB.Get(ctx, key).Result()
74
+ return val, err
75
+ }
76
+
77
+ //func RedisExpire(key string, expiration time.Duration) error {
78
+ // ctx := context.Background()
79
+ // return RDB.Expire(ctx, key, expiration).Err()
80
+ //}
81
+ //
82
+ //func RedisGetEx(key string, expiration time.Duration) (string, error) {
83
+ // ctx := context.Background()
84
+ // return RDB.GetSet(ctx, key, expiration).Result()
85
+ //}
86
+
87
+ func RedisDel(key string) error {
88
+ if DebugEnabled {
89
+ SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
90
+ }
91
+ ctx := context.Background()
92
+ return RDB.Del(ctx, key).Err()
93
+ }
94
+
95
+ func RedisHDelObj(key string) error {
96
+ if DebugEnabled {
97
+ SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
98
+ }
99
+ ctx := context.Background()
100
+ return RDB.HDel(ctx, key).Err()
101
+ }
102
+
103
+ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
104
+ if DebugEnabled {
105
+ SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
106
+ }
107
+ ctx := context.Background()
108
+
109
+ data := make(map[string]interface{})
110
+
111
+ // 使用反射遍历结构体字段
112
+ v := reflect.ValueOf(obj).Elem()
113
+ t := v.Type()
114
+ for i := 0; i < v.NumField(); i++ {
115
+ field := t.Field(i)
116
+ value := v.Field(i)
117
+
118
+ // Skip DeletedAt field
119
+ if field.Type.String() == "gorm.DeletedAt" {
120
+ continue
121
+ }
122
+
123
+ // 处理指针类型
124
+ if value.Kind() == reflect.Ptr {
125
+ if value.IsNil() {
126
+ data[field.Name] = ""
127
+ continue
128
+ }
129
+ value = value.Elem()
130
+ }
131
+
132
+ // 处理布尔类型
133
+ if value.Kind() == reflect.Bool {
134
+ data[field.Name] = strconv.FormatBool(value.Bool())
135
+ continue
136
+ }
137
+
138
+ // 其他类型直接转换为字符串
139
+ data[field.Name] = fmt.Sprintf("%v", value.Interface())
140
+ }
141
+
142
+ txn := RDB.TxPipeline()
143
+ txn.HSet(ctx, key, data)
144
+ txn.Expire(ctx, key, expiration)
145
+
146
+ _, err := txn.Exec(ctx)
147
+ if err != nil {
148
+ return fmt.Errorf("failed to execute transaction: %w", err)
149
+ }
150
+ return nil
151
+ }
152
+
153
+ func RedisHGetObj(key string, obj interface{}) error {
154
+ if DebugEnabled {
155
+ SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
156
+ }
157
+ ctx := context.Background()
158
+
159
+ result, err := RDB.HGetAll(ctx, key).Result()
160
+ if err != nil {
161
+ return fmt.Errorf("failed to load hash from Redis: %w", err)
162
+ }
163
+
164
+ if len(result) == 0 {
165
+ return fmt.Errorf("key %s not found in Redis", key)
166
+ }
167
+
168
+ // Handle both pointer and non-pointer values
169
+ val := reflect.ValueOf(obj)
170
+ if val.Kind() != reflect.Ptr {
171
+ return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
172
+ }
173
+
174
+ v := val.Elem()
175
+ if v.Kind() != reflect.Struct {
176
+ return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
177
+ }
178
+
179
+ t := v.Type()
180
+ for i := 0; i < v.NumField(); i++ {
181
+ field := t.Field(i)
182
+ fieldName := field.Name
183
+ if value, ok := result[fieldName]; ok {
184
+ fieldValue := v.Field(i)
185
+
186
+ // Handle pointer types
187
+ if fieldValue.Kind() == reflect.Ptr {
188
+ if value == "" {
189
+ continue
190
+ }
191
+ if fieldValue.IsNil() {
192
+ fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
193
+ }
194
+ fieldValue = fieldValue.Elem()
195
+ }
196
+
197
+ // Enhanced type handling for Token struct
198
+ switch fieldValue.Kind() {
199
+ case reflect.String:
200
+ fieldValue.SetString(value)
201
+ case reflect.Int, reflect.Int64:
202
+ intValue, err := strconv.ParseInt(value, 10, 64)
203
+ if err != nil {
204
+ return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
205
+ }
206
+ fieldValue.SetInt(intValue)
207
+ case reflect.Bool:
208
+ boolValue, err := strconv.ParseBool(value)
209
+ if err != nil {
210
+ return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
211
+ }
212
+ fieldValue.SetBool(boolValue)
213
+ case reflect.Struct:
214
+ // Special handling for gorm.DeletedAt
215
+ if fieldValue.Type().String() == "gorm.DeletedAt" {
216
+ if value != "" {
217
+ timeValue, err := time.Parse(time.RFC3339, value)
218
+ if err != nil {
219
+ return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
220
+ }
221
+ fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
222
+ }
223
+ }
224
+ default:
225
+ return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
226
+ }
227
+ }
228
+ }
229
+
230
+ return nil
231
+ }
232
+
233
+ // RedisIncr Add this function to handle atomic increments
234
+ func RedisIncr(key string, delta int64) error {
235
+ if DebugEnabled {
236
+ SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
237
+ }
238
+ // 检查键的剩余生存时间
239
+ ttlCmd := RDB.TTL(context.Background(), key)
240
+ ttl, err := ttlCmd.Result()
241
+ if err != nil && !errors.Is(err, redis.Nil) {
242
+ return fmt.Errorf("failed to get TTL: %w", err)
243
+ }
244
+
245
+ // 只有在 key 存在且有 TTL 时才需要特殊处理
246
+ if ttl > 0 {
247
+ ctx := context.Background()
248
+ // 开始一个Redis事务
249
+ txn := RDB.TxPipeline()
250
+
251
+ // 减少余额
252
+ decrCmd := txn.IncrBy(ctx, key, delta)
253
+ if err := decrCmd.Err(); err != nil {
254
+ return err // 如果减少失败,则直接返回错误
255
+ }
256
+
257
+ // 重新设置过期时间,使用原来的过期时间
258
+ txn.Expire(ctx, key, ttl)
259
+
260
+ // 执行事务
261
+ _, err = txn.Exec(ctx)
262
+ return err
263
+ }
264
+ return nil
265
+ }
266
+
267
+ func RedisHIncrBy(key, field string, delta int64) error {
268
+ if DebugEnabled {
269
+ SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
270
+ }
271
+ ttlCmd := RDB.TTL(context.Background(), key)
272
+ ttl, err := ttlCmd.Result()
273
+ if err != nil && !errors.Is(err, redis.Nil) {
274
+ return fmt.Errorf("failed to get TTL: %w", err)
275
+ }
276
+
277
+ if ttl > 0 {
278
+ ctx := context.Background()
279
+ txn := RDB.TxPipeline()
280
+
281
+ incrCmd := txn.HIncrBy(ctx, key, field, delta)
282
+ if err := incrCmd.Err(); err != nil {
283
+ return err
284
+ }
285
+
286
+ txn.Expire(ctx, key, ttl)
287
+
288
+ _, err = txn.Exec(ctx)
289
+ return err
290
+ }
291
+ return nil
292
+ }
293
+
294
+ func RedisHSetField(key, field string, value interface{}) error {
295
+ if DebugEnabled {
296
+ SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
297
+ }
298
+ ttlCmd := RDB.TTL(context.Background(), key)
299
+ ttl, err := ttlCmd.Result()
300
+ if err != nil && !errors.Is(err, redis.Nil) {
301
+ return fmt.Errorf("failed to get TTL: %w", err)
302
+ }
303
+
304
+ if ttl > 0 {
305
+ ctx := context.Background()
306
+ txn := RDB.TxPipeline()
307
+
308
+ hsetCmd := txn.HSet(ctx, key, field, value)
309
+ if err := hsetCmd.Err(); err != nil {
310
+ return err
311
+ }
312
+
313
+ txn.Expire(ctx, key, ttl)
314
+
315
+ _, err = txn.Exec(ctx)
316
+ return err
317
+ }
318
+ return nil
319
+ }
common/str.go ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "encoding/json"
5
+ "math/rand"
6
+ "strconv"
7
+ "unsafe"
8
+ )
9
+
10
+ func GetStringIfEmpty(str string, defaultValue string) string {
11
+ if str == "" {
12
+ return defaultValue
13
+ }
14
+ return str
15
+ }
16
+
17
+ func GetRandomString(length int) string {
18
+ //rand.Seed(time.Now().UnixNano())
19
+ key := make([]byte, length)
20
+ for i := 0; i < length; i++ {
21
+ key[i] = keyChars[rand.Intn(len(keyChars))]
22
+ }
23
+ return string(key)
24
+ }
25
+
26
+ func MapToJsonStr(m map[string]interface{}) string {
27
+ bytes, err := json.Marshal(m)
28
+ if err != nil {
29
+ return ""
30
+ }
31
+ return string(bytes)
32
+ }
33
+
34
+ func StrToMap(str string) map[string]interface{} {
35
+ m := make(map[string]interface{})
36
+ err := json.Unmarshal([]byte(str), &m)
37
+ if err != nil {
38
+ return nil
39
+ }
40
+ return m
41
+ }
42
+
43
+ func IsJsonStr(str string) bool {
44
+ var js map[string]interface{}
45
+ return json.Unmarshal([]byte(str), &js) == nil
46
+ }
47
+
48
+ func String2Int(str string) int {
49
+ num, err := strconv.Atoi(str)
50
+ if err != nil {
51
+ return 0
52
+ }
53
+ return num
54
+ }
55
+
56
+ func StringsContains(strs []string, str string) bool {
57
+ for _, s := range strs {
58
+ if s == str {
59
+ return true
60
+ }
61
+ }
62
+ return false
63
+ }
64
+
65
+ // StringToByteSlice []byte only read, panic on append
66
+ func StringToByteSlice(s string) []byte {
67
+ tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
68
+ tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
69
+ return *(*[]byte)(unsafe.Pointer(&tmp2))
70
+ }
common/topup-ratio.go ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "encoding/json"
5
+ )
6
+
7
+ var TopupGroupRatio = map[string]float64{
8
+ "default": 1,
9
+ "vip": 1,
10
+ "svip": 1,
11
+ }
12
+
13
+ func TopupGroupRatio2JSONString() string {
14
+ jsonBytes, err := json.Marshal(TopupGroupRatio)
15
+ if err != nil {
16
+ SysError("error marshalling model ratio: " + err.Error())
17
+ }
18
+ return string(jsonBytes)
19
+ }
20
+
21
+ func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
22
+ TopupGroupRatio = make(map[string]float64)
23
+ return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
24
+ }
25
+
26
+ func GetTopupGroupRatio(name string) float64 {
27
+ ratio, ok := TopupGroupRatio[name]
28
+ if !ok {
29
+ SysError("topup group ratio not found: " + name)
30
+ return 1
31
+ }
32
+ return ratio
33
+ }
common/utils.go ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "bytes"
5
+ "context"
6
+ crand "crypto/rand"
7
+ "encoding/base64"
8
+ "encoding/json"
9
+ "fmt"
10
+ "github.com/pkg/errors"
11
+ "html/template"
12
+ "io"
13
+ "log"
14
+ "math/big"
15
+ "math/rand"
16
+ "net"
17
+ "os"
18
+ "os/exec"
19
+ "runtime"
20
+ "strconv"
21
+ "strings"
22
+ "time"
23
+
24
+ "github.com/google/uuid"
25
+ )
26
+
27
+ func OpenBrowser(url string) {
28
+ var err error
29
+
30
+ switch runtime.GOOS {
31
+ case "linux":
32
+ err = exec.Command("xdg-open", url).Start()
33
+ case "windows":
34
+ err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
35
+ case "darwin":
36
+ err = exec.Command("open", url).Start()
37
+ }
38
+ if err != nil {
39
+ log.Println(err)
40
+ }
41
+ }
42
+
43
+ func GetIp() (ip string) {
44
+ ips, err := net.InterfaceAddrs()
45
+ if err != nil {
46
+ log.Println(err)
47
+ return ip
48
+ }
49
+
50
+ for _, a := range ips {
51
+ if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
52
+ if ipNet.IP.To4() != nil {
53
+ ip = ipNet.IP.String()
54
+ if strings.HasPrefix(ip, "10") {
55
+ return
56
+ }
57
+ if strings.HasPrefix(ip, "172") {
58
+ return
59
+ }
60
+ if strings.HasPrefix(ip, "192.168") {
61
+ return
62
+ }
63
+ ip = ""
64
+ }
65
+ }
66
+ }
67
+ return
68
+ }
69
+
70
+ var sizeKB = 1024
71
+ var sizeMB = sizeKB * 1024
72
+ var sizeGB = sizeMB * 1024
73
+
74
+ func Bytes2Size(num int64) string {
75
+ numStr := ""
76
+ unit := "B"
77
+ if num/int64(sizeGB) > 1 {
78
+ numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
79
+ unit = "GB"
80
+ } else if num/int64(sizeMB) > 1 {
81
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
82
+ unit = "MB"
83
+ } else if num/int64(sizeKB) > 1 {
84
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
85
+ unit = "KB"
86
+ } else {
87
+ numStr = fmt.Sprintf("%d", num)
88
+ }
89
+ return numStr + " " + unit
90
+ }
91
+
92
+ func Seconds2Time(num int) (time string) {
93
+ if num/31104000 > 0 {
94
+ time += strconv.Itoa(num/31104000) + " 年 "
95
+ num %= 31104000
96
+ }
97
+ if num/2592000 > 0 {
98
+ time += strconv.Itoa(num/2592000) + " 个月 "
99
+ num %= 2592000
100
+ }
101
+ if num/86400 > 0 {
102
+ time += strconv.Itoa(num/86400) + " 天 "
103
+ num %= 86400
104
+ }
105
+ if num/3600 > 0 {
106
+ time += strconv.Itoa(num/3600) + " 小时 "
107
+ num %= 3600
108
+ }
109
+ if num/60 > 0 {
110
+ time += strconv.Itoa(num/60) + " 分钟 "
111
+ num %= 60
112
+ }
113
+ time += strconv.Itoa(num) + " 秒"
114
+ return
115
+ }
116
+
117
+ func Interface2String(inter interface{}) string {
118
+ switch inter.(type) {
119
+ case string:
120
+ return inter.(string)
121
+ case int:
122
+ return fmt.Sprintf("%d", inter.(int))
123
+ case float64:
124
+ return fmt.Sprintf("%f", inter.(float64))
125
+ }
126
+ return "Not Implemented"
127
+ }
128
+
129
+ func UnescapeHTML(x string) interface{} {
130
+ return template.HTML(x)
131
+ }
132
+
133
+ func IntMax(a int, b int) int {
134
+ if a >= b {
135
+ return a
136
+ } else {
137
+ return b
138
+ }
139
+ }
140
+
141
+ func IsIP(s string) bool {
142
+ ip := net.ParseIP(s)
143
+ return ip != nil
144
+ }
145
+
146
+ func GetUUID() string {
147
+ code := uuid.New().String()
148
+ code = strings.Replace(code, "-", "", -1)
149
+ return code
150
+ }
151
+
152
+ const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
153
+
154
+ func init() {
155
+ rand.New(rand.NewSource(time.Now().UnixNano()))
156
+ }
157
+
158
+ func GenerateRandomCharsKey(length int) (string, error) {
159
+ b := make([]byte, length)
160
+ maxI := big.NewInt(int64(len(keyChars)))
161
+
162
+ for i := range b {
163
+ n, err := crand.Int(crand.Reader, maxI)
164
+ if err != nil {
165
+ return "", err
166
+ }
167
+ b[i] = keyChars[n.Int64()]
168
+ }
169
+
170
+ return string(b), nil
171
+ }
172
+
173
+ func GenerateRandomKey(length int) (string, error) {
174
+ bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
175
+ if _, err := crand.Read(bytes); err != nil {
176
+ return "", err
177
+ }
178
+ return base64.StdEncoding.EncodeToString(bytes), nil
179
+ }
180
+
181
+ func GenerateKey() (string, error) {
182
+ //rand.Seed(time.Now().UnixNano())
183
+ return GenerateRandomCharsKey(48)
184
+ }
185
+
186
+ func GetRandomInt(max int) int {
187
+ //rand.Seed(time.Now().UnixNano())
188
+ return rand.Intn(max)
189
+ }
190
+
191
+ func GetTimestamp() int64 {
192
+ return time.Now().Unix()
193
+ }
194
+
195
+ func GetTimeString() string {
196
+ now := time.Now()
197
+ return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
198
+ }
199
+
200
+ func Max(a int, b int) int {
201
+ if a >= b {
202
+ return a
203
+ } else {
204
+ return b
205
+ }
206
+ }
207
+
208
+ func MessageWithRequestId(message string, id string) string {
209
+ return fmt.Sprintf("%s (request id: %s)", message, id)
210
+ }
211
+
212
+ func RandomSleep() {
213
+ // Sleep for 0-3000 ms
214
+ time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
215
+ }
216
+
217
+ func GetPointer[T any](v T) *T {
218
+ return &v
219
+ }
220
+
221
+ func Any2Type[T any](data any) (T, error) {
222
+ var zero T
223
+ bytes, err := json.Marshal(data)
224
+ if err != nil {
225
+ return zero, err
226
+ }
227
+ var res T
228
+ err = json.Unmarshal(bytes, &res)
229
+ if err != nil {
230
+ return zero, err
231
+ }
232
+ return res, nil
233
+ }
234
+
235
+ // SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
236
+ func SaveTmpFile(filename string, data io.Reader) (string, error) {
237
+ f, err := os.CreateTemp(os.TempDir(), filename)
238
+ if err != nil {
239
+ return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
240
+ }
241
+ defer f.Close()
242
+
243
+ _, err = io.Copy(f, data)
244
+ if err != nil {
245
+ return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
246
+ }
247
+
248
+ return f.Name(), nil
249
+ }
250
+
251
+ // GetAudioDuration returns the duration of an audio file in seconds.
252
+ func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
253
+ // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
254
+ c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
255
+ output, err := c.Output()
256
+ if err != nil {
257
+ return 0, errors.Wrap(err, "failed to get audio duration")
258
+ }
259
+
260
+ return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
261
+ }
common/validate.go ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import "github.com/go-playground/validator/v10"
4
+
5
+ var Validate *validator.Validate
6
+
7
+ func init() {
8
+ Validate = validator.New()
9
+ }
common/verification.go ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package common
2
+
3
+ import (
4
+ "github.com/google/uuid"
5
+ "strings"
6
+ "sync"
7
+ "time"
8
+ )
9
+
10
+ type verificationValue struct {
11
+ code string
12
+ time time.Time
13
+ }
14
+
15
+ const (
16
+ EmailVerificationPurpose = "v"
17
+ PasswordResetPurpose = "r"
18
+ )
19
+
20
+ var verificationMutex sync.Mutex
21
+ var verificationMap map[string]verificationValue
22
+ var verificationMapMaxSize = 10
23
+ var VerificationValidMinutes = 10
24
+
25
+ func GenerateVerificationCode(length int) string {
26
+ code := uuid.New().String()
27
+ code = strings.Replace(code, "-", "", -1)
28
+ if length == 0 {
29
+ return code
30
+ }
31
+ return code[:length]
32
+ }
33
+
34
+ func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
35
+ verificationMutex.Lock()
36
+ defer verificationMutex.Unlock()
37
+ verificationMap[purpose+key] = verificationValue{
38
+ code: code,
39
+ time: time.Now(),
40
+ }
41
+ if len(verificationMap) > verificationMapMaxSize {
42
+ removeExpiredPairs()
43
+ }
44
+ }
45
+
46
+ func VerifyCodeWithKey(key string, code string, purpose string) bool {
47
+ verificationMutex.Lock()
48
+ defer verificationMutex.Unlock()
49
+ value, okay := verificationMap[purpose+key]
50
+ now := time.Now()
51
+ if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
52
+ return false
53
+ }
54
+ return code == value.code
55
+ }
56
+
57
+ func DeleteKey(key string, purpose string) {
58
+ verificationMutex.Lock()
59
+ defer verificationMutex.Unlock()
60
+ delete(verificationMap, purpose+key)
61
+ }
62
+
63
+ // no lock inside, so the caller must lock the verificationMap before calling!
64
+ func removeExpiredPairs() {
65
+ now := time.Now()
66
+ for key := range verificationMap {
67
+ if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
68
+ delete(verificationMap, key)
69
+ }
70
+ }
71
+ }
72
+
73
+ func init() {
74
+ verificationMutex.Lock()
75
+ defer verificationMutex.Unlock()
76
+ verificationMap = make(map[string]verificationValue)
77
+ }
constant/cache_key.go ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ import "one-api/common"
4
+
5
+ var (
6
+ TokenCacheSeconds = common.SyncFrequency
7
+ UserId2GroupCacheSeconds = common.SyncFrequency
8
+ UserId2QuotaCacheSeconds = common.SyncFrequency
9
+ UserId2StatusCacheSeconds = common.SyncFrequency
10
+ )
11
+
12
+ // Cache keys
13
+ const (
14
+ UserGroupKeyFmt = "user_group:%d"
15
+ UserQuotaKeyFmt = "user_quota:%d"
16
+ UserEnabledKeyFmt = "user_enabled:%d"
17
+ UserUsernameKeyFmt = "user_name:%d"
18
+ )
19
+
20
+ const (
21
+ TokenFiledRemainQuota = "RemainQuota"
22
+ TokenFieldGroup = "Group"
23
+ )
constant/channel_setting.go ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ var (
4
+ ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
5
+ ChanelSettingProxy = "proxy" // Proxy 代理
6
+ ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
7
+ )
constant/context_key.go ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ const (
4
+ ContextKeyRequestStartTime = "request_start_time"
5
+ ContextKeyUserSetting = "user_setting"
6
+ ContextKeyUserQuota = "user_quota"
7
+ ContextKeyUserStatus = "user_status"
8
+ ContextKeyUserEmail = "user_email"
9
+ ContextKeyUserGroup = "user_group"
10
+ )
constant/env.go ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ import (
4
+ "one-api/common"
5
+ )
6
+
7
+ var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
8
+ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
9
+
10
+ var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
11
+
12
+ // ForceStreamOption 覆盖请求参数,强制返回usage信息
13
+ var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
14
+
15
+ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
16
+
17
+ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
18
+
19
+ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
20
+
21
+ var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
22
+
23
+ //var GeminiModelMap = map[string]string{
24
+ // "gemini-1.0-pro": "v1",
25
+ //}
26
+
27
+ var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
28
+
29
+ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
30
+ var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
31
+
32
+ func InitEnv() {
33
+ //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
34
+ //if modelVersionMapStr == "" {
35
+ // return
36
+ //}
37
+ //for _, pair := range strings.Split(modelVersionMapStr, ",") {
38
+ // parts := strings.Split(pair, ":")
39
+ // if len(parts) == 2 {
40
+ // GeminiModelMap[parts[0]] = parts[1]
41
+ // } else {
42
+ // common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
43
+ // }
44
+ //}
45
+ }
46
+
47
+ // GenerateDefaultToken 是否生成初始令牌,默认关闭。
48
+ var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
constant/finish_reason.go ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ var (
4
+ FinishReasonStop = "stop"
5
+ FinishReasonToolCalls = "tool_calls"
6
+ FinishReasonLength = "length"
7
+ FinishReasonFunctionCall = "function_call"
8
+ FinishReasonContentFilter = "content_filter"
9
+ )
constant/midjourney.go ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ const (
4
+ MjErrorUnknown = 5
5
+ MjRequestError = 4
6
+ )
7
+
8
+ const (
9
+ MjActionImagine = "IMAGINE"
10
+ MjActionDescribe = "DESCRIBE"
11
+ MjActionBlend = "BLEND"
12
+ MjActionUpscale = "UPSCALE"
13
+ MjActionVariation = "VARIATION"
14
+ MjActionReRoll = "REROLL"
15
+ MjActionInPaint = "INPAINT"
16
+ MjActionModal = "MODAL"
17
+ MjActionZoom = "ZOOM"
18
+ MjActionCustomZoom = "CUSTOM_ZOOM"
19
+ MjActionShorten = "SHORTEN"
20
+ MjActionHighVariation = "HIGH_VARIATION"
21
+ MjActionLowVariation = "LOW_VARIATION"
22
+ MjActionPan = "PAN"
23
+ MjActionSwapFace = "SWAP_FACE"
24
+ MjActionUpload = "UPLOAD"
25
+ )
26
+
27
+ var MidjourneyModel2Action = map[string]string{
28
+ "mj_imagine": MjActionImagine,
29
+ "mj_describe": MjActionDescribe,
30
+ "mj_blend": MjActionBlend,
31
+ "mj_upscale": MjActionUpscale,
32
+ "mj_variation": MjActionVariation,
33
+ "mj_reroll": MjActionReRoll,
34
+ "mj_modal": MjActionModal,
35
+ "mj_inpaint": MjActionInPaint,
36
+ "mj_zoom": MjActionZoom,
37
+ "mj_custom_zoom": MjActionCustomZoom,
38
+ "mj_shorten": MjActionShorten,
39
+ "mj_high_variation": MjActionHighVariation,
40
+ "mj_low_variation": MjActionLowVariation,
41
+ "mj_pan": MjActionPan,
42
+ "swap_face": MjActionSwapFace,
43
+ "mj_upload": MjActionUpload,
44
+ }
constant/setup.go ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ package constant
2
+
3
+ var Setup = false
constant/task.go ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ type TaskPlatform string
4
+
5
+ const (
6
+ TaskPlatformSuno TaskPlatform = "suno"
7
+ TaskPlatformMidjourney = "mj"
8
+ )
9
+
10
+ const (
11
+ SunoActionMusic = "MUSIC"
12
+ SunoActionLyrics = "LYRICS"
13
+ )
14
+
15
+ var SunoModel2Action = map[string]string{
16
+ "suno_music": SunoActionMusic,
17
+ "suno_lyrics": SunoActionLyrics,
18
+ }
constant/user_setting.go ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package constant
2
+
3
+ var (
4
+ UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
5
+ UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
6
+ UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
7
+ UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
8
+ UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
9
+ UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
10
+ )
11
+
12
+ var (
13
+ NotifyTypeEmail = "email" // Email 邮件
14
+ NotifyTypeWebhook = "webhook" // Webhook
15
+ )
controller/billing.go ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package controller
2
+
3
+ import (
4
+ "github.com/gin-gonic/gin"
5
+ "one-api/common"
6
+ "one-api/dto"
7
+ "one-api/model"
8
+ )
9
+
10
+ func GetSubscription(c *gin.Context) {
11
+ var remainQuota int
12
+ var usedQuota int
13
+ var err error
14
+ var token *model.Token
15
+ var expiredTime int64
16
+ if common.DisplayTokenStatEnabled {
17
+ tokenId := c.GetInt("token_id")
18
+ token, err = model.GetTokenById(tokenId)
19
+ expiredTime = token.ExpiredTime
20
+ remainQuota = token.RemainQuota
21
+ usedQuota = token.UsedQuota
22
+ } else {
23
+ userId := c.GetInt("id")
24
+ remainQuota, err = model.GetUserQuota(userId, false)
25
+ usedQuota, err = model.GetUserUsedQuota(userId)
26
+ }
27
+ if expiredTime <= 0 {
28
+ expiredTime = 0
29
+ }
30
+ if err != nil {
31
+ openAIError := dto.OpenAIError{
32
+ Message: err.Error(),
33
+ Type: "upstream_error",
34
+ }
35
+ c.JSON(200, gin.H{
36
+ "error": openAIError,
37
+ })
38
+ return
39
+ }
40
+ quota := remainQuota + usedQuota
41
+ amount := float64(quota)
42
+ if common.DisplayInCurrencyEnabled {
43
+ amount /= common.QuotaPerUnit
44
+ }
45
+ if token != nil && token.UnlimitedQuota {
46
+ amount = 100000000
47
+ }
48
+ subscription := OpenAISubscriptionResponse{
49
+ Object: "billing_subscription",
50
+ HasPaymentMethod: true,
51
+ SoftLimitUSD: amount,
52
+ HardLimitUSD: amount,
53
+ SystemHardLimitUSD: amount,
54
+ AccessUntil: expiredTime,
55
+ }
56
+ c.JSON(200, subscription)
57
+ return
58
+ }
59
+
60
+ func GetUsage(c *gin.Context) {
61
+ var quota int
62
+ var err error
63
+ var token *model.Token
64
+ if common.DisplayTokenStatEnabled {
65
+ tokenId := c.GetInt("token_id")
66
+ token, err = model.GetTokenById(tokenId)
67
+ quota = token.UsedQuota
68
+ } else {
69
+ userId := c.GetInt("id")
70
+ quota, err = model.GetUserUsedQuota(userId)
71
+ }
72
+ if err != nil {
73
+ openAIError := dto.OpenAIError{
74
+ Message: err.Error(),
75
+ Type: "new_api_error",
76
+ }
77
+ c.JSON(200, gin.H{
78
+ "error": openAIError,
79
+ })
80
+ return
81
+ }
82
+ amount := float64(quota)
83
+ if common.DisplayInCurrencyEnabled {
84
+ amount /= common.QuotaPerUnit
85
+ }
86
+ usage := OpenAIUsageResponse{
87
+ Object: "list",
88
+ TotalUsage: amount * 100,
89
+ }
90
+ c.JSON(200, usage)
91
+ return
92
+ }
controller/channel-billing.go ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package controller
2
+
3
+ import (
4
+ "encoding/json"
5
+ "errors"
6
+ "fmt"
7
+ "io"
8
+ "net/http"
9
+ "one-api/common"
10
+ "one-api/model"
11
+ "one-api/service"
12
+ "strconv"
13
+ "time"
14
+
15
+ "github.com/gin-gonic/gin"
16
+ )
17
+
18
+ // https://github.com/songquanpeng/one-api/issues/79
19
+
20
+ type OpenAISubscriptionResponse struct {
21
+ Object string `json:"object"`
22
+ HasPaymentMethod bool `json:"has_payment_method"`
23
+ SoftLimitUSD float64 `json:"soft_limit_usd"`
24
+ HardLimitUSD float64 `json:"hard_limit_usd"`
25
+ SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
26
+ AccessUntil int64 `json:"access_until"`
27
+ }
28
+
29
+ type OpenAIUsageDailyCost struct {
30
+ Timestamp float64 `json:"timestamp"`
31
+ LineItems []struct {
32
+ Name string `json:"name"`
33
+ Cost float64 `json:"cost"`
34
+ }
35
+ }
36
+
37
+ type OpenAICreditGrants struct {
38
+ Object string `json:"object"`
39
+ TotalGranted float64 `json:"total_granted"`
40
+ TotalUsed float64 `json:"total_used"`
41
+ TotalAvailable float64 `json:"total_available"`
42
+ }
43
+
44
+ type OpenAIUsageResponse struct {
45
+ Object string `json:"object"`
46
+ //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
47
+ TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
48
+ }
49
+
50
+ type OpenAISBUsageResponse struct {
51
+ Msg string `json:"msg"`
52
+ Data *struct {
53
+ Credit string `json:"credit"`
54
+ } `json:"data"`
55
+ }
56
+
57
+ type AIProxyUserOverviewResponse struct {
58
+ Success bool `json:"success"`
59
+ Message string `json:"message"`
60
+ ErrorCode int `json:"error_code"`
61
+ Data struct {
62
+ TotalPoints float64 `json:"totalPoints"`
63
+ } `json:"data"`
64
+ }
65
+
66
+ type API2GPTUsageResponse struct {
67
+ Object string `json:"object"`
68
+ TotalGranted float64 `json:"total_granted"`
69
+ TotalUsed float64 `json:"total_used"`
70
+ TotalRemaining float64 `json:"total_remaining"`
71
+ }
72
+
73
+ type APGC2DGPTUsageResponse struct {
74
+ //Grants interface{} `json:"grants"`
75
+ Object string `json:"object"`
76
+ TotalAvailable float64 `json:"total_available"`
77
+ TotalGranted float64 `json:"total_granted"`
78
+ TotalUsed float64 `json:"total_used"`
79
+ }
80
+
81
+ type SiliconFlowUsageResponse struct {
82
+ Code int `json:"code"`
83
+ Message string `json:"message"`
84
+ Status bool `json:"status"`
85
+ Data struct {
86
+ ID string `json:"id"`
87
+ Name string `json:"name"`
88
+ Image string `json:"image"`
89
+ Email string `json:"email"`
90
+ IsAdmin bool `json:"isAdmin"`
91
+ Balance string `json:"balance"`
92
+ Status string `json:"status"`
93
+ Introduction string `json:"introduction"`
94
+ Role string `json:"role"`
95
+ ChargeBalance string `json:"chargeBalance"`
96
+ TotalBalance string `json:"totalBalance"`
97
+ Category string `json:"category"`
98
+ } `json:"data"`
99
+ }
100
+
101
+ type DeepSeekUsageResponse struct {
102
+ IsAvailable bool `json:"is_available"`
103
+ BalanceInfos []struct {
104
+ Currency string `json:"currency"`
105
+ TotalBalance string `json:"total_balance"`
106
+ GrantedBalance string `json:"granted_balance"`
107
+ ToppedUpBalance string `json:"topped_up_balance"`
108
+ } `json:"balance_infos"`
109
+ }
110
+
111
+ // GetAuthHeader get auth header
112
+ func GetAuthHeader(token string) http.Header {
113
+ h := http.Header{}
114
+ h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
115
+ return h
116
+ }
117
+
118
+ func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
119
+ req, err := http.NewRequest(method, url, nil)
120
+ if err != nil {
121
+ return nil, err
122
+ }
123
+ for k := range headers {
124
+ req.Header.Add(k, headers.Get(k))
125
+ }
126
+ res, err := service.GetHttpClient().Do(req)
127
+ if err != nil {
128
+ return nil, err
129
+ }
130
+ if res.StatusCode != http.StatusOK {
131
+ return nil, fmt.Errorf("status code: %d", res.StatusCode)
132
+ }
133
+ body, err := io.ReadAll(res.Body)
134
+ if err != nil {
135
+ return nil, err
136
+ }
137
+ err = res.Body.Close()
138
+ if err != nil {
139
+ return nil, err
140
+ }
141
+ return body, nil
142
+ }
143
+
144
+ func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
145
+ url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
146
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
147
+
148
+ if err != nil {
149
+ return 0, err
150
+ }
151
+ response := OpenAICreditGrants{}
152
+ err = json.Unmarshal(body, &response)
153
+ if err != nil {
154
+ return 0, err
155
+ }
156
+ channel.UpdateBalance(response.TotalAvailable)
157
+ return response.TotalAvailable, nil
158
+ }
159
+
160
+ func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
161
+ url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
162
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
163
+ if err != nil {
164
+ return 0, err
165
+ }
166
+ response := OpenAISBUsageResponse{}
167
+ err = json.Unmarshal(body, &response)
168
+ if err != nil {
169
+ return 0, err
170
+ }
171
+ if response.Data == nil {
172
+ return 0, errors.New(response.Msg)
173
+ }
174
+ balance, err := strconv.ParseFloat(response.Data.Credit, 64)
175
+ if err != nil {
176
+ return 0, err
177
+ }
178
+ channel.UpdateBalance(balance)
179
+ return balance, nil
180
+ }
181
+
182
+ func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
183
+ url := "https://aiproxy.io/api/report/getUserOverview"
184
+ headers := http.Header{}
185
+ headers.Add("Api-Key", channel.Key)
186
+ body, err := GetResponseBody("GET", url, channel, headers)
187
+ if err != nil {
188
+ return 0, err
189
+ }
190
+ response := AIProxyUserOverviewResponse{}
191
+ err = json.Unmarshal(body, &response)
192
+ if err != nil {
193
+ return 0, err
194
+ }
195
+ if !response.Success {
196
+ return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
197
+ }
198
+ channel.UpdateBalance(response.Data.TotalPoints)
199
+ return response.Data.TotalPoints, nil
200
+ }
201
+
202
+ func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
203
+ url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
204
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
205
+
206
+ if err != nil {
207
+ return 0, err
208
+ }
209
+ response := API2GPTUsageResponse{}
210
+ err = json.Unmarshal(body, &response)
211
+ if err != nil {
212
+ return 0, err
213
+ }
214
+ channel.UpdateBalance(response.TotalRemaining)
215
+ return response.TotalRemaining, nil
216
+ }
217
+
218
+ func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
219
+ url := "https://api.siliconflow.cn/v1/user/info"
220
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
221
+ if err != nil {
222
+ return 0, err
223
+ }
224
+ response := SiliconFlowUsageResponse{}
225
+ err = json.Unmarshal(body, &response)
226
+ if err != nil {
227
+ return 0, err
228
+ }
229
+ if response.Code != 20000 {
230
+ return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
231
+ }
232
+ balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
233
+ if err != nil {
234
+ return 0, err
235
+ }
236
+ channel.UpdateBalance(balance)
237
+ return balance, nil
238
+ }
239
+
240
+ func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
241
+ url := "https://api.deepseek.com/user/balance"
242
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
243
+ if err != nil {
244
+ return 0, err
245
+ }
246
+ response := DeepSeekUsageResponse{}
247
+ err = json.Unmarshal(body, &response)
248
+ if err != nil {
249
+ return 0, err
250
+ }
251
+ index := -1
252
+ for i, balanceInfo := range response.BalanceInfos {
253
+ if balanceInfo.Currency == "CNY" {
254
+ index = i
255
+ break
256
+ }
257
+ }
258
+ if index == -1 {
259
+ return 0, errors.New("currency CNY not found")
260
+ }
261
+ balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
262
+ if err != nil {
263
+ return 0, err
264
+ }
265
+ channel.UpdateBalance(balance)
266
+ return balance, nil
267
+ }
268
+
269
+ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
270
+ url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
271
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
272
+ if err != nil {
273
+ return 0, err
274
+ }
275
+ response := APGC2DGPTUsageResponse{}
276
+ err = json.Unmarshal(body, &response)
277
+ if err != nil {
278
+ return 0, err
279
+ }
280
+ channel.UpdateBalance(response.TotalAvailable)
281
+ return response.TotalAvailable, nil
282
+ }
283
+
284
+ func updateChannelBalance(channel *model.Channel) (float64, error) {
285
+ baseURL := common.ChannelBaseURLs[channel.Type]
286
+ if channel.GetBaseURL() == "" {
287
+ channel.BaseURL = &baseURL
288
+ }
289
+ switch channel.Type {
290
+ case common.ChannelTypeOpenAI:
291
+ if channel.GetBaseURL() != "" {
292
+ baseURL = channel.GetBaseURL()
293
+ }
294
+ case common.ChannelTypeAzure:
295
+ return 0, errors.New("尚未实现")
296
+ case common.ChannelTypeCustom:
297
+ baseURL = channel.GetBaseURL()
298
+ //case common.ChannelTypeOpenAISB:
299
+ // return updateChannelOpenAISBBalance(channel)
300
+ case common.ChannelTypeAIProxy:
301
+ return updateChannelAIProxyBalance(channel)
302
+ case common.ChannelTypeAPI2GPT:
303
+ return updateChannelAPI2GPTBalance(channel)
304
+ case common.ChannelTypeAIGC2D:
305
+ return updateChannelAIGC2DBalance(channel)
306
+ case common.ChannelTypeSiliconFlow:
307
+ return updateChannelSiliconFlowBalance(channel)
308
+ case common.ChannelTypeDeepSeek:
309
+ return updateChannelDeepSeekBalance(channel)
310
+ default:
311
+ return 0, errors.New("尚未实现")
312
+ }
313
+ url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
314
+
315
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
316
+ if err != nil {
317
+ return 0, err
318
+ }
319
+ subscription := OpenAISubscriptionResponse{}
320
+ err = json.Unmarshal(body, &subscription)
321
+ if err != nil {
322
+ return 0, err
323
+ }
324
+ now := time.Now()
325
+ startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
326
+ endDate := now.Format("2006-01-02")
327
+ if !subscription.HasPaymentMethod {
328
+ startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
329
+ }
330
+ url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
331
+ body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
332
+ if err != nil {
333
+ return 0, err
334
+ }
335
+ usage := OpenAIUsageResponse{}
336
+ err = json.Unmarshal(body, &usage)
337
+ if err != nil {
338
+ return 0, err
339
+ }
340
+ balance := subscription.HardLimitUSD - usage.TotalUsage/100
341
+ channel.UpdateBalance(balance)
342
+ return balance, nil
343
+ }
344
+
345
+ func UpdateChannelBalance(c *gin.Context) {
346
+ id, err := strconv.Atoi(c.Param("id"))
347
+ if err != nil {
348
+ c.JSON(http.StatusOK, gin.H{
349
+ "success": false,
350
+ "message": err.Error(),
351
+ })
352
+ return
353
+ }
354
+ channel, err := model.GetChannelById(id, true)
355
+ if err != nil {
356
+ c.JSON(http.StatusOK, gin.H{
357
+ "success": false,
358
+ "message": err.Error(),
359
+ })
360
+ return
361
+ }
362
+ balance, err := updateChannelBalance(channel)
363
+ if err != nil {
364
+ c.JSON(http.StatusOK, gin.H{
365
+ "success": false,
366
+ "message": err.Error(),
367
+ })
368
+ return
369
+ }
370
+ c.JSON(http.StatusOK, gin.H{
371
+ "success": true,
372
+ "message": "",
373
+ "balance": balance,
374
+ })
375
+ return
376
+ }
377
+
378
+ func updateAllChannelsBalance() error {
379
+ channels, err := model.GetAllChannels(0, 0, true, false)
380
+ if err != nil {
381
+ return err
382
+ }
383
+ for _, channel := range channels {
384
+ if channel.Status != common.ChannelStatusEnabled {
385
+ continue
386
+ }
387
+ // TODO: support Azure
388
+ //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
389
+ // continue
390
+ //}
391
+ balance, err := updateChannelBalance(channel)
392
+ if err != nil {
393
+ continue
394
+ } else {
395
+ // err is nil & balance <= 0 means quota is used up
396
+ if balance <= 0 {
397
+ service.DisableChannel(channel.Id, channel.Name, "余额不足")
398
+ }
399
+ }
400
+ time.Sleep(common.RequestInterval)
401
+ }
402
+ return nil
403
+ }
404
+
405
+ func UpdateAllChannelsBalance(c *gin.Context) {
406
+ // TODO: make it async
407
+ err := updateAllChannelsBalance()
408
+ if err != nil {
409
+ c.JSON(http.StatusOK, gin.H{
410
+ "success": false,
411
+ "message": err.Error(),
412
+ })
413
+ return
414
+ }
415
+ c.JSON(http.StatusOK, gin.H{
416
+ "success": true,
417
+ "message": "",
418
+ })
419
+ return
420
+ }
421
+
422
+ func AutomaticallyUpdateChannels(frequency int) {
423
+ for {
424
+ time.Sleep(time.Duration(frequency) * time.Minute)
425
+ common.SysLog("updating all channels")
426
+ _ = updateAllChannelsBalance()
427
+ common.SysLog("channels update done")
428
+ }
429
+ }
controller/channel-test.go ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package controller
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "errors"
7
+ "fmt"
8
+ "io"
9
+ "math"
10
+ "net/http"
11
+ "net/http/httptest"
12
+ "net/url"
13
+ "one-api/common"
14
+ "one-api/dto"
15
+ "one-api/middleware"
16
+ "one-api/model"
17
+ "one-api/relay"
18
+ relaycommon "one-api/relay/common"
19
+ "one-api/relay/constant"
20
+ "one-api/relay/helper"
21
+ "one-api/service"
22
+ "strconv"
23
+ "strings"
24
+ "sync"
25
+ "time"
26
+
27
+ "github.com/bytedance/gopkg/util/gopool"
28
+
29
+ "github.com/gin-gonic/gin"
30
+ )
31
+
32
+ func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
33
+ tik := time.Now()
34
+ if channel.Type == common.ChannelTypeMidjourney {
35
+ return errors.New("midjourney channel test is not supported"), nil
36
+ }
37
+ if channel.Type == common.ChannelTypeMidjourneyPlus {
38
+ return errors.New("midjourney plus channel test is not supported!!!"), nil
39
+ }
40
+ if channel.Type == common.ChannelTypeSunoAPI {
41
+ return errors.New("suno channel test is not supported"), nil
42
+ }
43
+ w := httptest.NewRecorder()
44
+ c, _ := gin.CreateTestContext(w)
45
+
46
+ requestPath := "/v1/chat/completions"
47
+
48
+ // 先判断是否为 Embedding 模型
49
+ if strings.Contains(strings.ToLower(testModel), "embedding") ||
50
+ strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
51
+ strings.Contains(testModel, "bge-") || // bge 系列模型
52
+ strings.Contains(testModel, "embed") ||
53
+ channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
54
+ requestPath = "/v1/embeddings" // 修改请求路径
55
+ }
56
+
57
+ c.Request = &http.Request{
58
+ Method: "POST",
59
+ URL: &url.URL{Path: requestPath}, // 使用动态路径
60
+ Body: nil,
61
+ Header: make(http.Header),
62
+ }
63
+
64
+ if testModel == "" {
65
+ if channel.TestModel != nil && *channel.TestModel != "" {
66
+ testModel = *channel.TestModel
67
+ } else {
68
+ if len(channel.GetModels()) > 0 {
69
+ testModel = channel.GetModels()[0]
70
+ } else {
71
+ testModel = "gpt-4o-mini"
72
+ }
73
+ }
74
+ }
75
+
76
+ cache, err := model.GetUserCache(1)
77
+ if err != nil {
78
+ return err, nil
79
+ }
80
+ cache.WriteContext(c)
81
+
82
+ c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
83
+ c.Request.Header.Set("Content-Type", "application/json")
84
+ c.Set("channel", channel.Type)
85
+ c.Set("base_url", channel.GetBaseURL())
86
+ group, _ := model.GetUserGroup(1, false)
87
+ c.Set("group", group)
88
+
89
+ middleware.SetupContextForSelectedChannel(c, channel, testModel)
90
+
91
+ info := relaycommon.GenRelayInfo(c)
92
+
93
+ err = helper.ModelMappedHelper(c, info)
94
+ if err != nil {
95
+ return err, nil
96
+ }
97
+ testModel = info.UpstreamModelName
98
+
99
+ apiType, _ := constant.ChannelType2APIType(channel.Type)
100
+ adaptor := relay.GetAdaptor(apiType)
101
+ if adaptor == nil {
102
+ return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
103
+ }
104
+
105
+ request := buildTestRequest(testModel)
106
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
107
+
108
+ priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
109
+ if err != nil {
110
+ return err, nil
111
+ }
112
+
113
+ adaptor.Init(info)
114
+
115
+ convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
116
+ if err != nil {
117
+ return err, nil
118
+ }
119
+ jsonData, err := json.Marshal(convertedRequest)
120
+ if err != nil {
121
+ return err, nil
122
+ }
123
+ requestBody := bytes.NewBuffer(jsonData)
124
+ c.Request.Body = io.NopCloser(requestBody)
125
+ resp, err := adaptor.DoRequest(c, info, requestBody)
126
+ if err != nil {
127
+ return err, nil
128
+ }
129
+ var httpResp *http.Response
130
+ if resp != nil {
131
+ httpResp = resp.(*http.Response)
132
+ if httpResp.StatusCode != http.StatusOK {
133
+ err := service.RelayErrorHandler(httpResp, true)
134
+ return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
135
+ }
136
+ }
137
+ usageA, respErr := adaptor.DoResponse(c, httpResp, info)
138
+ if respErr != nil {
139
+ return fmt.Errorf("%s", respErr.Error.Message), respErr
140
+ }
141
+ if usageA == nil {
142
+ return errors.New("usage is nil"), nil
143
+ }
144
+ usage := usageA.(*dto.Usage)
145
+ result := w.Result()
146
+ respBody, err := io.ReadAll(result.Body)
147
+ if err != nil {
148
+ return err, nil
149
+ }
150
+ info.PromptTokens = usage.PromptTokens
151
+
152
+ quota := 0
153
+ if !priceData.UsePrice {
154
+ quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
155
+ quota = int(math.Round(float64(quota) * priceData.ModelRatio))
156
+ if priceData.ModelRatio != 0 && quota <= 0 {
157
+ quota = 1
158
+ }
159
+ } else {
160
+ quota = int(priceData.ModelPrice * common.QuotaPerUnit)
161
+ }
162
+ tok := time.Now()
163
+ milliseconds := tok.Sub(tik).Milliseconds()
164
+ consumedTime := float64(milliseconds) / 1000.0
165
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
166
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
167
+ model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
168
+ quota, "模型测��", 0, quota, int(consumedTime), false, info.Group, other)
169
+ common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
170
+ return nil, nil
171
+ }
172
+
173
+ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
174
+ testRequest := &dto.GeneralOpenAIRequest{
175
+ Model: "", // this will be set later
176
+ Stream: false,
177
+ }
178
+
179
+ // 先判断是否为 Embedding 模型
180
+ if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
181
+ strings.HasPrefix(model, "m3e") || // m3e 系列模型
182
+ strings.Contains(model, "bge-") {
183
+ testRequest.Model = model
184
+ // Embedding 请求
185
+ testRequest.Input = []string{"hello world"}
186
+ return testRequest
187
+ }
188
+ // 并非Embedding 模型
189
+ if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
190
+ testRequest.MaxCompletionTokens = 10
191
+ } else if strings.Contains(model, "thinking") {
192
+ if !strings.Contains(model, "claude") {
193
+ testRequest.MaxTokens = 50
194
+ }
195
+ } else {
196
+ testRequest.MaxTokens = 10
197
+ }
198
+ content, _ := json.Marshal("hi")
199
+ testMessage := dto.Message{
200
+ Role: "user",
201
+ Content: content,
202
+ }
203
+ testRequest.Model = model
204
+ testRequest.Messages = append(testRequest.Messages, testMessage)
205
+ return testRequest
206
+ }
207
+
208
+ func TestChannel(c *gin.Context) {
209
+ channelId, err := strconv.Atoi(c.Param("id"))
210
+ if err != nil {
211
+ c.JSON(http.StatusOK, gin.H{
212
+ "success": false,
213
+ "message": err.Error(),
214
+ })
215
+ return
216
+ }
217
+ channel, err := model.GetChannelById(channelId, true)
218
+ if err != nil {
219
+ c.JSON(http.StatusOK, gin.H{
220
+ "success": false,
221
+ "message": err.Error(),
222
+ })
223
+ return
224
+ }
225
+ testModel := c.Query("model")
226
+ tik := time.Now()
227
+ err, _ = testChannel(channel, testModel)
228
+ tok := time.Now()
229
+ milliseconds := tok.Sub(tik).Milliseconds()
230
+ go channel.UpdateResponseTime(milliseconds)
231
+ consumedTime := float64(milliseconds) / 1000.0
232
+ if err != nil {
233
+ c.JSON(http.StatusOK, gin.H{
234
+ "success": false,
235
+ "message": err.Error(),
236
+ "time": consumedTime,
237
+ })
238
+ return
239
+ }
240
+ c.JSON(http.StatusOK, gin.H{
241
+ "success": true,
242
+ "message": "",
243
+ "time": consumedTime,
244
+ })
245
+ return
246
+ }
247
+
248
+ var testAllChannelsLock sync.Mutex
249
+ var testAllChannelsRunning bool = false
250
+
251
+ func testAllChannels(notify bool) error {
252
+
253
+ testAllChannelsLock.Lock()
254
+ if testAllChannelsRunning {
255
+ testAllChannelsLock.Unlock()
256
+ return errors.New("测试已在运行中")
257
+ }
258
+ testAllChannelsRunning = true
259
+ testAllChannelsLock.Unlock()
260
+ channels, err := model.GetAllChannels(0, 0, true, false)
261
+ if err != nil {
262
+ return err
263
+ }
264
+ var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
265
+ if disableThreshold == 0 {
266
+ disableThreshold = 10000000 // a impossible value
267
+ }
268
+ gopool.Go(func() {
269
+ for _, channel := range channels {
270
+ isChannelEnabled := channel.Status == common.ChannelStatusEnabled
271
+ tik := time.Now()
272
+ err, openaiWithStatusErr := testChannel(channel, "")
273
+ tok := time.Now()
274
+ milliseconds := tok.Sub(tik).Milliseconds()
275
+
276
+ shouldBanChannel := false
277
+
278
+ // request error disables the channel
279
+ if openaiWithStatusErr != nil {
280
+ oaiErr := openaiWithStatusErr.Error
281
+ err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
282
+ shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
283
+ }
284
+
285
+ if milliseconds > disableThreshold {
286
+ err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
287
+ shouldBanChannel = true
288
+ }
289
+
290
+ // disable channel
291
+ if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
292
+ service.DisableChannel(channel.Id, channel.Name, err.Error())
293
+ }
294
+
295
+ // enable channel
296
+ if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
297
+ service.EnableChannel(channel.Id, channel.Name)
298
+ }
299
+
300
+ channel.UpdateResponseTime(milliseconds)
301
+ time.Sleep(common.RequestInterval)
302
+ }
303
+ testAllChannelsLock.Lock()
304
+ testAllChannelsRunning = false
305
+ testAllChannelsLock.Unlock()
306
+ if notify {
307
+ service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
308
+ }
309
+ })
310
+ return nil
311
+ }
312
+
313
+ func TestAllChannels(c *gin.Context) {
314
+ err := testAllChannels(true)
315
+ if err != nil {
316
+ c.JSON(http.StatusOK, gin.H{
317
+ "success": false,
318
+ "message": err.Error(),
319
+ })
320
+ return
321
+ }
322
+ c.JSON(http.StatusOK, gin.H{
323
+ "success": true,
324
+ "message": "",
325
+ })
326
+ return
327
+ }
328
+
329
+ func AutomaticallyTestChannels(frequency int) {
330
+ for {
331
+ time.Sleep(time.Duration(frequency) * time.Minute)
332
+ common.SysLog("testing all channels")
333
+ _ = testAllChannels(false)
334
+ common.SysLog("channel test finished")
335
+ }
336
+ }
controller/channel.go ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package controller
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "net/http"
7
+ "one-api/common"
8
+ "one-api/model"
9
+ "strconv"
10
+ "strings"
11
+
12
+ "github.com/gin-gonic/gin"
13
+ )
14
+
15
+ type OpenAIModel struct {
16
+ ID string `json:"id"`
17
+ Object string `json:"object"`
18
+ Created int64 `json:"created"`
19
+ OwnedBy string `json:"owned_by"`
20
+ Permission []struct {
21
+ ID string `json:"id"`
22
+ Object string `json:"object"`
23
+ Created int64 `json:"created"`
24
+ AllowCreateEngine bool `json:"allow_create_engine"`
25
+ AllowSampling bool `json:"allow_sampling"`
26
+ AllowLogprobs bool `json:"allow_logprobs"`
27
+ AllowSearchIndices bool `json:"allow_search_indices"`
28
+ AllowView bool `json:"allow_view"`
29
+ AllowFineTuning bool `json:"allow_fine_tuning"`
30
+ Organization string `json:"organization"`
31
+ Group string `json:"group"`
32
+ IsBlocking bool `json:"is_blocking"`
33
+ } `json:"permission"`
34
+ Root string `json:"root"`
35
+ Parent string `json:"parent"`
36
+ }
37
+
38
+ type OpenAIModelsResponse struct {
39
+ Data []OpenAIModel `json:"data"`
40
+ Success bool `json:"success"`
41
+ }
42
+
43
+ func GetAllChannels(c *gin.Context) {
44
+ p, _ := strconv.Atoi(c.Query("p"))
45
+ pageSize, _ := strconv.Atoi(c.Query("page_size"))
46
+ if p < 0 {
47
+ p = 0
48
+ }
49
+ if pageSize < 0 {
50
+ pageSize = common.ItemsPerPage
51
+ }
52
+ channelData := make([]*model.Channel, 0)
53
+ idSort, _ := strconv.ParseBool(c.Query("id_sort"))
54
+ enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
55
+ if enableTagMode {
56
+ tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
57
+ if err != nil {
58
+ c.JSON(http.StatusOK, gin.H{
59
+ "success": false,
60
+ "message": err.Error(),
61
+ })
62
+ return
63
+ }
64
+ for _, tag := range tags {
65
+ if tag != nil && *tag != "" {
66
+ tagChannel, err := model.GetChannelsByTag(*tag, idSort)
67
+ if err == nil {
68
+ channelData = append(channelData, tagChannel...)
69
+ }
70
+ }
71
+ }
72
+ } else {
73
+ channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
74
+ if err != nil {
75
+ c.JSON(http.StatusOK, gin.H{
76
+ "success": false,
77
+ "message": err.Error(),
78
+ })
79
+ return
80
+ }
81
+ channelData = channels
82
+ }
83
+ c.JSON(http.StatusOK, gin.H{
84
+ "success": true,
85
+ "message": "",
86
+ "data": channelData,
87
+ })
88
+ return
89
+ }
90
+
91
+ func FetchUpstreamModels(c *gin.Context) {
92
+ id, err := strconv.Atoi(c.Param("id"))
93
+ if err != nil {
94
+ c.JSON(http.StatusOK, gin.H{
95
+ "success": false,
96
+ "message": err.Error(),
97
+ })
98
+ return
99
+ }
100
+
101
+ channel, err := model.GetChannelById(id, true)
102
+ if err != nil {
103
+ c.JSON(http.StatusOK, gin.H{
104
+ "success": false,
105
+ "message": err.Error(),
106
+ })
107
+ return
108
+ }
109
+
110
+ //if channel.Type != common.ChannelTypeOpenAI {
111
+ // c.JSON(http.StatusOK, gin.H{
112
+ // "success": false,
113
+ // "message": "仅支持 OpenAI 类型渠道",
114
+ // })
115
+ // return
116
+ //}
117
+ baseURL := common.ChannelBaseURLs[channel.Type]
118
+ if channel.GetBaseURL() != "" {
119
+ baseURL = channel.GetBaseURL()
120
+ }
121
+ url := fmt.Sprintf("%s/v1/models", baseURL)
122
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
123
+ if err != nil {
124
+ c.JSON(http.StatusOK, gin.H{
125
+ "success": false,
126
+ "message": err.Error(),
127
+ })
128
+ return
129
+ }
130
+
131
+ var result OpenAIModelsResponse
132
+ if err = json.Unmarshal(body, &result); err != nil {
133
+ c.JSON(http.StatusOK, gin.H{
134
+ "success": false,
135
+ "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
136
+ })
137
+ return
138
+ }
139
+
140
+ var ids []string
141
+ for _, model := range result.Data {
142
+ ids = append(ids, model.ID)
143
+ }
144
+
145
+ c.JSON(http.StatusOK, gin.H{
146
+ "success": true,
147
+ "message": "",
148
+ "data": ids,
149
+ })
150
+ }
151
+
152
+ func FixChannelsAbilities(c *gin.Context) {
153
+ count, err := model.FixAbility()
154
+ if err != nil {
155
+ c.JSON(http.StatusOK, gin.H{
156
+ "success": false,
157
+ "message": err.Error(),
158
+ })
159
+ return
160
+ }
161
+ c.JSON(http.StatusOK, gin.H{
162
+ "success": true,
163
+ "message": "",
164
+ "data": count,
165
+ })
166
+ }
167
+
168
+ func SearchChannels(c *gin.Context) {
169
+ keyword := c.Query("keyword")
170
+ group := c.Query("group")
171
+ modelKeyword := c.Query("model")
172
+ idSort, _ := strconv.ParseBool(c.Query("id_sort"))
173
+ enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
174
+ channelData := make([]*model.Channel, 0)
175
+ if enableTagMode {
176
+ tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
177
+ if err != nil {
178
+ c.JSON(http.StatusOK, gin.H{
179
+ "success": false,
180
+ "message": err.Error(),
181
+ })
182
+ return
183
+ }
184
+ for _, tag := range tags {
185
+ if tag != nil && *tag != "" {
186
+ tagChannel, err := model.GetChannelsByTag(*tag, idSort)
187
+ if err == nil {
188
+ channelData = append(channelData, tagChannel...)
189
+ }
190
+ }
191
+ }
192
+ } else {
193
+ channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
194
+ if err != nil {
195
+ c.JSON(http.StatusOK, gin.H{
196
+ "success": false,
197
+ "message": err.Error(),
198
+ })
199
+ return
200
+ }
201
+ channelData = channels
202
+ }
203
+ c.JSON(http.StatusOK, gin.H{
204
+ "success": true,
205
+ "message": "",
206
+ "data": channelData,
207
+ })
208
+ return
209
+ }
210
+
211
+ func GetChannel(c *gin.Context) {
212
+ id, err := strconv.Atoi(c.Param("id"))
213
+ if err != nil {
214
+ c.JSON(http.StatusOK, gin.H{
215
+ "success": false,
216
+ "message": err.Error(),
217
+ })
218
+ return
219
+ }
220
+ channel, err := model.GetChannelById(id, false)
221
+ if err != nil {
222
+ c.JSON(http.StatusOK, gin.H{
223
+ "success": false,
224
+ "message": err.Error(),
225
+ })
226
+ return
227
+ }
228
+ c.JSON(http.StatusOK, gin.H{
229
+ "success": true,
230
+ "message": "",
231
+ "data": channel,
232
+ })
233
+ return
234
+ }
235
+
236
+ func AddChannel(c *gin.Context) {
237
+ channel := model.Channel{}
238
+ err := c.ShouldBindJSON(&channel)
239
+ if err != nil {
240
+ c.JSON(http.StatusOK, gin.H{
241
+ "success": false,
242
+ "message": err.Error(),
243
+ })
244
+ return
245
+ }
246
+ channel.CreatedTime = common.GetTimestamp()
247
+ keys := strings.Split(channel.Key, "\n")
248
+ if channel.Type == common.ChannelTypeVertexAi {
249
+ if channel.Other == "" {
250
+ c.JSON(http.StatusOK, gin.H{
251
+ "success": false,
252
+ "message": "部署地区不能为空",
253
+ })
254
+ return
255
+ } else {
256
+ if common.IsJsonStr(channel.Other) {
257
+ // must have default
258
+ regionMap := common.StrToMap(channel.Other)
259
+ if regionMap["default"] == nil {
260
+ c.JSON(http.StatusOK, gin.H{
261
+ "success": false,
262
+ "message": "部署地区必须包含default字段",
263
+ })
264
+ return
265
+ }
266
+ }
267
+ }
268
+ keys = []string{channel.Key}
269
+ }
270
+ channels := make([]model.Channel, 0, len(keys))
271
+ for _, key := range keys {
272
+ if key == "" {
273
+ continue
274
+ }
275
+ localChannel := channel
276
+ localChannel.Key = key
277
+ // Validate the length of the model name
278
+ models := strings.Split(localChannel.Models, ",")
279
+ for _, model := range models {
280
+ if len(model) > 255 {
281
+ c.JSON(http.StatusOK, gin.H{
282
+ "success": false,
283
+ "message": fmt.Sprintf("模型名称过长: %s", model),
284
+ })
285
+ return
286
+ }
287
+ }
288
+ channels = append(channels, localChannel)
289
+ }
290
+ err = model.BatchInsertChannels(channels)
291
+ if err != nil {
292
+ c.JSON(http.StatusOK, gin.H{
293
+ "success": false,
294
+ "message": err.Error(),
295
+ })
296
+ return
297
+ }
298
+ c.JSON(http.StatusOK, gin.H{
299
+ "success": true,
300
+ "message": "",
301
+ })
302
+ return
303
+ }
304
+
305
+ func DeleteChannel(c *gin.Context) {
306
+ id, _ := strconv.Atoi(c.Param("id"))
307
+ channel := model.Channel{Id: id}
308
+ err := channel.Delete()
309
+ if err != nil {
310
+ c.JSON(http.StatusOK, gin.H{
311
+ "success": false,
312
+ "message": err.Error(),
313
+ })
314
+ return
315
+ }
316
+ c.JSON(http.StatusOK, gin.H{
317
+ "success": true,
318
+ "message": "",
319
+ })
320
+ return
321
+ }
322
+
323
+ func DeleteDisabledChannel(c *gin.Context) {
324
+ rows, err := model.DeleteDisabledChannel()
325
+ if err != nil {
326
+ c.JSON(http.StatusOK, gin.H{
327
+ "success": false,
328
+ "message": err.Error(),
329
+ })
330
+ return
331
+ }
332
+ c.JSON(http.StatusOK, gin.H{
333
+ "success": true,
334
+ "message": "",
335
+ "data": rows,
336
+ })
337
+ return
338
+ }
339
+
340
+ type ChannelTag struct {
341
+ Tag string `json:"tag"`
342
+ NewTag *string `json:"new_tag"`
343
+ Priority *int64 `json:"priority"`
344
+ Weight *uint `json:"weight"`
345
+ ModelMapping *string `json:"model_mapping"`
346
+ Models *string `json:"models"`
347
+ Groups *string `json:"groups"`
348
+ }
349
+
350
+ func DisableTagChannels(c *gin.Context) {
351
+ channelTag := ChannelTag{}
352
+ err := c.ShouldBindJSON(&channelTag)
353
+ if err != nil || channelTag.Tag == "" {
354
+ c.JSON(http.StatusOK, gin.H{
355
+ "success": false,
356
+ "message": "参数错误",
357
+ })
358
+ return
359
+ }
360
+ err = model.DisableChannelByTag(channelTag.Tag)
361
+ if err != nil {
362
+ c.JSON(http.StatusOK, gin.H{
363
+ "success": false,
364
+ "message": err.Error(),
365
+ })
366
+ return
367
+ }
368
+ c.JSON(http.StatusOK, gin.H{
369
+ "success": true,
370
+ "message": "",
371
+ })
372
+ return
373
+ }
374
+
375
+ func EnableTagChannels(c *gin.Context) {
376
+ channelTag := ChannelTag{}
377
+ err := c.ShouldBindJSON(&channelTag)
378
+ if err != nil || channelTag.Tag == "" {
379
+ c.JSON(http.StatusOK, gin.H{
380
+ "success": false,
381
+ "message": "参数错误",
382
+ })
383
+ return
384
+ }
385
+ err = model.EnableChannelByTag(channelTag.Tag)
386
+ if err != nil {
387
+ c.JSON(http.StatusOK, gin.H{
388
+ "success": false,
389
+ "message": err.Error(),
390
+ })
391
+ return
392
+ }
393
+ c.JSON(http.StatusOK, gin.H{
394
+ "success": true,
395
+ "message": "",
396
+ })
397
+ return
398
+ }
399
+
400
+ func EditTagChannels(c *gin.Context) {
401
+ channelTag := ChannelTag{}
402
+ err := c.ShouldBindJSON(&channelTag)
403
+ if err != nil {
404
+ c.JSON(http.StatusOK, gin.H{
405
+ "success": false,
406
+ "message": "参数错误",
407
+ })
408
+ return
409
+ }
410
+ if channelTag.Tag == "" {
411
+ c.JSON(http.StatusOK, gin.H{
412
+ "success": false,
413
+ "message": "tag不能为空",
414
+ })
415
+ return
416
+ }
417
+ err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
418
+ if err != nil {
419
+ c.JSON(http.StatusOK, gin.H{
420
+ "success": false,
421
+ "message": err.Error(),
422
+ })
423
+ return
424
+ }
425
+ c.JSON(http.StatusOK, gin.H{
426
+ "success": true,
427
+ "message": "",
428
+ })
429
+ return
430
+ }
431
+
432
+ type ChannelBatch struct {
433
+ Ids []int `json:"ids"`
434
+ Tag *string `json:"tag"`
435
+ }
436
+
437
+ func DeleteChannelBatch(c *gin.Context) {
438
+ channelBatch := ChannelBatch{}
439
+ err := c.ShouldBindJSON(&channelBatch)
440
+ if err != nil || len(channelBatch.Ids) == 0 {
441
+ c.JSON(http.StatusOK, gin.H{
442
+ "success": false,
443
+ "message": "参数错误",
444
+ })
445
+ return
446
+ }
447
+ err = model.BatchDeleteChannels(channelBatch.Ids)
448
+ if err != nil {
449
+ c.JSON(http.StatusOK, gin.H{
450
+ "success": false,
451
+ "message": err.Error(),
452
+ })
453
+ return
454
+ }
455
+ c.JSON(http.StatusOK, gin.H{
456
+ "success": true,
457
+ "message": "",
458
+ "data": len(channelBatch.Ids),
459
+ })
460
+ return
461
+ }
462
+
463
+ func UpdateChannel(c *gin.Context) {
464
+ channel := model.Channel{}
465
+ err := c.ShouldBindJSON(&channel)
466
+ if err != nil {
467
+ c.JSON(http.StatusOK, gin.H{
468
+ "success": false,
469
+ "message": err.Error(),
470
+ })
471
+ return
472
+ }
473
+ if channel.Type == common.ChannelTypeVertexAi {
474
+ if channel.Other == "" {
475
+ c.JSON(http.StatusOK, gin.H{
476
+ "success": false,
477
+ "message": "部署地区不能为空",
478
+ })
479
+ return
480
+ } else {
481
+ if common.IsJsonStr(channel.Other) {
482
+ // must have default
483
+ regionMap := common.StrToMap(channel.Other)
484
+ if regionMap["default"] == nil {
485
+ c.JSON(http.StatusOK, gin.H{
486
+ "success": false,
487
+ "message": "部署地区必须包含default字段",
488
+ })
489
+ return
490
+ }
491
+ }
492
+ }
493
+ }
494
+ err = channel.Update()
495
+ if err != nil {
496
+ c.JSON(http.StatusOK, gin.H{
497
+ "success": false,
498
+ "message": err.Error(),
499
+ })
500
+ return
501
+ }
502
+ c.JSON(http.StatusOK, gin.H{
503
+ "success": true,
504
+ "message": "",
505
+ "data": channel,
506
+ })
507
+ return
508
+ }
509
+
510
+ func FetchModels(c *gin.Context) {
511
+ var req struct {
512
+ BaseURL string `json:"base_url"`
513
+ Type int `json:"type"`
514
+ Key string `json:"key"`
515
+ }
516
+
517
+ if err := c.ShouldBindJSON(&req); err != nil {
518
+ c.JSON(http.StatusBadRequest, gin.H{
519
+ "success": false,
520
+ "message": "Invalid request",
521
+ })
522
+ return
523
+ }
524
+
525
+ baseURL := req.BaseURL
526
+ if baseURL == "" {
527
+ baseURL = common.ChannelBaseURLs[req.Type]
528
+ }
529
+
530
+ client := &http.Client{}
531
+ url := fmt.Sprintf("%s/v1/models", baseURL)
532
+
533
+ request, err := http.NewRequest("GET", url, nil)
534
+ if err != nil {
535
+ c.JSON(http.StatusInternalServerError, gin.H{
536
+ "success": false,
537
+ "message": err.Error(),
538
+ })
539
+ return
540
+ }
541
+
542
+ // remove line breaks and extra spaces.
543
+ key := strings.TrimSpace(req.Key)
544
+ // If the key contains a line break, only take the first part.
545
+ key = strings.Split(key, "\n")[0]
546
+ request.Header.Set("Authorization", "Bearer "+key)
547
+
548
+ response, err := client.Do(request)
549
+ if err != nil {
550
+ c.JSON(http.StatusInternalServerError, gin.H{
551
+ "success": false,
552
+ "message": err.Error(),
553
+ })
554
+ return
555
+ }
556
+ //check status code
557
+ if response.StatusCode != http.StatusOK {
558
+ c.JSON(http.StatusInternalServerError, gin.H{
559
+ "success": false,
560
+ "message": "Failed to fetch models",
561
+ })
562
+ return
563
+ }
564
+ defer response.Body.Close()
565
+
566
+ var result struct {
567
+ Data []struct {
568
+ ID string `json:"id"`
569
+ } `json:"data"`
570
+ }
571
+
572
+ if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
573
+ c.JSON(http.StatusInternalServerError, gin.H{
574
+ "success": false,
575
+ "message": err.Error(),
576
+ })
577
+ return
578
+ }
579
+
580
+ var models []string
581
+ for _, model := range result.Data {
582
+ models = append(models, model.ID)
583
+ }
584
+
585
+ c.JSON(http.StatusOK, gin.H{
586
+ "success": true,
587
+ "data": models,
588
+ })
589
+ }
590
+
591
+ func BatchSetChannelTag(c *gin.Context) {
592
+ channelBatch := ChannelBatch{}
593
+ err := c.ShouldBindJSON(&channelBatch)
594
+ if err != nil || len(channelBatch.Ids) == 0 {
595
+ c.JSON(http.StatusOK, gin.H{
596
+ "success": false,
597
+ "message": "参数错误",
598
+ })
599
+ return
600
+ }
601
+ err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
602
+ if err != nil {
603
+ c.JSON(http.StatusOK, gin.H{
604
+ "success": false,
605
+ "message": err.Error(),
606
+ })
607
+ return
608
+ }
609
+ c.JSON(http.StatusOK, gin.H{
610
+ "success": true,
611
+ "message": "",
612
+ "data": len(channelBatch.Ids),
613
+ })
614
+ return
615
+ }
controller/github.go ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package controller
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "errors"
7
+ "fmt"
8
+ "github.com/gin-contrib/sessions"
9
+ "github.com/gin-gonic/gin"
10
+ "net/http"
11
+ "one-api/common"
12
+ "one-api/model"
13
+ "strconv"
14
+ "time"
15
+ )
16
+
17
+ type GitHubOAuthResponse struct {
18
+ AccessToken string `json:"access_token"`
19
+ Scope string `json:"scope"`
20
+ TokenType string `json:"token_type"`
21
+ }
22
+
23
+ type GitHubUser struct {
24
+ Login string `json:"login"`
25
+ Name string `json:"name"`
26
+ Email string `json:"email"`
27
+ }
28
+
29
+ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
30
+ if code == "" {
31
+ return nil, errors.New("无效的参数")
32
+ }
33
+ values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
34
+ jsonData, err := json.Marshal(values)
35
+ if err != nil {
36
+ return nil, err
37
+ }
38
+ req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
39
+ if err != nil {
40
+ return nil, err
41
+ }
42
+ req.Header.Set("Content-Type", "application/json")
43
+ req.Header.Set("Accept", "application/json")
44
+ client := http.Client{
45
+ Timeout: 5 * time.Second,
46
+ }
47
+ res, err := client.Do(req)
48
+ if err != nil {
49
+ common.SysLog(err.Error())
50
+ return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
51
+ }
52
+ defer res.Body.Close()
53
+ var oAuthResponse GitHubOAuthResponse
54
+ err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
55
+ if err != nil {
56
+ return nil, err
57
+ }
58
+ req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
59
+ if err != nil {
60
+ return nil, err
61
+ }
62
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
63
+ res2, err := client.Do(req)
64
+ if err != nil {
65
+ common.SysLog(err.Error())
66
+ return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
67
+ }
68
+ defer res2.Body.Close()
69
+ var githubUser GitHubUser
70
+ err = json.NewDecoder(res2.Body).Decode(&githubUser)
71
+ if err != nil {
72
+ return nil, err
73
+ }
74
+ if githubUser.Login == "" {
75
+ return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
76
+ }
77
+ return &githubUser, nil
78
+ }
79
+
80
+ func GitHubOAuth(c *gin.Context) {
81
+ session := sessions.Default(c)
82
+ state := c.Query("state")
83
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
84
+ c.JSON(http.StatusForbidden, gin.H{
85
+ "success": false,
86
+ "message": "state is empty or not same",
87
+ })
88
+ return
89
+ }
90
+ username := session.Get("username")
91
+ if username != nil {
92
+ GitHubBind(c)
93
+ return
94
+ }
95
+
96
+ if !common.GitHubOAuthEnabled {
97
+ c.JSON(http.StatusOK, gin.H{
98
+ "success": false,
99
+ "message": "管理员未开启通过 GitHub 登录以及注册",
100
+ })
101
+ return
102
+ }
103
+ code := c.Query("code")
104
+ githubUser, err := getGitHubUserInfoByCode(code)
105
+ if err != nil {
106
+ c.JSON(http.StatusOK, gin.H{
107
+ "success": false,
108
+ "message": err.Error(),
109
+ })
110
+ return
111
+ }
112
+ user := model.User{
113
+ GitHubId: githubUser.Login,
114
+ }
115
+ // IsGitHubIdAlreadyTaken is unscoped
116
+ if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
117
+ // FillUserByGitHubId is scoped
118
+ err := user.FillUserByGitHubId()
119
+ if err != nil {
120
+ c.JSON(http.StatusOK, gin.H{
121
+ "success": false,
122
+ "message": err.Error(),
123
+ })
124
+ return
125
+ }
126
+ // if user.Id == 0 , user has been deleted
127
+ if user.Id == 0 {
128
+ c.JSON(http.StatusOK, gin.H{
129
+ "success": false,
130
+ "message": "用户已注销",
131
+ })
132
+ return
133
+ }
134
+ } else {
135
+ if common.RegisterEnabled {
136
+ user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
137
+ if githubUser.Name != "" {
138
+ user.DisplayName = githubUser.Name
139
+ } else {
140
+ user.DisplayName = "GitHub User"
141
+ }
142
+ user.Email = githubUser.Email
143
+ user.Role = common.RoleCommonUser
144
+ user.Status = common.UserStatusEnabled
145
+ affCode := session.Get("aff")
146
+ inviterId := 0
147
+ if affCode != nil {
148
+ inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
149
+ }
150
+
151
+ if err := user.Insert(inviterId); err != nil {
152
+ c.JSON(http.StatusOK, gin.H{
153
+ "success": false,
154
+ "message": err.Error(),
155
+ })
156
+ return
157
+ }
158
+ } else {
159
+ c.JSON(http.StatusOK, gin.H{
160
+ "success": false,
161
+ "message": "管理员关闭了新用户注册",
162
+ })
163
+ return
164
+ }
165
+ }
166
+
167
+ if user.Status != common.UserStatusEnabled {
168
+ c.JSON(http.StatusOK, gin.H{
169
+ "message": "用户已被封禁",
170
+ "success": false,
171
+ })
172
+ return
173
+ }
174
+ setupLogin(&user, c)
175
+ }
176
+
177
+ func GitHubBind(c *gin.Context) {
178
+ if !common.GitHubOAuthEnabled {
179
+ c.JSON(http.StatusOK, gin.H{
180
+ "success": false,
181
+ "message": "管理员未开启通过 GitHub 登录以及注册",
182
+ })
183
+ return
184
+ }
185
+ code := c.Query("code")
186
+ githubUser, err := getGitHubUserInfoByCode(code)
187
+ if err != nil {
188
+ c.JSON(http.StatusOK, gin.H{
189
+ "success": false,
190
+ "message": err.Error(),
191
+ })
192
+ return
193
+ }
194
+ user := model.User{
195
+ GitHubId: githubUser.Login,
196
+ }
197
+ if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
198
+ c.JSON(http.StatusOK, gin.H{
199
+ "success": false,
200
+ "message": "该 GitHub 账户已被绑定",
201
+ })
202
+ return
203
+ }
204
+ session := sessions.Default(c)
205
+ id := session.Get("id")
206
+ // id := c.GetInt("id") // critical bug!
207
+ user.Id = id.(int)
208
+ err = user.FillUserById()
209
+ if err != nil {
210
+ c.JSON(http.StatusOK, gin.H{
211
+ "success": false,
212
+ "message": err.Error(),
213
+ })
214
+ return
215
+ }
216
+ user.GitHubId = githubUser.Login
217
+ err = user.Update(false)
218
+ if err != nil {
219
+ c.JSON(http.StatusOK, gin.H{
220
+ "success": false,
221
+ "message": err.Error(),
222
+ })
223
+ return
224
+ }
225
+ c.JSON(http.StatusOK, gin.H{
226
+ "success": true,
227
+ "message": "bind",
228
+ })
229
+ return
230
+ }
231
+
232
+ func GenerateOAuthCode(c *gin.Context) {
233
+ session := sessions.Default(c)
234
+ state := common.GetRandomString(12)
235
+ affCode := c.Query("aff")
236
+ if affCode != "" {
237
+ session.Set("aff", affCode)
238
+ }
239
+ session.Set("oauth_state", state)
240
+ err := session.Save()
241
+ if err != nil {
242
+ c.JSON(http.StatusOK, gin.H{
243
+ "success": false,
244
+ "message": err.Error(),
245
+ })
246
+ return
247
+ }
248
+ c.JSON(http.StatusOK, gin.H{
249
+ "success": true,
250
+ "message": "",
251
+ "data": state,
252
+ })
253
+ }