Spaces:
Sleeping
Sleeping
force update files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env +19 -0
- .gitattributes +0 -0
- README.md +0 -0
- go.mod +78 -0
- internal/database/db.go +31 -0
- internal/handler/account.go +774 -0
- internal/handler/anthropic.go +57 -0
- internal/handler/chat.go +48 -0
- internal/handler/external.go +209 -0
- internal/handler/gemini.go +77 -0
- internal/handler/grok.go +53 -0
- internal/handler/oauth.go +331 -0
- internal/handler/openai.go +96 -0
- internal/handler/token.go +270 -0
- internal/middleware/auth.go +126 -0
- internal/model/account.go +65 -0
- internal/model/debug.go +29 -0
- internal/model/openai.go +53 -0
- internal/model/token_record.go +48 -0
- internal/model/zenmodel.go +223 -0
- internal/service/anthropic.go +1602 -0
- internal/service/api.go +143 -0
- internal/service/autogen.go +453 -0
- internal/service/credential.go +122 -0
- internal/service/debug.go +185 -0
- internal/service/errors.go +10 -0
- internal/service/gemini.go +360 -0
- internal/service/grok.go +273 -0
- internal/service/headers.go +84 -0
- internal/service/jwt.go +74 -0
- internal/service/openai.go +868 -0
- internal/service/pool.go +766 -0
- internal/service/provider/anthropic.go +164 -0
- internal/service/provider/client.go +69 -0
- internal/service/provider/errors.go +10 -0
- internal/service/provider/factory.go +27 -0
- internal/service/provider/gemini.go +124 -0
- internal/service/provider/grok.go +149 -0
- internal/service/provider/manager.go +73 -0
- internal/service/provider/openai.go +149 -0
- internal/service/provider/provider.go +77 -0
- internal/service/provider/proxy.go +251 -0
- internal/service/proxy_client.go +143 -0
- internal/service/refresh.go +764 -0
- internal/service/request.go +3 -0
- internal/service/scheduler.go +39 -0
- internal/service/stream.go +57 -0
- internal/service/token.go +128 -0
- internal/service/zencoder.go +217 -0
- main.go +112 -0
.env
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server Configuration
|
| 2 |
+
PORT=8080
|
| 3 |
+
DEBUG=false
|
| 4 |
+
|
| 5 |
+
# Database Configuration
|
| 6 |
+
DB_PATH=data.db
|
| 7 |
+
|
| 8 |
+
# Authentication
|
| 9 |
+
# Global authentication token for accessing the API
|
| 10 |
+
AUTH_TOKEN=your_secret_token_here
|
| 11 |
+
|
| 12 |
+
# Admin Management Password
|
| 13 |
+
# Password for accessing account management endpoints
|
| 14 |
+
ADMIN_PASSWORD=your_secret_token_here
|
| 15 |
+
|
| 16 |
+
# SOCKS Proxy Pool for API Requests Retry
|
| 17 |
+
# Comma-separated list of SOCKS5 proxies with format: socks5://host:port:username:password
|
| 18 |
+
SOCKS_PROXY_POOL=
|
| 19 |
+
|
.gitattributes
CHANGED
|
File without changes
|
README.md
CHANGED
|
File without changes
|
go.mod
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module zencoder-2api
|
| 2 |
+
|
| 3 |
+
go 1.21
|
| 4 |
+
|
| 5 |
+
require (
|
| 6 |
+
github.com/anthropics/anthropic-sdk-go v0.2.0-alpha.13
|
| 7 |
+
github.com/gin-gonic/gin v1.9.1
|
| 8 |
+
github.com/glebarez/sqlite v1.10.0
|
| 9 |
+
github.com/google/generative-ai-go v0.19.0
|
| 10 |
+
github.com/google/uuid v1.6.0
|
| 11 |
+
github.com/joho/godotenv v1.5.1
|
| 12 |
+
github.com/openai/openai-go v0.1.0-alpha.44
|
| 13 |
+
google.golang.org/api v0.214.0
|
| 14 |
+
gorm.io/gorm v1.25.7
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
require (
|
| 18 |
+
cloud.google.com/go v0.115.0 // indirect
|
| 19 |
+
cloud.google.com/go/ai v0.8.0 // indirect
|
| 20 |
+
cloud.google.com/go/auth v0.13.0 // indirect
|
| 21 |
+
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
|
| 22 |
+
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
| 23 |
+
cloud.google.com/go/longrunning v0.5.7 // indirect
|
| 24 |
+
github.com/bytedance/sonic v1.9.1 // indirect
|
| 25 |
+
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
| 26 |
+
github.com/dustin/go-humanize v1.0.1 // indirect
|
| 27 |
+
github.com/felixge/httpsnoop v1.0.4 // indirect
|
| 28 |
+
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
| 29 |
+
github.com/gin-contrib/sse v0.1.0 // indirect
|
| 30 |
+
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
| 31 |
+
github.com/go-logr/logr v1.4.2 // indirect
|
| 32 |
+
github.com/go-logr/stdr v1.2.2 // indirect
|
| 33 |
+
github.com/go-playground/locales v0.14.1 // indirect
|
| 34 |
+
github.com/go-playground/universal-translator v0.18.1 // indirect
|
| 35 |
+
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
| 36 |
+
github.com/goccy/go-json v0.10.2 // indirect
|
| 37 |
+
github.com/google/s2a-go v0.1.8 // indirect
|
| 38 |
+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
| 39 |
+
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
|
| 40 |
+
github.com/jinzhu/inflection v1.0.0 // indirect
|
| 41 |
+
github.com/jinzhu/now v1.1.5 // indirect
|
| 42 |
+
github.com/json-iterator/go v1.1.12 // indirect
|
| 43 |
+
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
| 44 |
+
github.com/leodido/go-urn v1.2.4 // indirect
|
| 45 |
+
github.com/mattn/go-isatty v0.0.19 // indirect
|
| 46 |
+
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
| 47 |
+
github.com/modern-go/reflect2 v1.0.2 // indirect
|
| 48 |
+
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
| 49 |
+
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
| 50 |
+
github.com/tidwall/gjson v1.14.4 // indirect
|
| 51 |
+
github.com/tidwall/match v1.1.1 // indirect
|
| 52 |
+
github.com/tidwall/pretty v1.2.1 // indirect
|
| 53 |
+
github.com/tidwall/sjson v1.2.5 // indirect
|
| 54 |
+
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
| 55 |
+
github.com/ugorji/go/codec v1.2.11 // indirect
|
| 56 |
+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
|
| 57 |
+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
|
| 58 |
+
go.opentelemetry.io/otel v1.29.0 // indirect
|
| 59 |
+
go.opentelemetry.io/otel/metric v1.29.0 // indirect
|
| 60 |
+
go.opentelemetry.io/otel/trace v1.29.0 // indirect
|
| 61 |
+
golang.org/x/arch v0.3.0 // indirect
|
| 62 |
+
golang.org/x/crypto v0.31.0 // indirect
|
| 63 |
+
golang.org/x/net v0.33.0 // indirect
|
| 64 |
+
golang.org/x/oauth2 v0.24.0 // indirect
|
| 65 |
+
golang.org/x/sync v0.10.0 // indirect
|
| 66 |
+
golang.org/x/sys v0.28.0 // indirect
|
| 67 |
+
golang.org/x/text v0.21.0 // indirect
|
| 68 |
+
golang.org/x/time v0.8.0 // indirect
|
| 69 |
+
google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect
|
| 70 |
+
google.golang.org/genproto/googleapis/rpc v0.0.0-20241209162323-e6fa225c2576 // indirect
|
| 71 |
+
google.golang.org/grpc v1.67.1 // indirect
|
| 72 |
+
google.golang.org/protobuf v1.35.2 // indirect
|
| 73 |
+
gopkg.in/yaml.v3 v3.0.1 // indirect
|
| 74 |
+
modernc.org/libc v1.22.5 // indirect
|
| 75 |
+
modernc.org/mathutil v1.5.0 // indirect
|
| 76 |
+
modernc.org/memory v1.5.0 // indirect
|
| 77 |
+
modernc.org/sqlite v1.23.1 // indirect
|
| 78 |
+
)
|
internal/database/db.go
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package database
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"zencoder-2api/internal/model"
|
| 5 |
+
|
| 6 |
+
"github.com/glebarez/sqlite"
|
| 7 |
+
"gorm.io/gorm"
|
| 8 |
+
"gorm.io/gorm/logger"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
var DB *gorm.DB
|
| 12 |
+
|
| 13 |
+
func Init(dbPath string) error {
|
| 14 |
+
var err error
|
| 15 |
+
DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
| 16 |
+
Logger: logger.Default.LogMode(logger.Silent), // 完全关闭日志输出
|
| 17 |
+
})
|
| 18 |
+
if err != nil {
|
| 19 |
+
return err
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
return DB.AutoMigrate(
|
| 23 |
+
&model.Account{},
|
| 24 |
+
&model.TokenRecord{},
|
| 25 |
+
&model.GenerationTask{},
|
| 26 |
+
)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
func GetDB() *gorm.DB {
|
| 30 |
+
return DB
|
| 31 |
+
}
|
internal/handler/account.go
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"log"
|
| 6 |
+
"net/http"
|
| 7 |
+
"strconv"
|
| 8 |
+
"strings"
|
| 9 |
+
"time"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"zencoder-2api/internal/database"
|
| 13 |
+
"zencoder-2api/internal/model"
|
| 14 |
+
"zencoder-2api/internal/service"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
type AccountHandler struct{}
|
| 18 |
+
|
| 19 |
+
func NewAccountHandler() *AccountHandler {
|
| 20 |
+
return &AccountHandler{}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
func (h *AccountHandler) List(c *gin.Context) {
|
| 24 |
+
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
| 25 |
+
size, _ := strconv.Atoi(c.DefaultQuery("size", "10"))
|
| 26 |
+
|
| 27 |
+
// 兼容旧的 category 参数,优先使用 status
|
| 28 |
+
status := c.DefaultQuery("status", "")
|
| 29 |
+
if status == "" {
|
| 30 |
+
category := c.DefaultQuery("category", "normal")
|
| 31 |
+
if category == "abnormal" {
|
| 32 |
+
status = "cooling"
|
| 33 |
+
} else {
|
| 34 |
+
status = category
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
if page < 1 {
|
| 39 |
+
page = 1
|
| 40 |
+
}
|
| 41 |
+
if size < 1 {
|
| 42 |
+
size = 10
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
var accounts []model.Account
|
| 46 |
+
var total int64
|
| 47 |
+
|
| 48 |
+
query := database.GetDB().Model(&model.Account{})
|
| 49 |
+
if status != "all" {
|
| 50 |
+
query = query.Where("status = ?", status)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
if err := query.Count(&total).Error; err != nil {
|
| 54 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 55 |
+
return
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
offset := (page - 1) * size
|
| 59 |
+
if err := query.Offset(offset).Limit(size).Order("id desc").Find(&accounts).Error; err != nil {
|
| 60 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 61 |
+
return
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// 调试日志:输出冷却账号的信息
|
| 65 |
+
if status == "cooling" {
|
| 66 |
+
for _, acc := range accounts {
|
| 67 |
+
if !acc.CoolingUntil.IsZero() {
|
| 68 |
+
log.Printf("[DEBUG] 冷却账号 %s (ID:%d) - CoolingUntil: %s (UTC), 现在: %s (UTC)",
|
| 69 |
+
acc.Email, acc.ID,
|
| 70 |
+
acc.CoolingUntil.Format("2006-01-02 15:04:05"),
|
| 71 |
+
time.Now().UTC().Format("2006-01-02 15:04:05"))
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Calculate Stats
|
| 77 |
+
var stats struct {
|
| 78 |
+
TotalAccounts int64 `json:"total_accounts"`
|
| 79 |
+
NormalAccounts int64 `json:"normal_accounts"` // 原 active_accounts
|
| 80 |
+
BannedAccounts int64 `json:"banned_accounts"`
|
| 81 |
+
ErrorAccounts int64 `json:"error_accounts"`
|
| 82 |
+
CoolingAccounts int64 `json:"cooling_accounts"`
|
| 83 |
+
DisabledAccounts int64 `json:"disabled_accounts"`
|
| 84 |
+
TodayUsage float64 `json:"today_usage"`
|
| 85 |
+
TotalUsage float64 `json:"total_usage"`
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
db := database.GetDB()
|
| 89 |
+
|
| 90 |
+
db.Model(&model.Account{}).Count(&stats.TotalAccounts)
|
| 91 |
+
db.Model(&model.Account{}).Where("status = ?", "normal").Count(&stats.NormalAccounts)
|
| 92 |
+
db.Model(&model.Account{}).Where("status = ?", "banned").Count(&stats.BannedAccounts)
|
| 93 |
+
db.Model(&model.Account{}).Where("status = ?", "error").Count(&stats.ErrorAccounts)
|
| 94 |
+
db.Model(&model.Account{}).Where("status = ?", "cooling").Count(&stats.CoolingAccounts)
|
| 95 |
+
db.Model(&model.Account{}).Where("status = ?", "disabled").Count(&stats.DisabledAccounts)
|
| 96 |
+
|
| 97 |
+
db.Model(&model.Account{}).Select("COALESCE(SUM(daily_used), 0)").Scan(&stats.TodayUsage)
|
| 98 |
+
db.Model(&model.Account{}).Select("COALESCE(SUM(total_used), 0)").Scan(&stats.TotalUsage)
|
| 99 |
+
|
| 100 |
+
// 兼容前端旧字段
|
| 101 |
+
statsMap := map[string]interface{}{
|
| 102 |
+
"total_accounts": stats.TotalAccounts,
|
| 103 |
+
"active_accounts": stats.NormalAccounts,
|
| 104 |
+
"banned_accounts": stats.BannedAccounts,
|
| 105 |
+
"error_accounts": stats.ErrorAccounts,
|
| 106 |
+
"cooling_accounts": stats.CoolingAccounts,
|
| 107 |
+
"disabled_accounts": stats.DisabledAccounts,
|
| 108 |
+
"today_usage": stats.TodayUsage,
|
| 109 |
+
"total_usage": stats.TotalUsage,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
c.JSON(http.StatusOK, gin.H{
|
| 113 |
+
"items": accounts,
|
| 114 |
+
"total": total,
|
| 115 |
+
"page": page,
|
| 116 |
+
"size": size,
|
| 117 |
+
"stats": statsMap,
|
| 118 |
+
})
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
type BatchCategoryRequest struct {
|
| 122 |
+
IDs []uint `json:"ids"`
|
| 123 |
+
Category string `json:"category"` // 前端可能还传 category
|
| 124 |
+
Status string `json:"status"`
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
func (h *AccountHandler) BatchUpdateCategory(c *gin.Context) {
|
| 128 |
+
var req BatchCategoryRequest
|
| 129 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 130 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 131 |
+
return
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if len(req.IDs) == 0 {
|
| 135 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "no ids provided"})
|
| 136 |
+
return
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
status := req.Status
|
| 140 |
+
if status == "" {
|
| 141 |
+
status = req.Category
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
updates := map[string]interface{}{
|
| 145 |
+
"status": status,
|
| 146 |
+
// 兼容旧字段
|
| 147 |
+
"category": status,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
switch status {
|
| 151 |
+
case "normal":
|
| 152 |
+
updates["is_active"] = true
|
| 153 |
+
updates["is_cooling"] = false
|
| 154 |
+
case "cooling":
|
| 155 |
+
updates["is_active"] = true // cooling 也是 active 的一种? 不,cooling 不参与轮询
|
| 156 |
+
updates["is_cooling"] = true
|
| 157 |
+
case "disabled":
|
| 158 |
+
updates["is_active"] = false
|
| 159 |
+
updates["is_cooling"] = false
|
| 160 |
+
default: // banned, error
|
| 161 |
+
updates["is_active"] = false
|
| 162 |
+
updates["is_cooling"] = false
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
if err := database.GetDB().Model(&model.Account{}).Where("id IN ?", req.IDs).Updates(updates).Error; err != nil {
|
| 166 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 167 |
+
return
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
// 触发 refresh? 为了性能这里不触发,等待自动刷新
|
| 171 |
+
c.JSON(http.StatusOK, gin.H{"message": "updated", "count": len(req.IDs)})
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
type MoveAllRequest struct {
|
| 175 |
+
FromStatus string `json:"from_status"`
|
| 176 |
+
ToStatus string `json:"to_status"`
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// BatchMoveAll 一键移动某个分类的所有账号到另一个分类
|
| 180 |
+
func (h *AccountHandler) BatchMoveAll(c *gin.Context) {
|
| 181 |
+
var req MoveAllRequest
|
| 182 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 183 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 184 |
+
return
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
if req.FromStatus == "" || req.ToStatus == "" {
|
| 188 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "from_status and to_status are required"})
|
| 189 |
+
return
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
if req.FromStatus == req.ToStatus {
|
| 193 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "from_status and to_status cannot be the same"})
|
| 194 |
+
return
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
updates := map[string]interface{}{
|
| 198 |
+
"status": req.ToStatus,
|
| 199 |
+
// 兼容旧字段
|
| 200 |
+
"category": req.ToStatus,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
// 根据目标状态设置相应的标志
|
| 204 |
+
switch req.ToStatus {
|
| 205 |
+
case "normal":
|
| 206 |
+
updates["is_active"] = true
|
| 207 |
+
updates["is_cooling"] = false
|
| 208 |
+
updates["error_count"] = 0
|
| 209 |
+
updates["ban_reason"] = ""
|
| 210 |
+
case "cooling":
|
| 211 |
+
updates["is_active"] = false
|
| 212 |
+
updates["is_cooling"] = true
|
| 213 |
+
updates["ban_reason"] = ""
|
| 214 |
+
case "disabled":
|
| 215 |
+
updates["is_active"] = false
|
| 216 |
+
updates["is_cooling"] = false
|
| 217 |
+
updates["ban_reason"] = ""
|
| 218 |
+
case "banned":
|
| 219 |
+
updates["is_active"] = false
|
| 220 |
+
updates["is_cooling"] = false
|
| 221 |
+
case "error":
|
| 222 |
+
updates["is_active"] = false
|
| 223 |
+
updates["is_cooling"] = false
|
| 224 |
+
default:
|
| 225 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid to_status"})
|
| 226 |
+
return
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// 执行批量更新
|
| 230 |
+
result := database.GetDB().Model(&model.Account{}).Where("status = ?", req.FromStatus).Updates(updates)
|
| 231 |
+
if result.Error != nil {
|
| 232 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
| 233 |
+
return
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
log.Printf("[批量移动] 从 %s 移动到 %s,影响 %d 个账号", req.FromStatus, req.ToStatus, result.RowsAffected)
|
| 237 |
+
|
| 238 |
+
c.JSON(http.StatusOK, gin.H{
|
| 239 |
+
"message": "moved successfully",
|
| 240 |
+
"moved_count": result.RowsAffected,
|
| 241 |
+
"from_status": req.FromStatus,
|
| 242 |
+
"to_status": req.ToStatus,
|
| 243 |
+
})
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
type BatchRefreshTokenRequest struct {
|
| 247 |
+
IDs []uint `json:"ids"` // 选中的账号IDs,如果为空则刷新所有账号
|
| 248 |
+
All bool `json:"all"` // 是否刷新所有账号
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
type BatchDeleteRequest struct {
|
| 252 |
+
IDs []uint `json:"ids"` // 选中的账号IDs
|
| 253 |
+
DeleteAll bool `json:"delete_all"` // 是否删除分类中的所有账号
|
| 254 |
+
Status string `json:"status"` // 要删除的分类状态
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// BatchRefreshToken 批量刷新账号token
|
| 258 |
+
func (h *AccountHandler) BatchRefreshToken(c *gin.Context) {
|
| 259 |
+
var req BatchRefreshTokenRequest
|
| 260 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 261 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 262 |
+
return
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// 设置流式响应头
|
| 266 |
+
c.Header("Content-Type", "text/event-stream")
|
| 267 |
+
c.Header("Cache-Control", "no-cache")
|
| 268 |
+
c.Header("Connection", "keep-alive")
|
| 269 |
+
c.Header("X-Accel-Buffering", "no")
|
| 270 |
+
|
| 271 |
+
flusher, ok := c.Writer.(http.Flusher)
|
| 272 |
+
if !ok {
|
| 273 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "流式传输不支持"})
|
| 274 |
+
return
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
var accounts []model.Account
|
| 278 |
+
var err error
|
| 279 |
+
|
| 280 |
+
// 根据请求类型获取要刷新的账号
|
| 281 |
+
if req.All {
|
| 282 |
+
// 刷新所有状态为normal且有refresh_token的账号
|
| 283 |
+
err = database.GetDB().Where("status = ? AND (client_id != '' AND client_secret != '')", "normal").Find(&accounts).Error
|
| 284 |
+
if err != nil {
|
| 285 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取账号列表失败: " + err.Error()})
|
| 286 |
+
return
|
| 287 |
+
}
|
| 288 |
+
log.Printf("[批量刷新Token] 准备刷新所有正常账号,共 %d 个", len(accounts))
|
| 289 |
+
} else if len(req.IDs) > 0 {
|
| 290 |
+
// 刷新选中的账号
|
| 291 |
+
err = database.GetDB().Where("id IN ? AND (client_id != '' AND client_secret != '')", req.IDs).Find(&accounts).Error
|
| 292 |
+
if err != nil {
|
| 293 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取选中账号失败: " + err.Error()})
|
| 294 |
+
return
|
| 295 |
+
}
|
| 296 |
+
log.Printf("[批量刷新Token] 准备刷新选中账号,共 %d 个", len(accounts))
|
| 297 |
+
} else {
|
| 298 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "请选择要刷新的账号或选择刷新所有账号"})
|
| 299 |
+
return
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
if len(accounts) == 0 {
|
| 303 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "没有找到可刷新的账号"})
|
| 304 |
+
return
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// 发送开始消息
|
| 308 |
+
fmt.Fprintf(c.Writer, "data: {\"type\":\"start\",\"total\":%d}\n\n", len(accounts))
|
| 309 |
+
flusher.Flush()
|
| 310 |
+
|
| 311 |
+
successCount := 0
|
| 312 |
+
failCount := 0
|
| 313 |
+
|
| 314 |
+
// 逐个刷新token
|
| 315 |
+
for i, account := range accounts {
|
| 316 |
+
log.Printf("[批量刷新Token] 开始刷新第 %d/%d 个账号: %s (ID:%d)", i+1, len(accounts), account.ClientID, account.ID)
|
| 317 |
+
|
| 318 |
+
// 使用OAuth client credentials刷新token
|
| 319 |
+
if err := service.RefreshAccountToken(&account); err != nil {
|
| 320 |
+
failCount++
|
| 321 |
+
errMsg := fmt.Sprintf("刷新失败: %v", err)
|
| 322 |
+
|
| 323 |
+
// 检查是否是账号锁定错误
|
| 324 |
+
if lockoutErr, ok := err.(*service.AccountLockoutError); ok {
|
| 325 |
+
errMsg = fmt.Sprintf("账号被锁定已自动标记为封禁: %s", lockoutErr.Body)
|
| 326 |
+
log.Printf("[批量刷新Token] 第 %d/%d 个账号被锁定: %s - %s", i+1, len(accounts), account.ClientID, lockoutErr.Body)
|
| 327 |
+
} else {
|
| 328 |
+
log.Printf("[批量刷新Token] 第 %d/%d 个账号刷新失败: %s - %v", i+1, len(accounts), account.ClientID, err)
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
fmt.Fprintf(c.Writer, "data: {\"type\":\"error\",\"index\":%d,\"account_id\":\"%s\",\"message\":\"%s\"}\n\n", i+1, account.ClientID, errMsg)
|
| 332 |
+
flusher.Flush()
|
| 333 |
+
} else {
|
| 334 |
+
successCount++
|
| 335 |
+
log.Printf("[批量刷新Token] 第 %d/%d 个账号刷新成功: %s (ID:%d)", i+1, len(accounts), account.ClientID, account.ID)
|
| 336 |
+
fmt.Fprintf(c.Writer, "data: {\"type\":\"success\",\"index\":%d,\"account_id\":\"%s\",\"email\":\"%s\"}\n\n", i+1, account.ClientID, account.Email)
|
| 337 |
+
flusher.Flush()
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
// 添加延迟避免请求过快
|
| 341 |
+
if i < len(accounts)-1 {
|
| 342 |
+
time.Sleep(200 * time.Millisecond)
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
// 发送完成消息
|
| 347 |
+
log.Printf("[批量刷新Token] 完成: 成功 %d 个, 失败 %d 个", successCount, failCount)
|
| 348 |
+
fmt.Fprintf(c.Writer, "data: {\"type\":\"complete\",\"success\":%d,\"fail\":%d}\n\n", successCount, failCount)
|
| 349 |
+
flusher.Flush()
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
func (h *AccountHandler) Create(c *gin.Context) {
|
| 353 |
+
var req model.AccountRequest
|
| 354 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 355 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 356 |
+
return
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// 生成模式 - 固定生成1个账号
|
| 360 |
+
if req.GenerateMode {
|
| 361 |
+
// 检查是否提供了 refresh_token 或 access_token
|
| 362 |
+
if req.RefreshToken == "" && req.Token == "" {
|
| 363 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "生成模式需要提供 access_token 或 RefreshToken"})
|
| 364 |
+
return
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
// 优先使用 access_token,如果同时提供了两个字段
|
| 368 |
+
var masterToken string
|
| 369 |
+
|
| 370 |
+
if req.Token != "" {
|
| 371 |
+
// 直接使用提供的 access_token
|
| 372 |
+
masterToken = req.Token
|
| 373 |
+
log.Printf("[生成凭证] 使用提供的 access_token")
|
| 374 |
+
} else {
|
| 375 |
+
// 使用 refresh_token 获取 access_token
|
| 376 |
+
tokenResp, err := service.RefreshAccessToken(req.RefreshToken, req.Proxy)
|
| 377 |
+
if err != nil {
|
| 378 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "RefreshToken 无效: " + err.Error()})
|
| 379 |
+
return
|
| 380 |
+
}
|
| 381 |
+
masterToken = tokenResp.AccessToken
|
| 382 |
+
log.Printf("[生成凭证] 通过 RefreshToken 获取了 access_token")
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
log.Printf("[生成凭证] 开始生成账号凭证")
|
| 386 |
+
|
| 387 |
+
// 生成凭证
|
| 388 |
+
cred, err := service.GenerateCredential(masterToken)
|
| 389 |
+
if err != nil {
|
| 390 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("生成失败: %v", err)})
|
| 391 |
+
return
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
log.Printf("[生成凭证] 凭证生成成功: ClientID=%s", cred.ClientID)
|
| 395 |
+
|
| 396 |
+
// 创建账号
|
| 397 |
+
account := model.Account{
|
| 398 |
+
ClientID: cred.ClientID,
|
| 399 |
+
ClientSecret: cred.Secret,
|
| 400 |
+
Proxy: req.Proxy,
|
| 401 |
+
IsActive: true,
|
| 402 |
+
Status: "normal",
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
// 使用生成的client_id和client_secret获取token
|
| 406 |
+
// 使用OAuth client credentials方式刷新token,使用 https://fe.zencoder.ai/oauth/token
|
| 407 |
+
if _, err := service.RefreshToken(&account); err != nil {
|
| 408 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("认证失败: %v", err)})
|
| 409 |
+
return
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// 解析 Token 获取详细信息
|
| 413 |
+
if payload, err := service.ParseJWT(account.AccessToken); err == nil {
|
| 414 |
+
account.Email = payload.Email
|
| 415 |
+
account.SubscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 416 |
+
|
| 417 |
+
if payload.Expiration > 0 {
|
| 418 |
+
account.TokenExpiry = time.Unix(payload.Expiration, 0)
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
plan := payload.CustomClaims.Plan
|
| 422 |
+
if plan != "" {
|
| 423 |
+
plan = strings.ToUpper(plan[:1]) + plan[1:]
|
| 424 |
+
}
|
| 425 |
+
if plan != "" {
|
| 426 |
+
account.PlanType = model.PlanType(plan)
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
if account.PlanType == "" {
|
| 430 |
+
account.PlanType = model.PlanFree
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// 检查是否已存在
|
| 434 |
+
var existing model.Account
|
| 435 |
+
var count int64
|
| 436 |
+
database.GetDB().Model(&model.Account{}).Where("client_id = ?", account.ClientID).Count(&count)
|
| 437 |
+
if count > 0 {
|
| 438 |
+
// 获取现有账号
|
| 439 |
+
database.GetDB().Where("client_id = ?", account.ClientID).First(&existing)
|
| 440 |
+
// 更新现有账号
|
| 441 |
+
existing.AccessToken = account.AccessToken
|
| 442 |
+
existing.TokenExpiry = account.TokenExpiry
|
| 443 |
+
existing.PlanType = account.PlanType
|
| 444 |
+
existing.Email = account.Email
|
| 445 |
+
existing.SubscriptionStartDate = account.SubscriptionStartDate
|
| 446 |
+
existing.IsActive = true
|
| 447 |
+
existing.Status = "normal" // 重新激活
|
| 448 |
+
existing.ClientSecret = account.ClientSecret
|
| 449 |
+
if account.Proxy != "" {
|
| 450 |
+
existing.Proxy = account.Proxy
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if err := database.GetDB().Save(&existing).Error; err != nil {
|
| 454 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("更新失败: %v", err)})
|
| 455 |
+
return
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
log.Printf("[添加账号] 账号更新成功: ClientID=%s, Email=%s, Plan=%s", existing.ClientID, existing.Email, existing.PlanType)
|
| 459 |
+
c.JSON(http.StatusOK, existing)
|
| 460 |
+
} else {
|
| 461 |
+
// 创建新账号
|
| 462 |
+
if err := database.GetDB().Create(&account).Error; err != nil {
|
| 463 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("创建失败: %v", err)})
|
| 464 |
+
return
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
log.Printf("[添加账号] 新账号创建成功: ClientID=%s, Email=%s, Plan=%s", account.ClientID, account.Email, account.PlanType)
|
| 468 |
+
c.JSON(http.StatusCreated, account)
|
| 469 |
+
}
|
| 470 |
+
return
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
// 原有的单个账号添加逻辑 - 现在使用 refresh_token
|
| 474 |
+
account := model.Account{
|
| 475 |
+
Proxy: req.Proxy,
|
| 476 |
+
IsActive: true,
|
| 477 |
+
Status: "normal",
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// 优先使用 access_token,如果同时提供了两个字段则不使用 refresh_token
|
| 481 |
+
if req.Token != "" && req.RefreshToken != "" {
|
| 482 |
+
log.Printf("[凭证模式] 同时提供了 access_token 和 RefreshToken,优先使用 access_token")
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
if req.Token != "" {
|
| 486 |
+
// JWT Parsing Logic
|
| 487 |
+
payload, err := service.ParseJWT(req.Token)
|
| 488 |
+
if err != nil {
|
| 489 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的Token: " + err.Error()})
|
| 490 |
+
return
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
account.AccessToken = req.Token
|
| 494 |
+
// 优先使用ClientID字段,如果没有则使用Subject
|
| 495 |
+
if payload.ClientID != "" {
|
| 496 |
+
account.ClientID = payload.ClientID
|
| 497 |
+
} else {
|
| 498 |
+
account.ClientID = payload.Subject
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
account.Email = payload.Email
|
| 502 |
+
account.SubscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 503 |
+
|
| 504 |
+
if payload.Expiration > 0 {
|
| 505 |
+
account.TokenExpiry = time.Unix(payload.Expiration, 0)
|
| 506 |
+
} else {
|
| 507 |
+
account.TokenExpiry = time.Now().Add(24 * time.Hour) // 默认24小时
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
// Map PlanType
|
| 511 |
+
plan := payload.CustomClaims.Plan
|
| 512 |
+
|
| 513 |
+
// Simple normalization
|
| 514 |
+
if plan != "" {
|
| 515 |
+
plan = strings.ToUpper(plan[:1]) + plan[1:]
|
| 516 |
+
}
|
| 517 |
+
account.PlanType = model.PlanType(plan)
|
| 518 |
+
if account.PlanType == "" {
|
| 519 |
+
account.PlanType = model.PlanFree
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
// Placeholder for secret since it's required by DB but not in JWT
|
| 523 |
+
account.ClientSecret = "jwt-login"
|
| 524 |
+
} else if req.RefreshToken != "" {
|
| 525 |
+
// 只提供了 refresh_token,使用它来获取 access_token
|
| 526 |
+
tokenResp, err := service.RefreshAccessToken(req.RefreshToken, req.Proxy)
|
| 527 |
+
if err != nil {
|
| 528 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "RefreshToken 无效: " + err.Error()})
|
| 529 |
+
return
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
account.AccessToken = tokenResp.AccessToken
|
| 533 |
+
account.RefreshToken = tokenResp.RefreshToken
|
| 534 |
+
account.TokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 535 |
+
|
| 536 |
+
// 解析 Token 获取详细信息
|
| 537 |
+
if payload, err := service.ParseJWT(tokenResp.AccessToken); err == nil {
|
| 538 |
+
// 设置 Email
|
| 539 |
+
if payload.Email != "" {
|
| 540 |
+
account.Email = payload.Email
|
| 541 |
+
} else if tokenResp.Email != "" {
|
| 542 |
+
account.Email = tokenResp.Email
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
// 设置 ClientID - 优先使用 Email 作为唯一标识符
|
| 546 |
+
if payload.Email != "" {
|
| 547 |
+
account.ClientID = payload.Email
|
| 548 |
+
} else if payload.Subject != "" {
|
| 549 |
+
account.ClientID = payload.Subject
|
| 550 |
+
} else if payload.ClientID != "" {
|
| 551 |
+
account.ClientID = payload.ClientID
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
account.SubscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 555 |
+
|
| 556 |
+
// Map PlanType
|
| 557 |
+
plan := payload.CustomClaims.Plan
|
| 558 |
+
if plan != "" {
|
| 559 |
+
plan = strings.ToUpper(plan[:1]) + plan[1:]
|
| 560 |
+
}
|
| 561 |
+
account.PlanType = model.PlanType(plan)
|
| 562 |
+
if account.PlanType == "" {
|
| 563 |
+
account.PlanType = model.PlanFree
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
log.Printf("[凭证模式-RefreshToken] 解析JWT成功: ClientID=%s, Email=%s, Plan=%s",
|
| 567 |
+
account.ClientID, account.Email, account.PlanType)
|
| 568 |
+
} else {
|
| 569 |
+
log.Printf("[凭证模式-RefreshToken] 解析JWT失败: %v", err)
|
| 570 |
+
// 如果JWT解析失败,使用 tokenResp 中的信息
|
| 571 |
+
if tokenResp.UserID != "" {
|
| 572 |
+
account.ClientID = tokenResp.UserID
|
| 573 |
+
account.Email = tokenResp.UserID
|
| 574 |
+
}
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
// 生成一个占位 ClientSecret
|
| 578 |
+
account.ClientSecret = "refresh-token-login"
|
| 579 |
+
|
| 580 |
+
// 确保 ClientID 不为空
|
| 581 |
+
if account.ClientID == "" {
|
| 582 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "无法获取用户信息,请检查RefreshToken是否有效"})
|
| 583 |
+
return
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
} else {
|
| 587 |
+
// Old Logic
|
| 588 |
+
if req.PlanType == "" {
|
| 589 |
+
req.PlanType = model.PlanFree
|
| 590 |
+
}
|
| 591 |
+
account.ClientID = req.ClientID
|
| 592 |
+
account.ClientSecret = req.ClientSecret
|
| 593 |
+
account.Email = req.Email
|
| 594 |
+
account.PlanType = req.PlanType
|
| 595 |
+
|
| 596 |
+
// 验证Token是否能正确获取
|
| 597 |
+
if _, err := service.RefreshToken(&account); err != nil {
|
| 598 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "认证失败: " + err.Error()})
|
| 599 |
+
return
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
// 解析Token获取详细信息
|
| 603 |
+
if payload, err := service.ParseJWT(account.AccessToken); err == nil {
|
| 604 |
+
if account.Email == "" {
|
| 605 |
+
account.Email = payload.Email
|
| 606 |
+
}
|
| 607 |
+
account.SubscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 608 |
+
|
| 609 |
+
if payload.Expiration > 0 {
|
| 610 |
+
account.TokenExpiry = time.Unix(payload.Expiration, 0)
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
plan := payload.CustomClaims.Plan
|
| 614 |
+
if plan != "" {
|
| 615 |
+
plan = strings.ToUpper(plan[:1]) + plan[1:]
|
| 616 |
+
}
|
| 617 |
+
if plan != "" {
|
| 618 |
+
account.PlanType = model.PlanType(plan)
|
| 619 |
+
}
|
| 620 |
+
}
|
| 621 |
+
if account.PlanType == "" {
|
| 622 |
+
account.PlanType = model.PlanFree
|
| 623 |
+
}
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
// Check if account exists - 使用 Count 避免 record not found 警告
|
| 627 |
+
var existing model.Account
|
| 628 |
+
var count int64
|
| 629 |
+
database.GetDB().Model(&model.Account{}).Where("client_id = ?", account.ClientID).Count(&count)
|
| 630 |
+
if count > 0 {
|
| 631 |
+
// 获取现有账号
|
| 632 |
+
database.GetDB().Where("client_id = ?", account.ClientID).First(&existing)
|
| 633 |
+
// Update existing
|
| 634 |
+
existing.AccessToken = account.AccessToken
|
| 635 |
+
existing.RefreshToken = account.RefreshToken // 更新 refresh_token
|
| 636 |
+
existing.TokenExpiry = account.TokenExpiry
|
| 637 |
+
existing.PlanType = account.PlanType
|
| 638 |
+
existing.Email = account.Email
|
| 639 |
+
existing.SubscriptionStartDate = account.SubscriptionStartDate
|
| 640 |
+
existing.IsActive = true
|
| 641 |
+
existing.Status = "normal"
|
| 642 |
+
if account.Proxy != "" {
|
| 643 |
+
existing.Proxy = account.Proxy
|
| 644 |
+
}
|
| 645 |
+
// If secret was provided manually, update it. If placeholder, keep existing.
|
| 646 |
+
if account.ClientSecret != "jwt-login" && account.ClientSecret != "refresh-token-login" {
|
| 647 |
+
existing.ClientSecret = account.ClientSecret
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
if err := database.GetDB().Save(&existing).Error; err != nil {
|
| 651 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 652 |
+
return
|
| 653 |
+
}
|
| 654 |
+
c.JSON(http.StatusOK, existing)
|
| 655 |
+
return
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
if err := database.GetDB().Create(&account).Error; err != nil {
|
| 659 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 660 |
+
return
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
c.JSON(http.StatusCreated, account)
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
func (h *AccountHandler) Update(c *gin.Context) {
|
| 667 |
+
id := c.Param("id")
|
| 668 |
+
var account model.Account
|
| 669 |
+
if err := database.GetDB().First(&account, id).Error; err != nil {
|
| 670 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "account not found"})
|
| 671 |
+
return
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
var req model.AccountRequest
|
| 675 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 676 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 677 |
+
return
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
account.Email = req.Email
|
| 681 |
+
account.PlanType = req.PlanType
|
| 682 |
+
account.Proxy = req.Proxy
|
| 683 |
+
|
| 684 |
+
if err := database.GetDB().Save(&account).Error; err != nil {
|
| 685 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 686 |
+
return
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
c.JSON(http.StatusOK, account)
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
// BatchDelete 批量删除账号
|
| 693 |
+
func (h *AccountHandler) BatchDelete(c *gin.Context) {
|
| 694 |
+
var req BatchDeleteRequest
|
| 695 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 696 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 697 |
+
return
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
var deletedCount int64
|
| 701 |
+
|
| 702 |
+
if req.DeleteAll {
|
| 703 |
+
// 删除指定分类的所有账号
|
| 704 |
+
if req.Status == "" {
|
| 705 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "delete_all模式需要指定status"})
|
| 706 |
+
return
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
// 执行删除操作
|
| 710 |
+
result := database.GetDB().Where("status = ?", req.Status).Delete(&model.Account{})
|
| 711 |
+
if result.Error != nil {
|
| 712 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
| 713 |
+
return
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
deletedCount = result.RowsAffected
|
| 717 |
+
log.Printf("[批量删除] 删除分类 %s 的所有账号,共删除 %d 个", req.Status, deletedCount)
|
| 718 |
+
|
| 719 |
+
} else {
|
| 720 |
+
// 删除选中的账号
|
| 721 |
+
if len(req.IDs) == 0 {
|
| 722 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "未选择要删除的账号"})
|
| 723 |
+
return
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
// 执行删除操作
|
| 727 |
+
result := database.GetDB().Where("id IN ?", req.IDs).Delete(&model.Account{})
|
| 728 |
+
if result.Error != nil {
|
| 729 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
| 730 |
+
return
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
deletedCount = result.RowsAffected
|
| 734 |
+
log.Printf("[批量删除] 删除选中的 %d 个账号,实际删除 %d 个", len(req.IDs), deletedCount)
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
c.JSON(http.StatusOK, gin.H{
|
| 738 |
+
"message": "批量删除成功",
|
| 739 |
+
"deleted_count": deletedCount,
|
| 740 |
+
})
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
func (h *AccountHandler) Delete(c *gin.Context) {
|
| 744 |
+
id := c.Param("id")
|
| 745 |
+
if err := database.GetDB().Delete(&model.Account{}, id).Error; err != nil {
|
| 746 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 747 |
+
return
|
| 748 |
+
}
|
| 749 |
+
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
func (h *AccountHandler) Toggle(c *gin.Context) {
|
| 753 |
+
id := c.Param("id")
|
| 754 |
+
var account model.Account
|
| 755 |
+
if err := database.GetDB().First(&account, id).Error; err != nil {
|
| 756 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "account not found"})
|
| 757 |
+
return
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
// 切换 Disabled / Normal
|
| 761 |
+
if account.Status == "disabled" || !account.IsActive {
|
| 762 |
+
account.Status = "normal"
|
| 763 |
+
account.IsActive = true
|
| 764 |
+
account.IsCooling = false
|
| 765 |
+
account.ErrorCount = 0
|
| 766 |
+
} else {
|
| 767 |
+
account.Status = "disabled"
|
| 768 |
+
account.IsActive = false
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
database.GetDB().Save(&account)
|
| 772 |
+
|
| 773 |
+
c.JSON(http.StatusOK, account)
|
| 774 |
+
}
|
internal/handler/anthropic.go
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"crypto/rand"
|
| 6 |
+
"encoding/hex"
|
| 7 |
+
"errors"
|
| 8 |
+
"fmt"
|
| 9 |
+
"io"
|
| 10 |
+
"net/http"
|
| 11 |
+
|
| 12 |
+
"zencoder-2api/internal/service"
|
| 13 |
+
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
type AnthropicHandler struct {
|
| 18 |
+
svc *service.AnthropicService
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
func NewAnthropicHandler() *AnthropicHandler {
|
| 22 |
+
return &AnthropicHandler{svc: service.NewAnthropicService()}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// generateTraceID 生成一个随机的 trace ID
|
| 26 |
+
func generateAnthropicTraceID() string {
|
| 27 |
+
b := make([]byte, 16)
|
| 28 |
+
rand.Read(b)
|
| 29 |
+
return hex.EncodeToString(b)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// Messages 处理 POST /v1/messages
|
| 33 |
+
func (h *AnthropicHandler) Messages(c *gin.Context) {
|
| 34 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 35 |
+
if err != nil {
|
| 36 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 37 |
+
return
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// 传递原始请求头给service层,用于错误日志记录
|
| 41 |
+
ctx := context.WithValue(c.Request.Context(), "originalHeaders", c.Request.Header)
|
| 42 |
+
|
| 43 |
+
if err := h.svc.MessagesProxy(ctx, c.Writer, body); err != nil {
|
| 44 |
+
h.handleError(c, err)
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// handleError 统一处理错误,特别是没有可用账号的错误
|
| 49 |
+
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
| 50 |
+
if errors.Is(err, service.ErrNoAvailableAccount) || errors.Is(err, service.ErrNoPermission) {
|
| 51 |
+
traceID := generateAnthropicTraceID()
|
| 52 |
+
errMsg := fmt.Sprintf("没有可用token(traceid: %s)", traceID)
|
| 53 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": errMsg})
|
| 54 |
+
return
|
| 55 |
+
}
|
| 56 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 57 |
+
}
|
internal/handler/chat.go
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
|
| 6 |
+
"github.com/gin-gonic/gin"
|
| 7 |
+
"zencoder-2api/internal/model"
|
| 8 |
+
"zencoder-2api/internal/service"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
type ChatHandler struct {
|
| 12 |
+
svc *service.APIService
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
func NewChatHandler() *ChatHandler {
|
| 16 |
+
return &ChatHandler{svc: service.NewAPIService()}
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
func (h *ChatHandler) ChatCompletions(c *gin.Context) {
|
| 20 |
+
var req model.ChatCompletionRequest
|
| 21 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 22 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 23 |
+
return
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if req.Stream {
|
| 27 |
+
h.handleStream(c, &req)
|
| 28 |
+
return
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
resp, err := h.svc.Chat(&req)
|
| 32 |
+
if err != nil {
|
| 33 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 34 |
+
return
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
c.JSON(http.StatusOK, resp)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
func (h *ChatHandler) handleStream(c *gin.Context, req *model.ChatCompletionRequest) {
|
| 41 |
+
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
| 42 |
+
c.Writer.Header().Set("Cache-Control", "no-cache")
|
| 43 |
+
c.Writer.Header().Set("Connection", "keep-alive")
|
| 44 |
+
|
| 45 |
+
if err := h.svc.ChatStream(req, c.Writer); err != nil {
|
| 46 |
+
c.SSEvent("error", err.Error())
|
| 47 |
+
}
|
| 48 |
+
}
|
internal/handler/external.go
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"log"
|
| 6 |
+
"net/http"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
|
| 10 |
+
"github.com/gin-gonic/gin"
|
| 11 |
+
"zencoder-2api/internal/database"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
"zencoder-2api/internal/service"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
type ExternalHandler struct{}
|
| 17 |
+
|
| 18 |
+
func NewExternalHandler() *ExternalHandler {
|
| 19 |
+
return &ExternalHandler{}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// ExternalTokenRequest 外部API提交token请求结构
|
| 23 |
+
type ExternalTokenRequest struct {
|
| 24 |
+
AccessToken string `json:"access_token"` // OAuth获取的access_token
|
| 25 |
+
RefreshToken string `json:"refresh_token"` // OAuth获取的refresh_token
|
| 26 |
+
Proxy string `json:"proxy"` // 可选的代理设置
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// ExternalTokenResponse 外部API响应结构
|
| 30 |
+
type ExternalTokenResponse struct {
|
| 31 |
+
Success bool `json:"success"`
|
| 32 |
+
Message string `json:"message"`
|
| 33 |
+
Account *model.Account `json:"account,omitempty"`
|
| 34 |
+
Error string `json:"error,omitempty"`
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// SubmitTokens 外部API接口:接收OAuth token信息并生成账号记录
|
| 38 |
+
func (h *ExternalHandler) SubmitTokens(c *gin.Context) {
|
| 39 |
+
var req ExternalTokenRequest
|
| 40 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 41 |
+
c.JSON(http.StatusBadRequest, ExternalTokenResponse{
|
| 42 |
+
Success: false,
|
| 43 |
+
Error: "请求格式错误: " + err.Error(),
|
| 44 |
+
})
|
| 45 |
+
return
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// 验证必要字段
|
| 49 |
+
if req.AccessToken == "" && req.RefreshToken == "" {
|
| 50 |
+
c.JSON(http.StatusBadRequest, ExternalTokenResponse{
|
| 51 |
+
Success: false,
|
| 52 |
+
Error: "必须提供 access_token 或 refresh_token",
|
| 53 |
+
})
|
| 54 |
+
return
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
log.Printf("[外部API] 收到token提交请求,access_token长度: %d, refresh_token长度: %d",
|
| 58 |
+
len(req.AccessToken), len(req.RefreshToken))
|
| 59 |
+
|
| 60 |
+
// 优先使用 access_token,如果同时提供了两个字段
|
| 61 |
+
var masterToken string
|
| 62 |
+
|
| 63 |
+
if req.AccessToken != "" {
|
| 64 |
+
// 直接使用提供的 access_token
|
| 65 |
+
masterToken = req.AccessToken
|
| 66 |
+
log.Printf("[外部API] 使用提供的 access_token")
|
| 67 |
+
} else {
|
| 68 |
+
// 使用 refresh_token 获取 access_token
|
| 69 |
+
tokenResp, err := service.RefreshAccessToken(req.RefreshToken, req.Proxy)
|
| 70 |
+
if err != nil {
|
| 71 |
+
c.JSON(http.StatusBadRequest, ExternalTokenResponse{
|
| 72 |
+
Success: false,
|
| 73 |
+
Error: "RefreshToken 无效: " + err.Error(),
|
| 74 |
+
})
|
| 75 |
+
return
|
| 76 |
+
}
|
| 77 |
+
masterToken = tokenResp.AccessToken
|
| 78 |
+
log.Printf("[外部API] 通过 RefreshToken 获取了 access_token")
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
log.Printf("[外部API] 开始生成账号凭证")
|
| 82 |
+
|
| 83 |
+
// 生成凭证
|
| 84 |
+
cred, err := service.GenerateCredential(masterToken)
|
| 85 |
+
if err != nil {
|
| 86 |
+
c.JSON(http.StatusInternalServerError, ExternalTokenResponse{
|
| 87 |
+
Success: false,
|
| 88 |
+
Error: fmt.Sprintf("生成失败: %v", err),
|
| 89 |
+
})
|
| 90 |
+
return
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
log.Printf("[外部API] 凭证生成成功: ClientID=%s", cred.ClientID)
|
| 94 |
+
|
| 95 |
+
// 创建账号
|
| 96 |
+
account := model.Account{
|
| 97 |
+
ClientID: cred.ClientID,
|
| 98 |
+
ClientSecret: cred.Secret,
|
| 99 |
+
Proxy: req.Proxy,
|
| 100 |
+
IsActive: true,
|
| 101 |
+
Status: "normal",
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// 使用生成的client_id和client_secret获取token,带重试机制
|
| 105 |
+
// 使用OAuth client credentials方式刷新token,使用 https://fe.zencoder.ai/oauth/token
|
| 106 |
+
maxRetries := 3
|
| 107 |
+
retryDelay := 2 * time.Second
|
| 108 |
+
var lastErr error
|
| 109 |
+
|
| 110 |
+
for attempt := 1; attempt <= maxRetries; attempt++ {
|
| 111 |
+
log.Printf("[外部API] 尝试获取token,第 %d/%d 次", attempt, maxRetries)
|
| 112 |
+
|
| 113 |
+
if _, err := service.RefreshToken(&account); err != nil {
|
| 114 |
+
lastErr = err
|
| 115 |
+
log.Printf("[外部API] 第 %d 次获取token失败: %v", attempt, err)
|
| 116 |
+
|
| 117 |
+
if attempt < maxRetries {
|
| 118 |
+
log.Printf("[外部API] 等待 %v 后重试", retryDelay)
|
| 119 |
+
time.Sleep(retryDelay)
|
| 120 |
+
continue
|
| 121 |
+
}
|
| 122 |
+
} else {
|
| 123 |
+
log.Printf("[外部API] 第 %d 次获取token成功", attempt)
|
| 124 |
+
lastErr = nil
|
| 125 |
+
break
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if lastErr != nil {
|
| 130 |
+
c.JSON(http.StatusBadRequest, ExternalTokenResponse{
|
| 131 |
+
Success: false,
|
| 132 |
+
Error: fmt.Sprintf("认证失败(重试 %d 次后): %v", maxRetries, lastErr),
|
| 133 |
+
})
|
| 134 |
+
return
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// 解析 Token 获取详细信息
|
| 138 |
+
if payload, err := service.ParseJWT(account.AccessToken); err == nil {
|
| 139 |
+
account.Email = payload.Email
|
| 140 |
+
account.SubscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 141 |
+
|
| 142 |
+
if payload.Expiration > 0 {
|
| 143 |
+
account.TokenExpiry = time.Unix(payload.Expiration, 0)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
plan := payload.CustomClaims.Plan
|
| 147 |
+
if plan != "" {
|
| 148 |
+
plan = strings.ToUpper(plan[:1]) + plan[1:]
|
| 149 |
+
}
|
| 150 |
+
if plan != "" {
|
| 151 |
+
account.PlanType = model.PlanType(plan)
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
if account.PlanType == "" {
|
| 155 |
+
account.PlanType = model.PlanFree
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// 检查是否已存在
|
| 159 |
+
var existing model.Account
|
| 160 |
+
var count int64
|
| 161 |
+
database.GetDB().Model(&model.Account{}).Where("client_id = ?", account.ClientID).Count(&count)
|
| 162 |
+
if count > 0 {
|
| 163 |
+
// 获取现有账号
|
| 164 |
+
database.GetDB().Where("client_id = ?", account.ClientID).First(&existing)
|
| 165 |
+
// 更新现有账号
|
| 166 |
+
existing.AccessToken = account.AccessToken
|
| 167 |
+
existing.TokenExpiry = account.TokenExpiry
|
| 168 |
+
existing.PlanType = account.PlanType
|
| 169 |
+
existing.Email = account.Email
|
| 170 |
+
existing.SubscriptionStartDate = account.SubscriptionStartDate
|
| 171 |
+
existing.IsActive = true
|
| 172 |
+
existing.Status = "normal" // 重新激活
|
| 173 |
+
existing.ClientSecret = account.ClientSecret
|
| 174 |
+
if account.Proxy != "" {
|
| 175 |
+
existing.Proxy = account.Proxy
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if err := database.GetDB().Save(&existing).Error; err != nil {
|
| 179 |
+
c.JSON(http.StatusInternalServerError, ExternalTokenResponse{
|
| 180 |
+
Success: false,
|
| 181 |
+
Error: fmt.Sprintf("更新失败: %v", err),
|
| 182 |
+
})
|
| 183 |
+
return
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
log.Printf("[外部API] 账号更新成功: ClientID=%s, Email=%s, Plan=%s", existing.ClientID, existing.Email, existing.PlanType)
|
| 187 |
+
c.JSON(http.StatusOK, ExternalTokenResponse{
|
| 188 |
+
Success: true,
|
| 189 |
+
Message: "账号更新成功",
|
| 190 |
+
Account: &existing,
|
| 191 |
+
})
|
| 192 |
+
} else {
|
| 193 |
+
// 创建新账号
|
| 194 |
+
if err := database.GetDB().Create(&account).Error; err != nil {
|
| 195 |
+
c.JSON(http.StatusInternalServerError, ExternalTokenResponse{
|
| 196 |
+
Success: false,
|
| 197 |
+
Error: fmt.Sprintf("创建失败: %v", err),
|
| 198 |
+
})
|
| 199 |
+
return
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
log.Printf("[外部API] 新账号创建成功: ClientID=%s, Email=%s, Plan=%s", account.ClientID, account.Email, account.PlanType)
|
| 203 |
+
c.JSON(http.StatusCreated, ExternalTokenResponse{
|
| 204 |
+
Success: true,
|
| 205 |
+
Message: "账号创建成功",
|
| 206 |
+
Account: &account,
|
| 207 |
+
})
|
| 208 |
+
}
|
| 209 |
+
}
|
internal/handler/gemini.go
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"crypto/rand"
|
| 5 |
+
"encoding/hex"
|
| 6 |
+
"errors"
|
| 7 |
+
"fmt"
|
| 8 |
+
"io"
|
| 9 |
+
"net/http"
|
| 10 |
+
"strings"
|
| 11 |
+
|
| 12 |
+
"github.com/gin-gonic/gin"
|
| 13 |
+
"zencoder-2api/internal/service"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
type GeminiHandler struct {
|
| 17 |
+
svc *service.GeminiService
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
func NewGeminiHandler() *GeminiHandler {
|
| 21 |
+
return &GeminiHandler{svc: service.NewGeminiService()}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// generateTraceID 生成一个随机的 trace ID
|
| 25 |
+
func generateGeminiTraceID() string {
|
| 26 |
+
b := make([]byte, 16)
|
| 27 |
+
rand.Read(b)
|
| 28 |
+
return hex.EncodeToString(b)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// HandleRequest 处理 POST /v1beta/models/*path
|
| 32 |
+
// 路径格式: /model:action 例如 /gemini-3-flash-preview:streamGenerateContent
|
| 33 |
+
func (h *GeminiHandler) HandleRequest(c *gin.Context) {
|
| 34 |
+
path := c.Param("path")
|
| 35 |
+
// 去掉开头的斜杠
|
| 36 |
+
path = strings.TrimPrefix(path, "/")
|
| 37 |
+
|
| 38 |
+
// 解析 model:action
|
| 39 |
+
parts := strings.SplitN(path, ":", 2)
|
| 40 |
+
if len(parts) != 2 {
|
| 41 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path format"})
|
| 42 |
+
return
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
modelName := parts[0]
|
| 46 |
+
action := parts[1]
|
| 47 |
+
|
| 48 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 49 |
+
if err != nil {
|
| 50 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 51 |
+
return
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
switch action {
|
| 55 |
+
case "generateContent":
|
| 56 |
+
if err := h.svc.GenerateContentProxy(c.Request.Context(), c.Writer, modelName, body); err != nil {
|
| 57 |
+
h.handleError(c, err)
|
| 58 |
+
}
|
| 59 |
+
case "streamGenerateContent":
|
| 60 |
+
if err := h.svc.StreamGenerateContentProxy(c.Request.Context(), c.Writer, modelName, body); err != nil {
|
| 61 |
+
h.handleError(c, err)
|
| 62 |
+
}
|
| 63 |
+
default:
|
| 64 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported action: " + action})
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// handleError 统一处理错误,特别是没有可用账号的错误
|
| 69 |
+
func (h *GeminiHandler) handleError(c *gin.Context, err error) {
|
| 70 |
+
if errors.Is(err, service.ErrNoAvailableAccount) || errors.Is(err, service.ErrNoPermission) {
|
| 71 |
+
traceID := generateGeminiTraceID()
|
| 72 |
+
errMsg := fmt.Sprintf("没有可用token(traceid: %s)", traceID)
|
| 73 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": errMsg})
|
| 74 |
+
return
|
| 75 |
+
}
|
| 76 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 77 |
+
}
|
internal/handler/grok.go
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"crypto/rand"
|
| 5 |
+
"encoding/hex"
|
| 6 |
+
"errors"
|
| 7 |
+
"fmt"
|
| 8 |
+
"io"
|
| 9 |
+
"net/http"
|
| 10 |
+
|
| 11 |
+
"zencoder-2api/internal/service"
|
| 12 |
+
|
| 13 |
+
"github.com/gin-gonic/gin"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
type GrokHandler struct {
|
| 17 |
+
svc *service.GrokService
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
func NewGrokHandler() *GrokHandler {
|
| 21 |
+
return &GrokHandler{svc: service.NewGrokService()}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// generateTraceID 生成一个随机的 trace ID
|
| 25 |
+
func generateGrokTraceID() string {
|
| 26 |
+
b := make([]byte, 16)
|
| 27 |
+
rand.Read(b)
|
| 28 |
+
return hex.EncodeToString(b)
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// ChatCompletions 处理 POST /v1/chat/completions (xAI)
|
| 32 |
+
func (h *GrokHandler) ChatCompletions(c *gin.Context) {
|
| 33 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 34 |
+
if err != nil {
|
| 35 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 36 |
+
return
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
if err := h.svc.ChatCompletionsProxy(c.Request.Context(), c.Writer, body); err != nil {
|
| 40 |
+
h.handleError(c, err)
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// handleError 统一处理错误,特别是没有可用账号的错误
|
| 45 |
+
func (h *GrokHandler) handleError(c *gin.Context, err error) {
|
| 46 |
+
if errors.Is(err, service.ErrNoAvailableAccount) || errors.Is(err, service.ErrNoPermission) {
|
| 47 |
+
traceID := generateGrokTraceID()
|
| 48 |
+
errMsg := fmt.Sprintf("没有可用token(traceid: %s)", traceID)
|
| 49 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": errMsg})
|
| 50 |
+
return
|
| 51 |
+
}
|
| 52 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 53 |
+
}
|
internal/handler/oauth.go
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"crypto/rand"
|
| 6 |
+
"crypto/sha256"
|
| 7 |
+
"encoding/base64"
|
| 8 |
+
"encoding/json"
|
| 9 |
+
"fmt"
|
| 10 |
+
"net/http"
|
| 11 |
+
"net/url"
|
| 12 |
+
"sync"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"github.com/gin-gonic/gin"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// PKCESession 存储PKCE会话信息
|
| 19 |
+
type PKCESession struct {
|
| 20 |
+
CodeVerifier string
|
| 21 |
+
CreatedAt time.Time
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// PKCESessionStore 内存中存储PKCE会话
|
| 25 |
+
type PKCESessionStore struct {
|
| 26 |
+
sync.RWMutex
|
| 27 |
+
sessions map[string]*PKCESession
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// 全局PKCE会话存储
|
| 31 |
+
var pkceStore = &PKCESessionStore{
|
| 32 |
+
sessions: make(map[string]*PKCESession),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// OAuthHandler OAuth相关处理器
|
| 36 |
+
type OAuthHandler struct {
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// NewOAuthHandler 创建OAuth处理器
|
| 40 |
+
func NewOAuthHandler() *OAuthHandler {
|
| 41 |
+
// 启动清理过期会话的定时器
|
| 42 |
+
go cleanupExpiredSessions()
|
| 43 |
+
|
| 44 |
+
return &OAuthHandler{}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// StartOAuthForRT 开始OAuth流程获取RT
|
| 48 |
+
func (h *OAuthHandler) StartOAuthForRT(c *gin.Context) {
|
| 49 |
+
// 生成PKCE参数
|
| 50 |
+
codeVerifier, err := generateCodeVerifier(32)
|
| 51 |
+
if err != nil {
|
| 52 |
+
c.JSON(http.StatusInternalServerError, gin.H{
|
| 53 |
+
"error": "生成PKCE参数失败",
|
| 54 |
+
})
|
| 55 |
+
return
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// 生成code_challenge
|
| 59 |
+
codeChallenge := generateCodeChallenge(codeVerifier)
|
| 60 |
+
|
| 61 |
+
// 生成会话ID
|
| 62 |
+
sessionID, err := generateSessionID()
|
| 63 |
+
if err != nil {
|
| 64 |
+
c.JSON(http.StatusInternalServerError, gin.H{
|
| 65 |
+
"error": "生成会话ID失败",
|
| 66 |
+
})
|
| 67 |
+
return
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// 存储会话
|
| 71 |
+
pkceStore.Lock()
|
| 72 |
+
pkceStore.sessions[sessionID] = &PKCESession{
|
| 73 |
+
CodeVerifier: codeVerifier,
|
| 74 |
+
CreatedAt: time.Now(),
|
| 75 |
+
}
|
| 76 |
+
pkceStore.Unlock()
|
| 77 |
+
|
| 78 |
+
// 获取回调URL
|
| 79 |
+
scheme := "http"
|
| 80 |
+
if c.Request.TLS != nil {
|
| 81 |
+
scheme = "https"
|
| 82 |
+
}
|
| 83 |
+
host := c.Request.Host
|
| 84 |
+
|
| 85 |
+
callbackURL := fmt.Sprintf("%s://%s/api/oauth/callback-rt?session=%s",
|
| 86 |
+
scheme, host, sessionID)
|
| 87 |
+
|
| 88 |
+
// 构建state参数
|
| 89 |
+
state := map[string]string{
|
| 90 |
+
"redirectUri": callbackURL,
|
| 91 |
+
"codeChallenge": codeChallenge,
|
| 92 |
+
"sessionId": sessionID,
|
| 93 |
+
}
|
| 94 |
+
stateJSON, _ := json.Marshal(state)
|
| 95 |
+
|
| 96 |
+
// 构建授权URL
|
| 97 |
+
params := url.Values{
|
| 98 |
+
"state": {string(stateJSON)},
|
| 99 |
+
"response_type": {"code"},
|
| 100 |
+
"client_id": {"5948a5c5-4b30-4465-a3f2-2136ea53ea0a"},
|
| 101 |
+
"scope": {"openid profile email"},
|
| 102 |
+
"redirect_uri": {"https://auth.zencoder.ai/extension/auth-success"},
|
| 103 |
+
"code_challenge": {codeChallenge},
|
| 104 |
+
"code_challenge_method": {"S256"},
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
authURL := fmt.Sprintf("https://fe.zencoder.ai/oauth/authorize?%s", params.Encode())
|
| 108 |
+
|
| 109 |
+
// 重定向到授权页面
|
| 110 |
+
c.Redirect(http.StatusFound, authURL)
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// CallbackOAuthForRT 处理OAuth回调
|
| 114 |
+
func (h *OAuthHandler) CallbackOAuthForRT(c *gin.Context) {
|
| 115 |
+
code := c.Query("code")
|
| 116 |
+
sessionID := c.Query("session")
|
| 117 |
+
|
| 118 |
+
// 验证参数
|
| 119 |
+
if code == "" || sessionID == "" {
|
| 120 |
+
h.renderCallbackPage(c, false, "", "", "缺少必要参数")
|
| 121 |
+
return
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// 获取会话
|
| 125 |
+
pkceStore.RLock()
|
| 126 |
+
session, exists := pkceStore.sessions[sessionID]
|
| 127 |
+
pkceStore.RUnlock()
|
| 128 |
+
|
| 129 |
+
if !exists {
|
| 130 |
+
h.renderCallbackPage(c, false, "", "", "会话已过期,请重新获取")
|
| 131 |
+
return
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// 交换token
|
| 135 |
+
tokenResp, err := h.exchangeCodeForToken(code, session.CodeVerifier)
|
| 136 |
+
if err != nil {
|
| 137 |
+
h.renderCallbackPage(c, false, "", "", fmt.Sprintf("获取Token失败: %v", err))
|
| 138 |
+
return
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// 清理会话
|
| 142 |
+
pkceStore.Lock()
|
| 143 |
+
delete(pkceStore.sessions, sessionID)
|
| 144 |
+
pkceStore.Unlock()
|
| 145 |
+
|
| 146 |
+
// 渲染成功页面,传递access token和refresh token
|
| 147 |
+
h.renderCallbackPage(c, true, tokenResp.AccessToken, tokenResp.RefreshToken, "")
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// exchangeCodeForToken 用授权码换取token
|
| 151 |
+
func (h *OAuthHandler) exchangeCodeForToken(code, codeVerifier string) (*OAuthTokenResponse, error) {
|
| 152 |
+
tokenURL := "https://auth.zencoder.ai/api/frontegg/oauth/token"
|
| 153 |
+
|
| 154 |
+
payload := map[string]string{
|
| 155 |
+
"code": code,
|
| 156 |
+
"redirect_uri": "https://auth.zencoder.ai/extension/auth-success",
|
| 157 |
+
"code_verifier": codeVerifier,
|
| 158 |
+
"grant_type": "authorization_code",
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
body, _ := json.Marshal(payload)
|
| 162 |
+
|
| 163 |
+
req, err := http.NewRequest("POST", tokenURL, bytes.NewReader(body))
|
| 164 |
+
if err != nil {
|
| 165 |
+
return nil, err
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// 设置请求头
|
| 169 |
+
req.Header.Set("Content-Type", "application/json")
|
| 170 |
+
req.Header.Set("x-frontegg-sdk", "@frontegg/nextjs@9.2.10")
|
| 171 |
+
req.Header.Set("x-frontegg-framework", "next@15.3.8")
|
| 172 |
+
req.Header.Set("Origin", "https://auth.zencoder.ai")
|
| 173 |
+
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
| 174 |
+
|
| 175 |
+
client := &http.Client{Timeout: 30 * time.Second}
|
| 176 |
+
resp, err := client.Do(req)
|
| 177 |
+
if err != nil {
|
| 178 |
+
return nil, err
|
| 179 |
+
}
|
| 180 |
+
defer resp.Body.Close()
|
| 181 |
+
|
| 182 |
+
if resp.StatusCode != http.StatusOK {
|
| 183 |
+
return nil, fmt.Errorf("token exchange failed with status %d", resp.StatusCode)
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
var tokenResp OAuthTokenResponse
|
| 187 |
+
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
| 188 |
+
return nil, err
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
return &tokenResp, nil
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// renderCallbackPage 渲染回调页面
|
| 195 |
+
func (h *OAuthHandler) renderCallbackPage(c *gin.Context, success bool, accessToken, refreshToken, errorMsg string) {
|
| 196 |
+
html := `
|
| 197 |
+
<!DOCTYPE html>
|
| 198 |
+
<html lang="zh-CN">
|
| 199 |
+
<head>
|
| 200 |
+
<meta charset="UTF-8">
|
| 201 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 202 |
+
<title>OAuth认证</title>
|
| 203 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 204 |
+
</head>
|
| 205 |
+
<body class="bg-gray-50 dark:bg-gray-900 min-h-screen flex items-center justify-center">
|
| 206 |
+
<div class="max-w-md w-full mx-4">
|
| 207 |
+
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-8">
|
| 208 |
+
`
|
| 209 |
+
|
| 210 |
+
if success {
|
| 211 |
+
html += fmt.Sprintf(`
|
| 212 |
+
<div class="text-center">
|
| 213 |
+
<div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-green-100">
|
| 214 |
+
<svg class="h-6 w-6 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 215 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
|
| 216 |
+
</svg>
|
| 217 |
+
</div>
|
| 218 |
+
<h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证成功!</h2>
|
| 219 |
+
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">正在返回并填充Token...</p>
|
| 220 |
+
</div>
|
| 221 |
+
<script>
|
| 222 |
+
// 发送消息给父窗口
|
| 223 |
+
if (window.opener) {
|
| 224 |
+
window.opener.postMessage({
|
| 225 |
+
type: 'oauth-rt-complete',
|
| 226 |
+
success: true,
|
| 227 |
+
accessToken: '%s',
|
| 228 |
+
refreshToken: '%s'
|
| 229 |
+
}, window.location.origin);
|
| 230 |
+
|
| 231 |
+
// 2秒后关闭窗口
|
| 232 |
+
setTimeout(() => {
|
| 233 |
+
window.close();
|
| 234 |
+
}, 2000);
|
| 235 |
+
}
|
| 236 |
+
</script>
|
| 237 |
+
`, accessToken, refreshToken)
|
| 238 |
+
} else {
|
| 239 |
+
html += fmt.Sprintf(`
|
| 240 |
+
<div class="text-center">
|
| 241 |
+
<div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-red-100">
|
| 242 |
+
<svg class="h-6 w-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 243 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path>
|
| 244 |
+
</svg>
|
| 245 |
+
</div>
|
| 246 |
+
<h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证失败</h2>
|
| 247 |
+
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">%s</p>
|
| 248 |
+
<button onclick="window.close()" class="mt-4 px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition-colors">
|
| 249 |
+
关闭窗口
|
| 250 |
+
</button>
|
| 251 |
+
</div>
|
| 252 |
+
<script>
|
| 253 |
+
// 发送错误消息给父窗口
|
| 254 |
+
if (window.opener) {
|
| 255 |
+
window.opener.postMessage({
|
| 256 |
+
type: 'oauth-rt-complete',
|
| 257 |
+
success: false,
|
| 258 |
+
error: '%s'
|
| 259 |
+
}, window.location.origin);
|
| 260 |
+
}
|
| 261 |
+
</script>
|
| 262 |
+
`, errorMsg, errorMsg)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
html += `
|
| 266 |
+
</div>
|
| 267 |
+
</div>
|
| 268 |
+
</body>
|
| 269 |
+
</html>
|
| 270 |
+
`
|
| 271 |
+
|
| 272 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 273 |
+
c.String(http.StatusOK, html)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
// OAuthTokenResponse OAuth token响应
|
| 277 |
+
type OAuthTokenResponse struct {
|
| 278 |
+
AccessToken string `json:"access_token"`
|
| 279 |
+
RefreshToken string `json:"refresh_token"`
|
| 280 |
+
TokenType string `json:"token_type"`
|
| 281 |
+
ExpiresIn int `json:"expires_in"`
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// generateCodeVerifier 生成PKCE code_verifier
|
| 285 |
+
func generateCodeVerifier(length int) (string, error) {
|
| 286 |
+
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
| 287 |
+
result := make([]byte, length)
|
| 288 |
+
randomBytes := make([]byte, length)
|
| 289 |
+
|
| 290 |
+
if _, err := rand.Read(randomBytes); err != nil {
|
| 291 |
+
return "", err
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
for i := 0; i < length; i++ {
|
| 295 |
+
result[i] = chars[int(randomBytes[i])%len(chars)]
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
return string(result), nil
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
// generateCodeChallenge 生成PKCE code_challenge
|
| 302 |
+
func generateCodeChallenge(codeVerifier string) string {
|
| 303 |
+
hash := sha256.Sum256([]byte(codeVerifier))
|
| 304 |
+
return base64.RawURLEncoding.EncodeToString(hash[:])
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// generateSessionID 生成会话ID
|
| 308 |
+
func generateSessionID() (string, error) {
|
| 309 |
+
b := make([]byte, 16)
|
| 310 |
+
if _, err := rand.Read(b); err != nil {
|
| 311 |
+
return "", err
|
| 312 |
+
}
|
| 313 |
+
return base64.URLEncoding.EncodeToString(b), nil
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
// cleanupExpiredSessions 清理过期的PKCE会话
|
| 317 |
+
func cleanupExpiredSessions() {
|
| 318 |
+
ticker := time.NewTicker(5 * time.Minute)
|
| 319 |
+
defer ticker.Stop()
|
| 320 |
+
|
| 321 |
+
for range ticker.C {
|
| 322 |
+
pkceStore.Lock()
|
| 323 |
+
now := time.Now()
|
| 324 |
+
for id, session := range pkceStore.sessions {
|
| 325 |
+
if now.Sub(session.CreatedAt) > 10*time.Minute {
|
| 326 |
+
delete(pkceStore.sessions, id)
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
pkceStore.Unlock()
|
| 330 |
+
}
|
| 331 |
+
}
|
internal/handler/openai.go
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"crypto/rand"
|
| 5 |
+
"encoding/hex"
|
| 6 |
+
"encoding/json"
|
| 7 |
+
"errors"
|
| 8 |
+
"fmt"
|
| 9 |
+
"io"
|
| 10 |
+
"net/http"
|
| 11 |
+
|
| 12 |
+
"github.com/gin-gonic/gin"
|
| 13 |
+
"zencoder-2api/internal/model"
|
| 14 |
+
"zencoder-2api/internal/service"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
type OpenAIHandler struct {
|
| 18 |
+
svc *service.OpenAIService
|
| 19 |
+
grokSvc *service.GrokService
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewOpenAIHandler() *OpenAIHandler {
|
| 23 |
+
return &OpenAIHandler{
|
| 24 |
+
svc: service.NewOpenAIService(),
|
| 25 |
+
grokSvc: service.NewGrokService(),
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// generateTraceID 生成一个随机的 trace ID
|
| 30 |
+
func generateTraceID() string {
|
| 31 |
+
b := make([]byte, 16)
|
| 32 |
+
rand.Read(b)
|
| 33 |
+
return hex.EncodeToString(b)
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// ChatCompletions 处理 POST /v1/chat/completions
|
| 37 |
+
func (h *OpenAIHandler) ChatCompletions(c *gin.Context) {
|
| 38 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 39 |
+
if err != nil {
|
| 40 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 41 |
+
return
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// 解析模型名以确定使用哪个服务
|
| 45 |
+
var req struct {
|
| 46 |
+
Model string `json:"model"`
|
| 47 |
+
}
|
| 48 |
+
if err := json.Unmarshal(body, &req); err != nil {
|
| 49 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 50 |
+
return
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// 根据模型的 ProviderID 分流
|
| 54 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 55 |
+
if !exists {
|
| 56 |
+
// 模型不存在,返回错误
|
| 57 |
+
h.handleError(c, service.ErrNoAvailableAccount)
|
| 58 |
+
return
|
| 59 |
+
}
|
| 60 |
+
if zenModel.ProviderID == "xai" {
|
| 61 |
+
// Grok 模型使用 xAI 服务
|
| 62 |
+
if err := h.grokSvc.ChatCompletionsProxy(c.Request.Context(), c.Writer, body); err != nil {
|
| 63 |
+
h.handleError(c, err)
|
| 64 |
+
}
|
| 65 |
+
return
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// 其他模型使用 OpenAI 服务
|
| 69 |
+
if err := h.svc.ChatCompletionsProxy(c.Request.Context(), c.Writer, body); err != nil {
|
| 70 |
+
h.handleError(c, err)
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// Responses 处理 POST /v1/responses
|
| 75 |
+
func (h *OpenAIHandler) Responses(c *gin.Context) {
|
| 76 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 77 |
+
if err != nil {
|
| 78 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 79 |
+
return
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if err := h.svc.ResponsesProxy(c.Request.Context(), c.Writer, body); err != nil {
|
| 83 |
+
h.handleError(c, err)
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// handleError 统一处理错误,特别是没有可用账号的错误
|
| 88 |
+
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
| 89 |
+
if errors.Is(err, service.ErrNoAvailableAccount) || errors.Is(err, service.ErrNoPermission) {
|
| 90 |
+
traceID := generateTraceID()
|
| 91 |
+
errMsg := fmt.Sprintf("没有可用token(traceid: %s)", traceID)
|
| 92 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": errMsg})
|
| 93 |
+
return
|
| 94 |
+
}
|
| 95 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 96 |
+
}
|
internal/handler/token.go
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package handler
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
"strconv"
|
| 6 |
+
"strings"
|
| 7 |
+
"time"
|
| 8 |
+
|
| 9 |
+
"github.com/gin-gonic/gin"
|
| 10 |
+
"zencoder-2api/internal/database"
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
"zencoder-2api/internal/service"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
type TokenHandler struct{}
|
| 16 |
+
|
| 17 |
+
func NewTokenHandler() *TokenHandler {
|
| 18 |
+
return &TokenHandler{}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// ListTokenRecords 获取所有token记录
|
| 22 |
+
func (h *TokenHandler) ListTokenRecords(c *gin.Context) {
|
| 23 |
+
records, err := service.GetAllTokenRecords()
|
| 24 |
+
if err != nil {
|
| 25 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 26 |
+
return
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// 获取每个token的生成任务统计
|
| 30 |
+
var enrichedRecords []map[string]interface{}
|
| 31 |
+
for _, record := range records {
|
| 32 |
+
// 统计该token的任务信息
|
| 33 |
+
var taskStats struct {
|
| 34 |
+
TotalTasks int64 `json:"total_tasks"`
|
| 35 |
+
TotalSuccess int64 `json:"total_success"`
|
| 36 |
+
TotalFail int64 `json:"total_fail"`
|
| 37 |
+
RunningTasks int64 `json:"running_tasks"`
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
db := database.GetDB()
|
| 41 |
+
db.Model(&model.GenerationTask{}).Where("token_record_id = ?", record.ID).Count(&taskStats.TotalTasks)
|
| 42 |
+
db.Model(&model.GenerationTask{}).Where("token_record_id = ?", record.ID).
|
| 43 |
+
Select("COALESCE(SUM(success_count), 0)").Scan(&taskStats.TotalSuccess)
|
| 44 |
+
db.Model(&model.GenerationTask{}).Where("token_record_id = ?", record.ID).
|
| 45 |
+
Select("COALESCE(SUM(fail_count), 0)").Scan(&taskStats.TotalFail)
|
| 46 |
+
db.Model(&model.GenerationTask{}).Where("token_record_id = ? AND status = ?", record.ID, "running").
|
| 47 |
+
Count(&taskStats.RunningTasks)
|
| 48 |
+
|
| 49 |
+
// 解析JWT获取用户信息
|
| 50 |
+
var email string
|
| 51 |
+
var planType string
|
| 52 |
+
var subscriptionStartDate time.Time
|
| 53 |
+
if record.Token != "" {
|
| 54 |
+
if payload, err := service.ParseJWT(record.Token); err == nil {
|
| 55 |
+
email = payload.Email
|
| 56 |
+
planType = payload.CustomClaims.Plan
|
| 57 |
+
if planType != "" {
|
| 58 |
+
planType = strings.ToUpper(planType[:1]) + planType[1:]
|
| 59 |
+
}
|
| 60 |
+
// 获取订阅开始时间
|
| 61 |
+
subscriptionStartDate = service.GetSubscriptionDate(payload)
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
enrichedRecord := map[string]interface{}{
|
| 66 |
+
"id": record.ID,
|
| 67 |
+
"description": record.Description,
|
| 68 |
+
"generated_count": record.GeneratedCount,
|
| 69 |
+
"last_generated_at": record.LastGeneratedAt,
|
| 70 |
+
"auto_generate": record.AutoGenerate,
|
| 71 |
+
"threshold": record.Threshold,
|
| 72 |
+
"generate_batch": record.GenerateBatch,
|
| 73 |
+
"is_active": record.IsActive,
|
| 74 |
+
"created_at": record.CreatedAt,
|
| 75 |
+
"updated_at": record.UpdatedAt,
|
| 76 |
+
"token_expiry": record.TokenExpiry,
|
| 77 |
+
"status": record.Status,
|
| 78 |
+
"ban_reason": record.BanReason,
|
| 79 |
+
"email": email,
|
| 80 |
+
"plan_type": planType,
|
| 81 |
+
"subscription_start_date": subscriptionStartDate,
|
| 82 |
+
"has_refresh_token": record.RefreshToken != "",
|
| 83 |
+
"total_tasks": taskStats.TotalTasks,
|
| 84 |
+
"total_success": taskStats.TotalSuccess,
|
| 85 |
+
"total_fail": taskStats.TotalFail,
|
| 86 |
+
"running_tasks": taskStats.RunningTasks,
|
| 87 |
+
}
|
| 88 |
+
enrichedRecords = append(enrichedRecords, enrichedRecord)
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
c.JSON(http.StatusOK, gin.H{
|
| 92 |
+
"items": enrichedRecords,
|
| 93 |
+
"total": len(enrichedRecords),
|
| 94 |
+
})
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// UpdateTokenRecord 更新token记录配置
|
| 98 |
+
func (h *TokenHandler) UpdateTokenRecord(c *gin.Context) {
|
| 99 |
+
id := c.Param("id")
|
| 100 |
+
tokenID, err := strconv.ParseUint(id, 10, 32)
|
| 101 |
+
if err != nil {
|
| 102 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
| 103 |
+
return
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
var req struct {
|
| 107 |
+
AutoGenerate *bool `json:"auto_generate"`
|
| 108 |
+
Threshold *int `json:"threshold"`
|
| 109 |
+
GenerateBatch *int `json:"generate_batch"`
|
| 110 |
+
IsActive *bool `json:"is_active"`
|
| 111 |
+
Description string `json:"description"`
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 115 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 116 |
+
return
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
updates := make(map[string]interface{})
|
| 120 |
+
if req.AutoGenerate != nil {
|
| 121 |
+
updates["auto_generate"] = *req.AutoGenerate
|
| 122 |
+
}
|
| 123 |
+
if req.Threshold != nil {
|
| 124 |
+
updates["threshold"] = *req.Threshold
|
| 125 |
+
}
|
| 126 |
+
if req.GenerateBatch != nil {
|
| 127 |
+
updates["generate_batch"] = *req.GenerateBatch
|
| 128 |
+
}
|
| 129 |
+
if req.IsActive != nil {
|
| 130 |
+
updates["is_active"] = *req.IsActive
|
| 131 |
+
}
|
| 132 |
+
if req.Description != "" {
|
| 133 |
+
updates["description"] = req.Description
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if err := service.UpdateTokenRecord(uint(tokenID), updates); err != nil {
|
| 137 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 138 |
+
return
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
c.JSON(http.StatusOK, gin.H{"message": "updated"})
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// GetGenerationTasks 获取生成任务历史
|
| 145 |
+
func (h *TokenHandler) GetGenerationTasks(c *gin.Context) {
|
| 146 |
+
tokenRecordID := c.Query("token_record_id")
|
| 147 |
+
var tokenID uint
|
| 148 |
+
if tokenRecordID != "" {
|
| 149 |
+
id, err := strconv.ParseUint(tokenRecordID, 10, 32)
|
| 150 |
+
if err != nil {
|
| 151 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid token_record_id"})
|
| 152 |
+
return
|
| 153 |
+
}
|
| 154 |
+
tokenID = uint(id)
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
tasks, err := service.GetGenerationTasks(tokenID)
|
| 158 |
+
if err != nil {
|
| 159 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 160 |
+
return
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
c.JSON(http.StatusOK, gin.H{
|
| 164 |
+
"items": tasks,
|
| 165 |
+
"total": len(tasks),
|
| 166 |
+
})
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// TriggerGeneration 手动触发生成
|
| 170 |
+
func (h *TokenHandler) TriggerGeneration(c *gin.Context) {
|
| 171 |
+
id := c.Param("id")
|
| 172 |
+
tokenID, err := strconv.ParseUint(id, 10, 32)
|
| 173 |
+
if err != nil {
|
| 174 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
| 175 |
+
return
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if err := service.ManualTriggerGeneration(uint(tokenID)); err != nil {
|
| 179 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 180 |
+
return
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
c.JSON(http.StatusOK, gin.H{"message": "生成任务已触发"})
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// GetPoolStatus 获取号池状态
|
| 187 |
+
func (h *TokenHandler) GetPoolStatus(c *gin.Context) {
|
| 188 |
+
db := database.GetDB()
|
| 189 |
+
|
| 190 |
+
var stats struct {
|
| 191 |
+
TotalAccounts int64 `json:"total_accounts"`
|
| 192 |
+
NormalAccounts int64 `json:"normal_accounts"`
|
| 193 |
+
CoolingAccounts int64 `json:"cooling_accounts"`
|
| 194 |
+
BannedAccounts int64 `json:"banned_accounts"`
|
| 195 |
+
ErrorAccounts int64 `json:"error_accounts"`
|
| 196 |
+
DisabledAccounts int64 `json:"disabled_accounts"`
|
| 197 |
+
ActiveTokens int64 `json:"active_tokens"`
|
| 198 |
+
RunningTasks int64 `json:"running_tasks"`
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// 统计账号状态
|
| 202 |
+
db.Model(&model.Account{}).Count(&stats.TotalAccounts)
|
| 203 |
+
db.Model(&model.Account{}).Where("status = ?", "normal").Count(&stats.NormalAccounts)
|
| 204 |
+
db.Model(&model.Account{}).Where("status = ?", "cooling").Count(&stats.CoolingAccounts)
|
| 205 |
+
db.Model(&model.Account{}).Where("status = ?", "banned").Count(&stats.BannedAccounts)
|
| 206 |
+
db.Model(&model.Account{}).Where("status = ?", "error").Count(&stats.ErrorAccounts)
|
| 207 |
+
db.Model(&model.Account{}).Where("status = ?", "disabled").Count(&stats.DisabledAccounts)
|
| 208 |
+
|
| 209 |
+
// 统计激活的token
|
| 210 |
+
db.Model(&model.TokenRecord{}).Where("is_active = ?", true).Count(&stats.ActiveTokens)
|
| 211 |
+
|
| 212 |
+
// 统计运行中的任务
|
| 213 |
+
db.Model(&model.GenerationTask{}).Where("status = ?", "running").Count(&stats.RunningTasks)
|
| 214 |
+
|
| 215 |
+
c.JSON(http.StatusOK, stats)
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// DeleteTokenRecord 删除token记录
|
| 219 |
+
func (h *TokenHandler) DeleteTokenRecord(c *gin.Context) {
|
| 220 |
+
id := c.Param("id")
|
| 221 |
+
tokenID, err := strconv.ParseUint(id, 10, 32)
|
| 222 |
+
if err != nil {
|
| 223 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
| 224 |
+
return
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// 开启事务,确保删除操作的原子性
|
| 228 |
+
db := database.GetDB()
|
| 229 |
+
tx := db.Begin()
|
| 230 |
+
|
| 231 |
+
// 先删除所有关联的生成任务历史记录
|
| 232 |
+
if err := tx.Where("token_record_id = ?", tokenID).Delete(&model.GenerationTask{}).Error; err != nil {
|
| 233 |
+
tx.Rollback()
|
| 234 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除关联任务失败: " + err.Error()})
|
| 235 |
+
return
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
// 删除token记录本身
|
| 239 |
+
if err := tx.Delete(&model.TokenRecord{}, tokenID).Error; err != nil {
|
| 240 |
+
tx.Rollback()
|
| 241 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除token记录失败: " + err.Error()})
|
| 242 |
+
return
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// 提交事务
|
| 246 |
+
if err := tx.Commit().Error; err != nil {
|
| 247 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "提交事务失败: " + err.Error()})
|
| 248 |
+
return
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
c.JSON(http.StatusOK, gin.H{"message": "token及其所有历史记录已删除"})
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// RefreshTokenRecord 刷新token记录
|
| 255 |
+
func (h *TokenHandler) RefreshTokenRecord(c *gin.Context) {
|
| 256 |
+
id := c.Param("id")
|
| 257 |
+
tokenID, err := strconv.ParseUint(id, 10, 32)
|
| 258 |
+
if err != nil {
|
| 259 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
| 260 |
+
return
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// 调用service层的刷新函数
|
| 264 |
+
if err := service.RefreshTokenAndAccounts(uint(tokenID)); err != nil {
|
| 265 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
| 266 |
+
return
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
c.JSON(http.StatusOK, gin.H{"message": "Token刷新成功,相关账号刷新已启动"})
|
| 270 |
+
}
|
internal/middleware/auth.go
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package middleware
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
"os"
|
| 6 |
+
"strings"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"zencoder-2api/internal/service"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
// LoggerMiddleware 为每个请求创建 logger 并在结束时 flush
|
| 13 |
+
func LoggerMiddleware() gin.HandlerFunc {
|
| 14 |
+
return func(c *gin.Context) {
|
| 15 |
+
logger := service.NewRequestLogger()
|
| 16 |
+
ctx := service.WithLogger(c.Request.Context(), logger)
|
| 17 |
+
c.Request = c.Request.WithContext(ctx)
|
| 18 |
+
|
| 19 |
+
c.Next()
|
| 20 |
+
|
| 21 |
+
// 请求结束时 flush 日志
|
| 22 |
+
logger.Flush()
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
func AuthMiddleware() gin.HandlerFunc {
|
| 27 |
+
// 从环境变量获取全局 Token
|
| 28 |
+
token := os.Getenv("AUTH_TOKEN")
|
| 29 |
+
|
| 30 |
+
return func(c *gin.Context) {
|
| 31 |
+
// 如果没有配置全局 Token,则跳过鉴权
|
| 32 |
+
if token == "" {
|
| 33 |
+
c.Next()
|
| 34 |
+
return
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// 1. 检查 OpenAI 格式: Authorization: Bearer <token>
|
| 38 |
+
authHeader := c.GetHeader("Authorization")
|
| 39 |
+
if authHeader != "" {
|
| 40 |
+
parts := strings.SplitN(authHeader, " ", 2)
|
| 41 |
+
if len(parts) == 2 && parts[0] == "Bearer" && parts[1] == token {
|
| 42 |
+
c.Next()
|
| 43 |
+
return
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// 2. 检查 Anthropic 格式: x-api-key: <token>
|
| 48 |
+
if c.GetHeader("x-api-key") == token {
|
| 49 |
+
c.Next()
|
| 50 |
+
return
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// 3. 检查 Gemini 格式: x-goog-api-key: <token> 或 query param key=<token>
|
| 54 |
+
if c.GetHeader("x-goog-api-key") == token {
|
| 55 |
+
c.Next()
|
| 56 |
+
return
|
| 57 |
+
}
|
| 58 |
+
if c.Query("key") == token {
|
| 59 |
+
c.Next()
|
| 60 |
+
return
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// 鉴权失败
|
| 64 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
| 65 |
+
"error": gin.H{
|
| 66 |
+
"message": "Invalid authentication token",
|
| 67 |
+
"type": "authentication_error",
|
| 68 |
+
},
|
| 69 |
+
})
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// AdminAuthMiddleware 后台管理密码验证中间件
|
| 74 |
+
func AdminAuthMiddleware() gin.HandlerFunc {
|
| 75 |
+
// 从环境变量获取后台管理密码
|
| 76 |
+
adminPassword := os.Getenv("ADMIN_PASSWORD")
|
| 77 |
+
|
| 78 |
+
return func(c *gin.Context) {
|
| 79 |
+
// 如果没有配置管理密码,则跳过鉴权
|
| 80 |
+
if adminPassword == "" {
|
| 81 |
+
c.Next()
|
| 82 |
+
return
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// 检查请求头中的管理密码
|
| 86 |
+
// 支持多种格式:
|
| 87 |
+
// 1. Authorization: Bearer <password>
|
| 88 |
+
// 2. X-Admin-Password: <password>
|
| 89 |
+
// 3. Admin-Password: <password>
|
| 90 |
+
|
| 91 |
+
var providedPassword string
|
| 92 |
+
|
| 93 |
+
// 检查 Authorization: Bearer <password>
|
| 94 |
+
authHeader := c.GetHeader("Authorization")
|
| 95 |
+
if authHeader != "" {
|
| 96 |
+
parts := strings.SplitN(authHeader, " ", 2)
|
| 97 |
+
if len(parts) == 2 && parts[0] == "Bearer" {
|
| 98 |
+
providedPassword = parts[1]
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// 检查 X-Admin-Password
|
| 103 |
+
if providedPassword == "" {
|
| 104 |
+
providedPassword = c.GetHeader("X-Admin-Password")
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// 检查 Admin-Password
|
| 108 |
+
if providedPassword == "" {
|
| 109 |
+
providedPassword = c.GetHeader("Admin-Password")
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// 验证密码
|
| 113 |
+
if providedPassword == adminPassword {
|
| 114 |
+
c.Next()
|
| 115 |
+
return
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// 鉴权失败
|
| 119 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
| 120 |
+
"error": gin.H{
|
| 121 |
+
"message": "Invalid admin password",
|
| 122 |
+
"type": "authentication_error",
|
| 123 |
+
},
|
| 124 |
+
})
|
| 125 |
+
}
|
| 126 |
+
}
|
internal/model/account.go
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"time"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
type PlanType string
|
| 8 |
+
|
| 9 |
+
const (
|
| 10 |
+
PlanFree PlanType = "Free"
|
| 11 |
+
PlanStarter PlanType = "Starter"
|
| 12 |
+
PlanCore PlanType = "Core"
|
| 13 |
+
PlanAdvanced PlanType = "Advanced"
|
| 14 |
+
PlanMax PlanType = "Max"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
// 每日积分限制
|
| 18 |
+
var PlanLimits = map[PlanType]int{
|
| 19 |
+
PlanFree: 30,
|
| 20 |
+
PlanStarter: 280,
|
| 21 |
+
PlanCore: 750,
|
| 22 |
+
PlanAdvanced: 1900,
|
| 23 |
+
PlanMax: 4200,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
type Account struct {
|
| 27 |
+
ID uint `json:"id" gorm:"primaryKey"`
|
| 28 |
+
ClientID string `json:"client_id" gorm:"uniqueIndex;not null"`
|
| 29 |
+
ClientSecret string `json:"-" gorm:"not null"` // 隐藏不传出
|
| 30 |
+
Email string `json:"email" gorm:"index"`
|
| 31 |
+
Category string `json:"category" gorm:"default:'normal';index"` // Deprecated: Use Status instead
|
| 32 |
+
Status string `json:"status" gorm:"default:'normal';index"` // normal, cooling, banned, error, disabled
|
| 33 |
+
PlanType PlanType `json:"plan_type" gorm:"default:'Free'"`
|
| 34 |
+
Proxy string `json:"proxy"`
|
| 35 |
+
AccessToken string `json:"-" gorm:"type:text"`
|
| 36 |
+
RefreshToken string `json:"-" gorm:"type:text"` // 用于刷新 AccessToken
|
| 37 |
+
TokenExpiry time.Time `json:"token_expiry"` // 传出token过期时间
|
| 38 |
+
CreditRefreshTime time.Time `json:"credit_refresh_time"` // 积分刷新时间(来自Zen-Pricing-Period-End)
|
| 39 |
+
IsActive bool `json:"is_active" gorm:"default:true"`
|
| 40 |
+
IsCooling bool `json:"is_cooling" gorm:"default:false"`
|
| 41 |
+
CoolingUntil time.Time `json:"cooling_until"` // 冷却结束时间
|
| 42 |
+
BanReason string `json:"ban_reason"` // 封禁/冷却原因
|
| 43 |
+
RateLimitHits int `json:"rate_limit_hits" gorm:"default:0"` // 429 错误次数
|
| 44 |
+
DailyUsed float64 `json:"daily_used" gorm:"default:0"`
|
| 45 |
+
TotalUsed float64 `json:"total_used" gorm:"default:0"`
|
| 46 |
+
LastResetDate string `json:"last_reset_date"`
|
| 47 |
+
SubscriptionStartDate time.Time `json:"subscription_start_date"`
|
| 48 |
+
LastUsed time.Time `json:"last_used"`
|
| 49 |
+
ErrorCount int `json:"error_count" gorm:"default:0"`
|
| 50 |
+
CreatedAt time.Time `json:"created_at"`
|
| 51 |
+
UpdatedAt time.Time `json:"updated_at"`
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
type AccountRequest struct {
|
| 55 |
+
ClientID string `json:"client_id"`
|
| 56 |
+
ClientSecret string `json:"client_secret"`
|
| 57 |
+
RefreshToken string `json:"refresh_token"` // Refresh token for authentication
|
| 58 |
+
Token string `json:"token"` // Deprecated: Use RefreshToken instead
|
| 59 |
+
Email string `json:"email"`
|
| 60 |
+
PlanType PlanType `json:"plan_type"`
|
| 61 |
+
Proxy string `json:"proxy"`
|
| 62 |
+
// Batch generation fields
|
| 63 |
+
GenerateMode bool `json:"generate_mode"` // true for batch generation mode
|
| 64 |
+
GenerateCount int `json:"generate_count"` // number of credentials to generate
|
| 65 |
+
}
|
internal/model/debug.go
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
// DebugLogModelMapping 输出模型映射查找日志
|
| 8 |
+
func DebugLogModelMapping(requestModel string, zenModel ZenModel, found bool) {
|
| 9 |
+
if found {
|
| 10 |
+
log.Printf("[DEBUG] [ModelMapping] ✓ 找到模型映射: request=%s → id=%s, model=%s, provider=%s, multiplier=%.1f",
|
| 11 |
+
requestModel, zenModel.ID, zenModel.Model, zenModel.ProviderID, zenModel.Multiplier)
|
| 12 |
+
if zenModel.Parameters != nil {
|
| 13 |
+
if zenModel.Parameters.Thinking != nil {
|
| 14 |
+
log.Printf("[DEBUG] [ModelMapping] └─ thinking: type=%s, budgetTokens=%d",
|
| 15 |
+
zenModel.Parameters.Thinking.Type, zenModel.Parameters.Thinking.BudgetTokens)
|
| 16 |
+
}
|
| 17 |
+
if zenModel.Parameters.ExtraHeaders != nil {
|
| 18 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 19 |
+
log.Printf("[DEBUG] [ModelMapping] └─ extraHeader: %s=%s", k, v)
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
if zenModel.Parameters.ForceStreaming != nil && *zenModel.Parameters.ForceStreaming {
|
| 23 |
+
log.Printf("[DEBUG] [ModelMapping] └─ forceStreaming: true")
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
} else {
|
| 27 |
+
log.Printf("[DEBUG] [ModelMapping] ✗ 未找到模型映射: request=%s, 使用默认配置", requestModel)
|
| 28 |
+
}
|
| 29 |
+
}
|
internal/model/openai.go
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
// OpenAI compatible request/response structures
|
| 4 |
+
|
| 5 |
+
type ChatCompletionRequest struct {
|
| 6 |
+
Model string `json:"model"`
|
| 7 |
+
Messages []ChatMessage `json:"messages"`
|
| 8 |
+
MaxTokens int `json:"max_tokens,omitempty"`
|
| 9 |
+
Temperature float64 `json:"temperature,omitempty"`
|
| 10 |
+
TopP float64 `json:"top_p,omitempty"`
|
| 11 |
+
Stream bool `json:"stream,omitempty"`
|
| 12 |
+
Stop []string `json:"stop,omitempty"`
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
type ChatMessage struct {
|
| 16 |
+
Role string `json:"role"`
|
| 17 |
+
Content string `json:"content"`
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
type ChatCompletionResponse struct {
|
| 21 |
+
ID string `json:"id"`
|
| 22 |
+
Object string `json:"object"`
|
| 23 |
+
Created int64 `json:"created"`
|
| 24 |
+
Model string `json:"model"`
|
| 25 |
+
Choices []Choice `json:"choices"`
|
| 26 |
+
Usage Usage `json:"usage"`
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
type Choice struct {
|
| 30 |
+
Index int `json:"index"`
|
| 31 |
+
Message ChatMessage `json:"message"`
|
| 32 |
+
FinishReason string `json:"finish_reason"`
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
type Usage struct {
|
| 36 |
+
PromptTokens int `json:"prompt_tokens"`
|
| 37 |
+
CompletionTokens int `json:"completion_tokens"`
|
| 38 |
+
TotalTokens int `json:"total_tokens"`
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
type StreamChoice struct {
|
| 42 |
+
Index int `json:"index"`
|
| 43 |
+
Delta ChatMessage `json:"delta"`
|
| 44 |
+
FinishReason *string `json:"finish_reason"`
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
type ChatCompletionChunk struct {
|
| 48 |
+
ID string `json:"id"`
|
| 49 |
+
Object string `json:"object"`
|
| 50 |
+
Created int64 `json:"created"`
|
| 51 |
+
Model string `json:"model"`
|
| 52 |
+
Choices []StreamChoice `json:"choices"`
|
| 53 |
+
}
|
internal/model/token_record.go
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"time"
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
// TokenRecord 记录生成账号时使用的token
|
| 8 |
+
type TokenRecord struct {
|
| 9 |
+
ID uint `json:"id" gorm:"primaryKey"`
|
| 10 |
+
Token string `json:"token" gorm:"type:text"` // 当前的access token(通过refresh_token生成)
|
| 11 |
+
RefreshToken string `json:"refresh_token" gorm:"type:text"` // 用于刷新token的refresh_token,可以为空
|
| 12 |
+
TokenExpiry time.Time `json:"token_expiry"` // access token过期时间
|
| 13 |
+
Description string `json:"description"` // token描述
|
| 14 |
+
Email string `json:"email"` // 账号邮箱(从JWT解析)
|
| 15 |
+
PlanType string `json:"plan_type"` // 订阅等级(从JWT解析)
|
| 16 |
+
SubscriptionStartDate time.Time `json:"subscription_start_date"` // 订阅开始时间(从JWT解析)
|
| 17 |
+
GeneratedCount int `json:"generated_count" gorm:"default:0"` // 已生成账号总数
|
| 18 |
+
LastGeneratedAt time.Time `json:"last_generated_at"` // 最后生成时间
|
| 19 |
+
AutoGenerate bool `json:"auto_generate" gorm:"default:true"` // 是否自动生成
|
| 20 |
+
Threshold int `json:"threshold" gorm:"default:10"` // 触发自动生成的阈值
|
| 21 |
+
GenerateBatch int `json:"generate_batch" gorm:"default:30"` // 每批生成数量
|
| 22 |
+
IsActive bool `json:"is_active" gorm:"default:true"` // 是否激活
|
| 23 |
+
Status string `json:"status" gorm:"default:'active'"` // token状态: active, banned, expired, disabled
|
| 24 |
+
BanReason string `json:"ban_reason"` // 封禁原因
|
| 25 |
+
HasRefreshToken bool `json:"has_refresh_token" gorm:"default:false"` // 是否有refresh_token
|
| 26 |
+
TotalSuccess int `json:"total_success" gorm:"default:0"` // 总成功数
|
| 27 |
+
TotalFail int `json:"total_fail" gorm:"default:0"` // 总失败数
|
| 28 |
+
TotalTasks int `json:"total_tasks" gorm:"default:0"` // 总任务数
|
| 29 |
+
RunningTasks int `json:"running_tasks" gorm:"-"` // 运行中的任务数(不存储在数据库)
|
| 30 |
+
CreatedAt time.Time `json:"created_at"`
|
| 31 |
+
UpdatedAt time.Time `json:"updated_at"`
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// GenerationTask 生成任务记录
|
| 35 |
+
type GenerationTask struct {
|
| 36 |
+
ID uint `json:"id" gorm:"primaryKey"`
|
| 37 |
+
TokenRecordID uint `json:"token_record_id" gorm:"index;not null"`
|
| 38 |
+
Token string `json:"-" gorm:"type:text"` // 实际使用的token
|
| 39 |
+
BatchSize int `json:"batch_size"` // 批次大小
|
| 40 |
+
SuccessCount int `json:"success_count" gorm:"default:0"` // 成功数量
|
| 41 |
+
FailCount int `json:"fail_count" gorm:"default:0"` // 失败数量
|
| 42 |
+
Status string `json:"status" gorm:"default:'pending'"` // pending, running, completed, failed
|
| 43 |
+
StartedAt time.Time `json:"started_at"`
|
| 44 |
+
CompletedAt time.Time `json:"completed_at"`
|
| 45 |
+
ErrorMessage string `json:"error_message"`
|
| 46 |
+
CreatedAt time.Time `json:"created_at"`
|
| 47 |
+
UpdatedAt time.Time `json:"updated_at"`
|
| 48 |
+
}
|
internal/model/zenmodel.go
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
// ThinkingConfig thinking模式配置
|
| 4 |
+
type ThinkingConfig struct {
|
| 5 |
+
Type string `json:"type"`
|
| 6 |
+
BudgetTokens int `json:"budgetTokens"`
|
| 7 |
+
Signature string `json:"signature,omitempty"`
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
// ReasoningConfig OpenAI reasoning配置
|
| 11 |
+
type ReasoningConfig struct {
|
| 12 |
+
Effort string `json:"effort"`
|
| 13 |
+
Summary string `json:"summary,omitempty"`
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
// TextConfig OpenAI text配置
|
| 17 |
+
type TextConfig struct {
|
| 18 |
+
Verbosity string `json:"verbosity"`
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// ModelParameters 模型参数配置
|
| 22 |
+
type ModelParameters struct {
|
| 23 |
+
Temperature *float64 `json:"temperature,omitempty"`
|
| 24 |
+
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
| 25 |
+
Reasoning *ReasoningConfig `json:"reasoning,omitempty"`
|
| 26 |
+
Text *TextConfig `json:"text,omitempty"`
|
| 27 |
+
ExtraHeaders map[string]string `json:"extraHeaders,omitempty"`
|
| 28 |
+
ForceStreaming *bool `json:"forceStreaming,omitempty"`
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
type ZenModel struct {
|
| 32 |
+
ID string `json:"id"`
|
| 33 |
+
DisplayName string `json:"displayName"`
|
| 34 |
+
Model string `json:"model"`
|
| 35 |
+
Multiplier float64 `json:"multiplier"`
|
| 36 |
+
ProviderID string `json:"providerId"`
|
| 37 |
+
Parameters *ModelParameters `json:"parameters,omitempty"`
|
| 38 |
+
IsHidden bool `json:"isHidden"`
|
| 39 |
+
PremiumOnly bool `json:"premiumOnly"` // 仅Advanced/Max可用
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// 辅助变量
|
| 43 |
+
var (
|
| 44 |
+
temp0 = 0.0
|
| 45 |
+
temp1 = 1.0
|
| 46 |
+
forceStream = true
|
| 47 |
+
|
| 48 |
+
// Thinking模式参数
|
| 49 |
+
thinkingParams = &ModelParameters{
|
| 50 |
+
Temperature: &temp1,
|
| 51 |
+
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 4096},
|
| 52 |
+
ExtraHeaders: map[string]string{
|
| 53 |
+
"anthropic-beta": "interleaved-thinking-2025-05-14",
|
| 54 |
+
},
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// OpenAI reasoning参数
|
| 58 |
+
openaiParams = &ModelParameters{
|
| 59 |
+
Temperature: &temp1,
|
| 60 |
+
Reasoning: &ReasoningConfig{Effort: "medium", Summary: "auto"},
|
| 61 |
+
Text: &TextConfig{Verbosity: "medium"},
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
// 模型映射表
|
| 66 |
+
var ZenModels = map[string]ZenModel{
|
| 67 |
+
// Anthropic Models - Thinking模式(通过ID访问)
|
| 68 |
+
"claude-haiku-4-5-20251001-thinking": {
|
| 69 |
+
ID: "haiku-4-5-think", DisplayName: "Haiku 4.5 Parallel Thinking",
|
| 70 |
+
Model: "claude-haiku-4-5-20251001", Multiplier: 1, ProviderID: "anthropic",
|
| 71 |
+
Parameters: thinkingParams,
|
| 72 |
+
},
|
| 73 |
+
"claude-sonnet-4-20250514-thinking": {
|
| 74 |
+
ID: "sonnet-4-think", DisplayName: "Sonnet 4 Parallel Thinking",
|
| 75 |
+
Model: "claude-sonnet-4-20250514", Multiplier: 3, ProviderID: "anthropic",
|
| 76 |
+
Parameters: thinkingParams,
|
| 77 |
+
IsHidden: true,
|
| 78 |
+
},
|
| 79 |
+
"claude-sonnet-4-5-20250929-thinking": {
|
| 80 |
+
ID: "sonnet-4-5-think", DisplayName: "Sonnet 4.5 Parallel Thinking",
|
| 81 |
+
Model: "claude-sonnet-4-5-20250929", Multiplier: 3, ProviderID: "anthropic",
|
| 82 |
+
Parameters: thinkingParams,
|
| 83 |
+
},
|
| 84 |
+
"claude-opus-4-1-20250805-thinking": {
|
| 85 |
+
ID: "opus-4-think", DisplayName: "Opus 4.1 Parallel Thinking",
|
| 86 |
+
Model: "claude-opus-4-1-20250805", Multiplier: 15, ProviderID: "anthropic",
|
| 87 |
+
PremiumOnly: true,
|
| 88 |
+
Parameters: &ModelParameters{
|
| 89 |
+
Temperature: &temp1,
|
| 90 |
+
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 4096},
|
| 91 |
+
ExtraHeaders: map[string]string{"anthropic-beta": "interleaved-thinking-2025-05-14"},
|
| 92 |
+
ForceStreaming: &forceStream,
|
| 93 |
+
},
|
| 94 |
+
IsHidden: true,
|
| 95 |
+
},
|
| 96 |
+
"claude-opus-4-5-20251101-thinking": {
|
| 97 |
+
ID: "opus-4-5-think", DisplayName: "Opus 4.5 Parallel Thinking",
|
| 98 |
+
Model: "claude-opus-4-5-20251101", Multiplier: 5, ProviderID: "anthropic",
|
| 99 |
+
PremiumOnly: true,
|
| 100 |
+
Parameters: &ModelParameters{
|
| 101 |
+
Temperature: &temp1,
|
| 102 |
+
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 4096},
|
| 103 |
+
ExtraHeaders: map[string]string{"anthropic-beta": "interleaved-thinking-2025-05-14"},
|
| 104 |
+
ForceStreaming: &forceStream,
|
| 105 |
+
},
|
| 106 |
+
},
|
| 107 |
+
// Anthropic Models - 标准模式(不带 Thinking)
|
| 108 |
+
"claude-sonnet-4-20250514": {
|
| 109 |
+
ID: "sonnet-4", DisplayName: "Sonnet 4",
|
| 110 |
+
Model: "claude-sonnet-4-20250514", Multiplier: 2, ProviderID: "anthropic",
|
| 111 |
+
},
|
| 112 |
+
"claude-sonnet-4-5-20250929": {
|
| 113 |
+
ID: "sonnet-4-5", DisplayName: "Sonnet 4.5",
|
| 114 |
+
Model: "claude-sonnet-4-5-20250929", Multiplier: 2, ProviderID: "anthropic",
|
| 115 |
+
},
|
| 116 |
+
"claude-opus-4-1-20250805": {
|
| 117 |
+
ID: "opus-4", DisplayName: "Opus 4.1",
|
| 118 |
+
Model: "claude-opus-4-1-20250805", Multiplier: 10, ProviderID: "anthropic",
|
| 119 |
+
PremiumOnly: true,
|
| 120 |
+
Parameters: &ModelParameters{ForceStreaming: &forceStream},
|
| 121 |
+
},
|
| 122 |
+
"claude-opus-4-5-20251101": { //非原生实现
|
| 123 |
+
ID: "opus-4-5-think", DisplayName: "Opus 4.5 Parallel Thinking",
|
| 124 |
+
Model: "claude-opus-4-5-20251101", Multiplier: 5, ProviderID: "anthropic",
|
| 125 |
+
PremiumOnly: true,
|
| 126 |
+
Parameters: &ModelParameters{
|
| 127 |
+
Temperature: &temp1,
|
| 128 |
+
Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 4096},
|
| 129 |
+
ExtraHeaders: map[string]string{"anthropic-beta": "interleaved-thinking-2025-05-14"},
|
| 130 |
+
ForceStreaming: &forceStream,
|
| 131 |
+
},
|
| 132 |
+
},
|
| 133 |
+
"claude-haiku-4-5-20251001": { //非原生实现
|
| 134 |
+
ID: "haiku-4-5-think", DisplayName: "Haiku 4.5 Parallel Thinking",
|
| 135 |
+
Model: "claude-haiku-4-5-20251001", Multiplier: 1, ProviderID: "anthropic",
|
| 136 |
+
Parameters: thinkingParams,
|
| 137 |
+
},
|
| 138 |
+
// Gemini Models
|
| 139 |
+
"gemini-3-pro-preview": {
|
| 140 |
+
ID: "gemini-3-pro-preview", DisplayName: "Gemini Pro 3.0",
|
| 141 |
+
Model: "gemini-3-pro-preview", Multiplier: 2, ProviderID: "gemini",
|
| 142 |
+
Parameters: &ModelParameters{Temperature: &temp1},
|
| 143 |
+
},
|
| 144 |
+
"gemini-3-flash-preview": {
|
| 145 |
+
ID: "gemini-3-flash-preview", DisplayName: "Gemini Flash 3.0",
|
| 146 |
+
Model: "gemini-3-flash-preview", Multiplier: 1, ProviderID: "gemini",
|
| 147 |
+
Parameters: &ModelParameters{Temperature: &temp1},
|
| 148 |
+
IsHidden: true,
|
| 149 |
+
},
|
| 150 |
+
|
| 151 |
+
// OpenAI Models
|
| 152 |
+
"gpt-5.1-codex-mini": {
|
| 153 |
+
ID: "gpt-5-1-codex-mini", DisplayName: "GPT-5.1 Codex mini",
|
| 154 |
+
Model: "gpt-5.1-codex-mini", Multiplier: 0.5, ProviderID: "openai",
|
| 155 |
+
Parameters: openaiParams,
|
| 156 |
+
},
|
| 157 |
+
"gpt-5.1-codex": {
|
| 158 |
+
ID: "gpt-5-1-codex-medium", DisplayName: "GPT-5.1 Codex",
|
| 159 |
+
Model: "gpt-5.1-codex", Multiplier: 1, ProviderID: "openai",
|
| 160 |
+
Parameters: openaiParams,
|
| 161 |
+
IsHidden: true,
|
| 162 |
+
},
|
| 163 |
+
"gpt-5.1-codex-max": {
|
| 164 |
+
ID: "gpt-5-1-codex-max", DisplayName: "GPT-5.1 Codex Max",
|
| 165 |
+
Model: "gpt-5.1-codex-max", Multiplier: 1.5, ProviderID: "openai",
|
| 166 |
+
Parameters: openaiParams,
|
| 167 |
+
},
|
| 168 |
+
"gpt-5.2-codex": {
|
| 169 |
+
ID: "gpt-5-2-codex", DisplayName: "GPT-5.2 Codex",
|
| 170 |
+
Model: "gpt-5.2-codex", Multiplier: 2, ProviderID: "openai",
|
| 171 |
+
Parameters: openaiParams,
|
| 172 |
+
},
|
| 173 |
+
"gpt-5-2025-08-07": {
|
| 174 |
+
ID: "gpt-5-medium", DisplayName: "GPT-5",
|
| 175 |
+
Model: "gpt-5-2025-08-07", Multiplier: 1, ProviderID: "openai",
|
| 176 |
+
Parameters: openaiParams,
|
| 177 |
+
IsHidden: true,
|
| 178 |
+
},
|
| 179 |
+
"gpt-5-codex": {
|
| 180 |
+
ID: "gpt-5-codex-medium", DisplayName: "GPT-5-Codex",
|
| 181 |
+
Model: "gpt-5-codex", Multiplier: 1, ProviderID: "openai",
|
| 182 |
+
Parameters: openaiParams,
|
| 183 |
+
IsHidden: true,
|
| 184 |
+
},
|
| 185 |
+
|
| 186 |
+
// xAI Models
|
| 187 |
+
"grok-code-fast-1": {
|
| 188 |
+
ID: "grok-code-fast", DisplayName: "Grok Code Fast 1",
|
| 189 |
+
Model: "grok-code-fast-1", Multiplier: 0.25, ProviderID: "xai",
|
| 190 |
+
Parameters: &ModelParameters{Temperature: &temp0},
|
| 191 |
+
},
|
| 192 |
+
|
| 193 |
+
// Utility Models
|
| 194 |
+
"gpt-5-nano-2025-08-07": {
|
| 195 |
+
ID: "generate-name-v2", DisplayName: "Cheap model for generating names",
|
| 196 |
+
Model: "gpt-5-nano-2025-08-07", Multiplier: 0, ProviderID: "openai",
|
| 197 |
+
Parameters: &ModelParameters{
|
| 198 |
+
Reasoning: &ReasoningConfig{Effort: "minimal"},
|
| 199 |
+
},
|
| 200 |
+
},
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
// GetZenModel 获取模型配置,如果不存在则返回空模型和false
|
| 204 |
+
func GetZenModel(modelID string) (ZenModel, bool) {
|
| 205 |
+
if m, ok := ZenModels[modelID]; ok {
|
| 206 |
+
return m, true
|
| 207 |
+
}
|
| 208 |
+
// 模型不存在,返回空模型和false
|
| 209 |
+
return ZenModel{}, false
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// CanUseModel 检查订阅类型是否可以使用指定模型
|
| 213 |
+
func CanUseModel(planType PlanType, modelID string) bool {
|
| 214 |
+
zenModel, _ := GetZenModel(modelID)
|
| 215 |
+
|
| 216 |
+
// Advanced和Max可以使用所有模型
|
| 217 |
+
if planType == PlanAdvanced || planType == PlanMax {
|
| 218 |
+
return true
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// 其他订阅类型不能使用PremiumOnly模型
|
| 222 |
+
return !zenModel.PremiumOnly
|
| 223 |
+
}
|
internal/service/anthropic.go
ADDED
|
@@ -0,0 +1,1602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"bytes"
|
| 6 |
+
"context"
|
| 7 |
+
"encoding/json"
|
| 8 |
+
"fmt"
|
| 9 |
+
"io"
|
| 10 |
+
"log"
|
| 11 |
+
"math/rand"
|
| 12 |
+
"net/http"
|
| 13 |
+
"strings"
|
| 14 |
+
"time"
|
| 15 |
+
|
| 16 |
+
"zencoder-2api/internal/model"
|
| 17 |
+
"zencoder-2api/internal/service/provider"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
// sanitizeRequestBody 清理请求体中的敏感信息,保留结构但替换内容
|
| 21 |
+
func sanitizeRequestBody(body []byte) string {
|
| 22 |
+
var reqMap map[string]interface{}
|
| 23 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 24 |
+
return string(body) // 如果解析失败,返回原始内容
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// 处理messages数组
|
| 28 |
+
if messages, ok := reqMap["messages"].([]interface{}); ok {
|
| 29 |
+
for i, msg := range messages {
|
| 30 |
+
if msgMap, ok := msg.(map[string]interface{}); ok {
|
| 31 |
+
// 处理content字段
|
| 32 |
+
if content, exists := msgMap["content"]; exists {
|
| 33 |
+
// content可能是字符串或数组
|
| 34 |
+
switch c := content.(type) {
|
| 35 |
+
case string:
|
| 36 |
+
// 如果是字符串,直接替换
|
| 37 |
+
msgMap["content"] = "Content omitted"
|
| 38 |
+
case []interface{}:
|
| 39 |
+
// 如果是数组(结构化内容),保留结构但替换文本
|
| 40 |
+
for j, block := range c {
|
| 41 |
+
if blockMap, ok := block.(map[string]interface{}); ok {
|
| 42 |
+
// 保留type字段
|
| 43 |
+
if blockType, hasType := blockMap["type"]; hasType {
|
| 44 |
+
// 根据type处理不同的内容块
|
| 45 |
+
switch blockType {
|
| 46 |
+
case "text":
|
| 47 |
+
// 替换text内容
|
| 48 |
+
blockMap["text"] = "Content omitted"
|
| 49 |
+
case "thinking", "redacted_thinking":
|
| 50 |
+
// thinking块:替换thinking内容
|
| 51 |
+
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
| 52 |
+
blockMap["thinking"] = "Content omitted"
|
| 53 |
+
}
|
| 54 |
+
// 保留signature字段不变
|
| 55 |
+
case "image":
|
| 56 |
+
// 图片块:清理source内容
|
| 57 |
+
if source, hasSource := blockMap["source"]; hasSource {
|
| 58 |
+
if sourceMap, ok := source.(map[string]interface{}); ok {
|
| 59 |
+
// 保留类型但清理数据
|
| 60 |
+
if _, hasData := sourceMap["data"]; hasData {
|
| 61 |
+
sourceMap["data"] = "Image data omitted"
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
case "tool_use":
|
| 66 |
+
// 工具使用块:清理input内容
|
| 67 |
+
if _, hasInput := blockMap["input"]; hasInput {
|
| 68 |
+
blockMap["input"] = map[string]interface{}{
|
| 69 |
+
"note": "Tool input omitted",
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
case "tool_result":
|
| 73 |
+
// 工具结果块:清理content内容
|
| 74 |
+
if _, hasContent := blockMap["content"]; hasContent {
|
| 75 |
+
blockMap["content"] = "Tool result omitted"
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
c[j] = blockMap
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
msgMap["content"] = c
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
messages[i] = msgMap
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
reqMap["messages"] = messages
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// 处理tools字段 - 改为空数组
|
| 92 |
+
if _, hasTools := reqMap["tools"]; hasTools {
|
| 93 |
+
reqMap["tools"] = []interface{}{}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// 处理system字段 - 替换为固定文本
|
| 97 |
+
if _, hasSystem := reqMap["system"]; hasSystem {
|
| 98 |
+
reqMap["system"] = "System prompt omitted"
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// 序列化为JSON字符串
|
| 102 |
+
sanitized, _ := json.MarshalIndent(reqMap, "", " ")
|
| 103 |
+
return string(sanitized)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// logRequestDetails 记录请求详细信息
|
| 107 |
+
func logRequestDetails(prefix string, headers http.Header, body []byte) {
|
| 108 |
+
log.Printf("%s 请求详情:", prefix)
|
| 109 |
+
|
| 110 |
+
// 记录请求头
|
| 111 |
+
log.Printf("%s 请求头:", prefix)
|
| 112 |
+
for k, v := range headers {
|
| 113 |
+
// 过滤敏感请求头
|
| 114 |
+
if strings.Contains(strings.ToLower(k), "auth") ||
|
| 115 |
+
strings.Contains(strings.ToLower(k), "key") ||
|
| 116 |
+
strings.Contains(strings.ToLower(k), "token") {
|
| 117 |
+
log.Printf(" %s: [REDACTED]", k)
|
| 118 |
+
} else {
|
| 119 |
+
log.Printf(" %s: %s", k, strings.Join(v, ", "))
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// 记录请求体(已清理敏感信息)
|
| 124 |
+
log.Printf("%s 请求体 (已清理):", prefix)
|
| 125 |
+
log.Printf("%s", sanitizeRequestBody(body))
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
const AnthropicBaseURL = "https://api.zencoder.ai/anthropic"
|
| 129 |
+
|
| 130 |
+
type AnthropicService struct{}
|
| 131 |
+
|
| 132 |
+
func NewAnthropicService() *AnthropicService {
|
| 133 |
+
return &AnthropicService{}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// Messages 处理/v1/messages请求,直接透传到Anthropic API
|
| 137 |
+
func (s *AnthropicService) Messages(ctx context.Context, body []byte, isStream bool) (*http.Response, error) {
|
| 138 |
+
var req struct {
|
| 139 |
+
Model string `json:"model"`
|
| 140 |
+
MaxTokens float64 `json:"max_tokens,omitempty"`
|
| 141 |
+
Thinking map[string]interface{} `json:"thinking,omitempty"`
|
| 142 |
+
}
|
| 143 |
+
if err := json.Unmarshal(body, &req); err != nil {
|
| 144 |
+
return nil, fmt.Errorf("invalid request body: %w", err)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// 记录请求的模型和thinking状态
|
| 148 |
+
thinkingStatus := "disabled"
|
| 149 |
+
if req.Thinking != nil {
|
| 150 |
+
if enabled, ok := req.Thinking["enabled"].(bool); ok && enabled {
|
| 151 |
+
thinkingStatus = "enabled"
|
| 152 |
+
} else if thinkingType, ok := req.Thinking["type"].(string); ok && thinkingType == "enabled" {
|
| 153 |
+
thinkingStatus = "enabled"
|
| 154 |
+
}
|
| 155 |
+
// 如果有thinking配置且有budget_tokens,也记录
|
| 156 |
+
if budget, ok := req.Thinking["budget_tokens"].(float64); ok && budget > 0 {
|
| 157 |
+
thinkingStatus = fmt.Sprintf("enabled(budget=%g)", budget)
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
// 只在非限速测试时输出请求信息
|
| 161 |
+
if IsDebugMode() && !strings.Contains(req.Model, "test") {
|
| 162 |
+
log.Printf("[Anthropic] 请求 - Model: %s, Thinking: %s", req.Model, thinkingStatus)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// 检查是否需要映射到对应的thinking模型
|
| 166 |
+
originalModel := req.Model
|
| 167 |
+
if req.Thinking != nil {
|
| 168 |
+
// 检查是否开启了thinking
|
| 169 |
+
thinkingEnabled := false
|
| 170 |
+
if enabled, ok := req.Thinking["enabled"].(bool); ok && enabled {
|
| 171 |
+
thinkingEnabled = true
|
| 172 |
+
} else if thinkingType, ok := req.Thinking["type"].(string); ok && thinkingType == "enabled" {
|
| 173 |
+
thinkingEnabled = true
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
if thinkingEnabled {
|
| 177 |
+
// 检查是否存在对应的thinking模型
|
| 178 |
+
thinkingModelID := req.Model + "-thinking"
|
| 179 |
+
if _, exists := model.GetZenModel(thinkingModelID); exists {
|
| 180 |
+
req.Model = thinkingModelID
|
| 181 |
+
DebugLog(ctx, "[Anthropic] 映射到thinking模型: %s -> %s", originalModel, req.Model)
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// 检查模型是否存在于模型字典中
|
| 187 |
+
_, exists := model.GetZenModel(req.Model)
|
| 188 |
+
if !exists {
|
| 189 |
+
DebugLog(ctx, "[Anthropic] 模型不存在: %s", req.Model)
|
| 190 |
+
return nil, ErrNoAvailableAccount
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
DebugLogRequest(ctx, "Anthropic", "/v1/messages", req.Model)
|
| 194 |
+
|
| 195 |
+
// 处理max_tokens和thinking.budget_tokens的关系
|
| 196 |
+
// 如果用户传入了thinking配置,检查并调整max_tokens
|
| 197 |
+
if req.Thinking != nil {
|
| 198 |
+
budgetTokens := 0.0
|
| 199 |
+
if budget, ok := req.Thinking["budget_tokens"].(float64); ok {
|
| 200 |
+
budgetTokens = budget
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
// 如果max_tokens小于等于budget_tokens,调整max_tokens
|
| 204 |
+
if budgetTokens > 0 && req.MaxTokens > 0 && req.MaxTokens <= budgetTokens {
|
| 205 |
+
// 按用户要求:max_tokens = max_tokens + budget_tokens
|
| 206 |
+
newMaxTokens := req.MaxTokens + budgetTokens
|
| 207 |
+
|
| 208 |
+
// 修改原始请求体中的max_tokens
|
| 209 |
+
var reqMap map[string]interface{}
|
| 210 |
+
if err := json.Unmarshal(body, &reqMap); err == nil {
|
| 211 |
+
reqMap["max_tokens"] = newMaxTokens
|
| 212 |
+
if modifiedBody, err := json.Marshal(reqMap); err == nil {
|
| 213 |
+
body = modifiedBody
|
| 214 |
+
DebugLog(ctx, "[Anthropic] 调整max_tokens: %.0f -> %.0f (原值+budget_tokens)", req.MaxTokens, newMaxTokens)
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
var lastErr error
|
| 221 |
+
for i := 0; i < MaxRetries; i++ {
|
| 222 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 223 |
+
if err != nil {
|
| 224 |
+
DebugLogRequestEnd(ctx, "Anthropic", false, err)
|
| 225 |
+
return nil, err
|
| 226 |
+
}
|
| 227 |
+
DebugLogAccountSelected(ctx, "Anthropic", account.ID, account.Email)
|
| 228 |
+
|
| 229 |
+
resp, err := s.doRequest(ctx, account, req.Model, body)
|
| 230 |
+
if err != nil {
|
| 231 |
+
// 请求失败,释放账号
|
| 232 |
+
ReleaseAccount(account)
|
| 233 |
+
// MarkAccountError(account)
|
| 234 |
+
lastErr = err
|
| 235 |
+
DebugLogRetry(ctx, "Anthropic", i+1, account.ID, err)
|
| 236 |
+
continue
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
// 只在调试模式下且非限速测试时输出详细响应信息
|
| 240 |
+
if IsDebugMode() && !strings.Contains(req.Model, "test") {
|
| 241 |
+
DebugLogResponseReceived(ctx, "Anthropic", resp.StatusCode)
|
| 242 |
+
|
| 243 |
+
// 只输出积分信息,不输出所有响应头
|
| 244 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 245 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 246 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 247 |
+
DebugLog(ctx, "[Anthropic] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 248 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 249 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 250 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
if resp.StatusCode >= 400 {
|
| 255 |
+
// 读取错误响应内容
|
| 256 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 257 |
+
resp.Body.Close()
|
| 258 |
+
|
| 259 |
+
// 检查是否是官方API直接抛出的错误(413、400、429)
|
| 260 |
+
// 这些错误不是token池问题,应直接返回给客户端
|
| 261 |
+
if resp.StatusCode == 413 || resp.StatusCode == 400 || resp.StatusCode == 429 {
|
| 262 |
+
// 对于400错误,根据错误类型决定日志级别
|
| 263 |
+
if resp.StatusCode == 400 {
|
| 264 |
+
// 解析thinking状态用于日志
|
| 265 |
+
thinkingStatus := "disabled"
|
| 266 |
+
if req.Thinking != nil {
|
| 267 |
+
if enabled, ok := req.Thinking["enabled"].(bool); ok && enabled {
|
| 268 |
+
thinkingStatus = "enabled"
|
| 269 |
+
} else if thinkingType, ok := req.Thinking["type"].(string); ok && thinkingType == "enabled" {
|
| 270 |
+
thinkingStatus = "enabled"
|
| 271 |
+
}
|
| 272 |
+
// 如果有thinking配置且有budget_tokens,也记录
|
| 273 |
+
if budget, ok := req.Thinking["budget_tokens"].(float64); ok && budget > 0 {
|
| 274 |
+
thinkingStatus = fmt.Sprintf("enabled(budget=%g)", budget)
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// 尝试解析错误类型
|
| 279 |
+
var errResp struct {
|
| 280 |
+
Error struct {
|
| 281 |
+
Type string `json:"type"`
|
| 282 |
+
Message string `json:"message"`
|
| 283 |
+
} `json:"error"`
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
isKnownError := false
|
| 287 |
+
isPromptTooLongError := false
|
| 288 |
+
if err := json.Unmarshal(errBody, &errResp); err == nil && errResp.Error.Type != "" {
|
| 289 |
+
// 检查是否是已知的错误类型
|
| 290 |
+
knownErrors := []string{
|
| 291 |
+
"prompt is too long",
|
| 292 |
+
"max_tokens",
|
| 293 |
+
"invalid_request_error",
|
| 294 |
+
"authentication_error",
|
| 295 |
+
"permission_error",
|
| 296 |
+
"rate_limit_error",
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
errorMessage := strings.ToLower(errResp.Error.Message)
|
| 300 |
+
for _, known := range knownErrors {
|
| 301 |
+
if strings.Contains(errorMessage, known) || errResp.Error.Type == known {
|
| 302 |
+
isKnownError = true
|
| 303 |
+
if known == "prompt is too long" || strings.Contains(errorMessage, "prompt is too long") {
|
| 304 |
+
isPromptTooLongError = true
|
| 305 |
+
}
|
| 306 |
+
break
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
if isKnownError {
|
| 311 |
+
// 已知错误,只输出简单日志,包含请求模型ID和thinking状态
|
| 312 |
+
log.Printf("[Anthropic] 400错误: %s - %s (Model: %s, Thinking: %s)", errResp.Error.Type, errResp.Error.Message, req.Model, thinkingStatus)
|
| 313 |
+
|
| 314 |
+
// 对于非"prompt is too long"错误,在DEBUG模式下输出详细信息
|
| 315 |
+
if !isPromptTooLongError && IsDebugMode() {
|
| 316 |
+
if originalHeaders, ok := ctx.Value("originalHeaders").(http.Header); ok {
|
| 317 |
+
logRequestDetails("[Anthropic] 原始客户端", originalHeaders, body)
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
} else {
|
| 321 |
+
// 未知错误,输出详细日志用于调试,包含请求模型ID和thinking状态
|
| 322 |
+
log.Printf("[Anthropic] 400未知错误: %s (Model: %s, Thinking: %s)", string(errBody), req.Model, thinkingStatus)
|
| 323 |
+
if IsDebugMode() {
|
| 324 |
+
// DEBUG模式下输出原始请求信息
|
| 325 |
+
if originalHeaders, ok := ctx.Value("originalHeaders").(http.Header); ok {
|
| 326 |
+
logRequestDetails("[Anthropic] 原始客户端", originalHeaders, body)
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
} else {
|
| 331 |
+
// 解析失败,输出完整错误用于调试,包含请求模型ID和thinking状态
|
| 332 |
+
log.Printf("[Anthropic] 400错误(无法解析): %s (Model: %s, Thinking: %s)", string(errBody), req.Model, thinkingStatus)
|
| 333 |
+
if IsDebugMode() {
|
| 334 |
+
// DEBUG模式下输出原始请求信息
|
| 335 |
+
if originalHeaders, ok := ctx.Value("originalHeaders").(http.Header); ok {
|
| 336 |
+
logRequestDetails("[Anthropic] 原始客户端", originalHeaders, body)
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
} else if resp.StatusCode == 429 {
|
| 341 |
+
// 简化429错误日志输出
|
| 342 |
+
s.classifyAndLog429Error(string(errBody), account.ID, account.Email)
|
| 343 |
+
|
| 344 |
+
// 检查是否是Claude官方的429错误
|
| 345 |
+
isClaudeOfficialError := s.isClaudeOfficial429Error(string(errBody))
|
| 346 |
+
|
| 347 |
+
// 尝试使用代理池重试
|
| 348 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, req.Model, body)
|
| 349 |
+
if proxyErr == nil && proxyResp != nil {
|
| 350 |
+
// 代理重试成功
|
| 351 |
+
ReleaseAccount(account)
|
| 352 |
+
return proxyResp, nil
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
if proxyErr != nil {
|
| 356 |
+
log.Printf("[Anthropic] 代理重试失败 账号ID:%d %s", account.ID, account.Email)
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// 只有Claude官方的429错误才返回原始响应,其他429错误返回通用错误
|
| 360 |
+
if isClaudeOfficialError {
|
| 361 |
+
// Claude官方429错误,返回原始响应
|
| 362 |
+
ReleaseAccount(account)
|
| 363 |
+
return &http.Response{
|
| 364 |
+
StatusCode: resp.StatusCode,
|
| 365 |
+
Header: resp.Header,
|
| 366 |
+
Body: io.NopCloser(bytes.NewReader(errBody)),
|
| 367 |
+
}, nil
|
| 368 |
+
} else {
|
| 369 |
+
// 非Claude官方429错误,不返回原始响应,继续重试其他账号
|
| 370 |
+
ReleaseAccount(account)
|
| 371 |
+
lastErr = fmt.Errorf("non-official 429 error")
|
| 372 |
+
if IsDebugMode() {
|
| 373 |
+
DebugLogRetry(ctx, "Anthropic", i+1, account.ID, lastErr)
|
| 374 |
+
}
|
| 375 |
+
continue
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
// 对于其他官方API错误(400、413):
|
| 379 |
+
// 1. 释放账号
|
| 380 |
+
// 2. 不计算账号错误次数
|
| 381 |
+
// 3. 直接返回原始响应
|
| 382 |
+
ReleaseAccount(account)
|
| 383 |
+
return &http.Response{
|
| 384 |
+
StatusCode: resp.StatusCode,
|
| 385 |
+
Header: resp.Header,
|
| 386 |
+
Body: io.NopCloser(bytes.NewReader(errBody)),
|
| 387 |
+
}, nil
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// 503和529错误:上游API错误,不是token问题
|
| 391 |
+
if resp.StatusCode == 503 || resp.StatusCode == 529 {
|
| 392 |
+
// 只记录简单的错误日志
|
| 393 |
+
log.Printf("错误响应 [%d]: %s", resp.StatusCode, string(errBody))
|
| 394 |
+
// 释放账号,不计算错误次数,返回通用错误
|
| 395 |
+
ReleaseAccount(account)
|
| 396 |
+
return nil, ErrNoAvailableAccount
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
// 500错误处理
|
| 400 |
+
if resp.StatusCode == 500 {
|
| 401 |
+
// 检查是否是限速问题
|
| 402 |
+
if strings.Contains(string(errBody), "Rate limit tracking problem") {
|
| 403 |
+
log.Printf("[Anthropic] 限速跟踪问题,尝试使用代理重试")
|
| 404 |
+
|
| 405 |
+
// 尝试使用代理池重试
|
| 406 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, req.Model, body)
|
| 407 |
+
if proxyErr == nil && proxyResp != nil {
|
| 408 |
+
// 代理重试成功
|
| 409 |
+
ReleaseAccount(account)
|
| 410 |
+
return proxyResp, nil
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
log.Printf("[Anthropic] 代理重试失败: %v", proxyErr)
|
| 414 |
+
|
| 415 |
+
// 代理重试失败,继续原有逻���:冻结账号5-10秒随机时间
|
| 416 |
+
freezeTime := 5 + rand.Intn(6) // 5-10秒随机
|
| 417 |
+
|
| 418 |
+
// 非调试模式下只输出简单信息
|
| 419 |
+
if !IsDebugMode() {
|
| 420 |
+
log.Printf("[Anthropic] 限速错误,冻结账号 ID:%d %s %d秒,重试 #%d", account.ID, account.Email, freezeTime, i+1)
|
| 421 |
+
} else {
|
| 422 |
+
log.Printf("[Anthropic] 检测到限速错误,冻结账号 ID:%d %s %d秒", account.ID, account.Email, freezeTime)
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
// 冻结账号并释放(不计算错误次数,这是临时限速问题)
|
| 426 |
+
FreezeAccount(account, time.Duration(freezeTime)*time.Second) // 这个函数内部会释放账号
|
| 427 |
+
|
| 428 |
+
// 设置错误并继续重试其他账号
|
| 429 |
+
lastErr = fmt.Errorf("rate limit tracking problem")
|
| 430 |
+
|
| 431 |
+
// 只在调试模式下输出详细重试日志
|
| 432 |
+
if IsDebugMode() {
|
| 433 |
+
DebugLogRetry(ctx, "Anthropic", i+1, account.ID, lastErr)
|
| 434 |
+
}
|
| 435 |
+
continue
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
// 其他500错误,释放账号并直接返回
|
| 439 |
+
ReleaseAccount(account)
|
| 440 |
+
return &http.Response{
|
| 441 |
+
StatusCode: resp.StatusCode,
|
| 442 |
+
Header: resp.Header,
|
| 443 |
+
Body: io.NopCloser(bytes.NewReader(errBody)),
|
| 444 |
+
}, nil
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
// 其他错误,释放账号并继续重试
|
| 448 |
+
ReleaseAccount(account)
|
| 449 |
+
// MarkAccountError(account)
|
| 450 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 451 |
+
|
| 452 |
+
// 只在调试模式下输出详细错误信息
|
| 453 |
+
if IsDebugMode() {
|
| 454 |
+
DebugLogErrorResponse(ctx, "Anthropic", resp.StatusCode, string(errBody))
|
| 455 |
+
DebugLogRetry(ctx, "Anthropic", i+1, account.ID, lastErr)
|
| 456 |
+
} else {
|
| 457 |
+
// 非调试模式下只输出简单的重试信息
|
| 458 |
+
log.Printf("[Anthropic] API错误 %d,重试 #%d", resp.StatusCode, i+1)
|
| 459 |
+
}
|
| 460 |
+
continue
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
// 请求成功,释放账号
|
| 464 |
+
ReleaseAccount(account)
|
| 465 |
+
|
| 466 |
+
ResetAccountError(account)
|
| 467 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 468 |
+
if !exists {
|
| 469 |
+
// 模型不存在,使用默认倍率
|
| 470 |
+
UpdateAccountCreditsFromResponse(account, resp, 1.0)
|
| 471 |
+
} else {
|
| 472 |
+
// 使用统一的积分更新函数,自动处理响应头中的积分信息
|
| 473 |
+
UpdateAccountCreditsFromResponse(account, resp, zenModel.Multiplier)
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
DebugLogRequestEnd(ctx, "Anthropic", true, nil)
|
| 477 |
+
return resp, nil
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// 只在调试模式下输出详细的请求结束日志
|
| 481 |
+
if IsDebugMode() {
|
| 482 |
+
DebugLogRequestEnd(ctx, "Anthropic", false, lastErr)
|
| 483 |
+
} else {
|
| 484 |
+
// 非调试模式下只输出简单的失败信息
|
| 485 |
+
log.Printf("[Anthropic] 所有重试失败: %v", lastErr)
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
// 检查是否是网络连接错误,如果是则返回统一的错误信息,避免暴露内部网络详情
|
| 489 |
+
if lastErr != nil {
|
| 490 |
+
errStr := lastErr.Error()
|
| 491 |
+
// 检查常见的网络连接错误
|
| 492 |
+
if strings.Contains(errStr, "dial tcp") ||
|
| 493 |
+
strings.Contains(errStr, "connection refused") ||
|
| 494 |
+
strings.Contains(errStr, "no such host") ||
|
| 495 |
+
strings.Contains(errStr, "cannot assign requested address") ||
|
| 496 |
+
strings.Contains(errStr, "timeout") ||
|
| 497 |
+
strings.Contains(errStr, "network is unreachable") {
|
| 498 |
+
return nil, ErrNoAvailableAccount
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
func (s *AnthropicService) doRequest(ctx context.Context, account *model.Account, modelID string, body []byte) (*http.Response, error) {
|
| 506 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 507 |
+
if !exists {
|
| 508 |
+
// 模型不存在,返回错误
|
| 509 |
+
return nil, ErrNoAvailableAccount
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
// 注意:已移除模型替换逻辑,直接使用原始请求体
|
| 513 |
+
modifiedBody := body
|
| 514 |
+
|
| 515 |
+
// 对于需要 thinking 的模型,强制添加 thinking 配置
|
| 516 |
+
var err error
|
| 517 |
+
modifiedBody, err = s.ensureThinkingConfig(modifiedBody, modelID)
|
| 518 |
+
if err != nil {
|
| 519 |
+
return nil, fmt.Errorf("failed to ensure thinking config: %w", err)
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
// 根据模型要求调整参数(温度、top_p等)
|
| 523 |
+
modifiedBody, err = s.adjustParametersForModel(modifiedBody, modelID)
|
| 524 |
+
if err != nil {
|
| 525 |
+
return nil, fmt.Errorf("failed to adjust parameters: %w", err)
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
// 注意:已移除模型重定向逻辑,直接使用用户请求的模型名
|
| 529 |
+
DebugLogActualModel(ctx, "Anthropic", modelID, modelID)
|
| 530 |
+
|
| 531 |
+
reqURL := AnthropicBaseURL + "/v1/messages"
|
| 532 |
+
DebugLogRequestSent(ctx, "Anthropic", reqURL)
|
| 533 |
+
|
| 534 |
+
resp, err := s.makeRequest(ctx, modifiedBody, account, zenModel)
|
| 535 |
+
if err != nil {
|
| 536 |
+
return nil, err
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
// 检查是否是400错误,需要特殊处理
|
| 540 |
+
if resp.StatusCode == 400 {
|
| 541 |
+
bodyBytes, readErr := io.ReadAll(resp.Body)
|
| 542 |
+
resp.Body.Close()
|
| 543 |
+
|
| 544 |
+
if readErr == nil {
|
| 545 |
+
errorBody := string(bodyBytes)
|
| 546 |
+
|
| 547 |
+
// 检查是否是thinking格式错误,但不再进行模型切换
|
| 548 |
+
if s.isThinkingFormatError(errorBody) {
|
| 549 |
+
log.Printf("[Anthropic] thinking格式错误: %s", errorBody)
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
// 检查是否是thinking signature过期错误
|
| 553 |
+
if s.isThinkingSignatureError(errorBody) {
|
| 554 |
+
// 解析当前请求的模型和thinking状态
|
| 555 |
+
var reqInfo struct {
|
| 556 |
+
Model string `json:"model"`
|
| 557 |
+
Thinking map[string]interface{} `json:"thinking,omitempty"`
|
| 558 |
+
}
|
| 559 |
+
json.Unmarshal(modifiedBody, &reqInfo)
|
| 560 |
+
|
| 561 |
+
thinkingStatus := "disabled"
|
| 562 |
+
if reqInfo.Thinking != nil {
|
| 563 |
+
if enabled, ok := reqInfo.Thinking["enabled"].(bool); ok && enabled {
|
| 564 |
+
thinkingStatus = "enabled"
|
| 565 |
+
} else if thinkingType, ok := reqInfo.Thinking["type"].(string); ok && thinkingType == "enabled" {
|
| 566 |
+
thinkingStatus = "enabled"
|
| 567 |
+
}
|
| 568 |
+
if budget, ok := reqInfo.Thinking["budget_tokens"].(float64); ok && budget > 0 {
|
| 569 |
+
thinkingStatus = fmt.Sprintf("enabled(budget=%g)", budget)
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
if IsDebugMode() {
|
| 574 |
+
log.Printf("[Anthropic] thinking signature过期,尝试转换assistant消息为user消息重试")
|
| 575 |
+
} else {
|
| 576 |
+
log.Printf("[Anthropic] thinking signature过期,尝试转换assistant消息为user消息重试 model:%s thinking:%s", reqInfo.Model, thinkingStatus)
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
// 转换请求体:将assistant消息转换为user消息
|
| 580 |
+
fixedBody, fixErr := s.convertAssistantMessagesToUser(modifiedBody)
|
| 581 |
+
if fixErr == nil {
|
| 582 |
+
return s.makeRequest(ctx, fixedBody, account, zenModel)
|
| 583 |
+
} else {
|
| 584 |
+
log.Printf("[Anthropic] 转换assistant消息失败: %v", fixErr)
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
// 检查是否是参数冲突错误(temperature 和 top_p 不能同时指定)
|
| 589 |
+
if s.isParameterConflictError(errorBody) {
|
| 590 |
+
DebugLogRequestSent(ctx, "Anthropic", "Retrying with only temperature parameter")
|
| 591 |
+
|
| 592 |
+
// 移除 top_p 参数,只保留 temperature
|
| 593 |
+
fixedBody, fixErr := s.removeTopP(modifiedBody)
|
| 594 |
+
if fixErr == nil {
|
| 595 |
+
return s.makeRequest(ctx, fixedBody, account, zenModel)
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
// 检查是否是温度参数错误
|
| 600 |
+
if s.isTemperatureError(errorBody) {
|
| 601 |
+
DebugLogRequestSent(ctx, "Anthropic", "Retrying with temperature=1.0")
|
| 602 |
+
|
| 603 |
+
// 强制设置温度为1.0并重试
|
| 604 |
+
fixedBody, fixErr := s.forceTemperature(modifiedBody, 1.0)
|
| 605 |
+
if fixErr == nil {
|
| 606 |
+
return s.makeRequest(ctx, fixedBody, account, zenModel)
|
| 607 |
+
}
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
// 如果不是thinking相关的可修复错误,返回原始响应
|
| 613 |
+
return &http.Response{
|
| 614 |
+
StatusCode: resp.StatusCode,
|
| 615 |
+
Header: resp.Header,
|
| 616 |
+
Body: io.NopCloser(bytes.NewReader(bodyBytes)),
|
| 617 |
+
}, nil
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
return resp, nil
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
func (s *AnthropicService) makeRequest(ctx context.Context, body []byte, account *model.Account, zenModel model.ZenModel) (*http.Response, error) {
|
| 624 |
+
httpReq, err := http.NewRequest("POST", AnthropicBaseURL+"/v1/messages", bytes.NewReader(body))
|
| 625 |
+
if err != nil {
|
| 626 |
+
return nil, err
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
// 设置Zencoder自定义请求头
|
| 630 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 631 |
+
|
| 632 |
+
// Anthropic特有请求头
|
| 633 |
+
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
| 634 |
+
|
| 635 |
+
// 添加模型配置的额外请求头
|
| 636 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 637 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 638 |
+
httpReq.Header.Set(k, v)
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
// 只在非限速测试且调试模式下记录请求头
|
| 643 |
+
if IsDebugMode() {
|
| 644 |
+
// 检查请求体中的模型以判断是否为限速测试
|
| 645 |
+
var reqCheck struct {
|
| 646 |
+
Model string `json:"model"`
|
| 647 |
+
}
|
| 648 |
+
if json.Unmarshal(body, &reqCheck) == nil && !strings.Contains(reqCheck.Model, "test") {
|
| 649 |
+
DebugLogRequestHeaders(ctx, "Anthropic", httpReq.Header)
|
| 650 |
+
}
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
httpClient := provider.NewHTTPClient(account.Proxy, 0)
|
| 654 |
+
resp, err := httpClient.Do(httpReq)
|
| 655 |
+
if err != nil {
|
| 656 |
+
return nil, err
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
// 不输出响应头调试信息以减少日志量
|
| 660 |
+
|
| 661 |
+
// 如果是400错误,记录详细的请求信息
|
| 662 |
+
if resp.StatusCode == 400 {
|
| 663 |
+
// 读取错误响应内容
|
| 664 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 665 |
+
resp.Body.Close()
|
| 666 |
+
|
| 667 |
+
// 检查是否是"prompt is too long"错误
|
| 668 |
+
isPromptTooLongError := false
|
| 669 |
+
// 检查是否是thinking格式错误(将在doRequest中处理并重试)
|
| 670 |
+
isThinkingFormatError := false
|
| 671 |
+
// 检查是否是thinking signature过期错误(将在doRequest中处理并重试)
|
| 672 |
+
isThinkingSignatureError := false
|
| 673 |
+
var errResp struct {
|
| 674 |
+
Error struct {
|
| 675 |
+
Type string `json:"type"`
|
| 676 |
+
Message string `json:"message"`
|
| 677 |
+
} `json:"error"`
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
if err := json.Unmarshal(errBody, &errResp); err == nil {
|
| 681 |
+
errorMessage := strings.ToLower(errResp.Error.Message)
|
| 682 |
+
if strings.Contains(errorMessage, "prompt is too long") {
|
| 683 |
+
isPromptTooLongError = true
|
| 684 |
+
// 对于prompt过长错误,只输出简单的错误信息
|
| 685 |
+
log.Printf("[Anthropic] 400错误: %s - %s", errResp.Error.Type, errResp.Error.Message)
|
| 686 |
+
}
|
| 687 |
+
// 检查是否是thinking格式错误
|
| 688 |
+
if strings.Contains(errResp.Error.Message, "When `thinking` is enabled") ||
|
| 689 |
+
strings.Contains(errResp.Error.Message, "Expected `thinking` or `redacted_thinking`") {
|
| 690 |
+
isThinkingFormatError = true
|
| 691 |
+
// 输出详细的thinking格式错误信息
|
| 692 |
+
log.Printf("[Anthropic] thinking格式错误详情: %s", errResp.Error.Message)
|
| 693 |
+
log.Printf("[Anthropic] 发送给zencoder的请求体:")
|
| 694 |
+
log.Printf("%s", sanitizeRequestBody(body))
|
| 695 |
+
}
|
| 696 |
+
// 检查是否是thinking signature过期错误
|
| 697 |
+
if strings.Contains(errResp.Error.Message, "Invalid `signature` in `thinking` block") {
|
| 698 |
+
isThinkingSignatureError = true
|
| 699 |
+
// 对于thinking signature过期错误,只输出简单信息,详细处理留给doRequest
|
| 700 |
+
}
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
// 只在非调试模式且非已知可重试错误时才输出详细debug信息
|
| 704 |
+
// thinking相关错误会在doRequest中处理,如果重试成功就不需要输出debug日志
|
| 705 |
+
shouldOutputDetails := !isPromptTooLongError && !isThinkingFormatError && !isThinkingSignatureError
|
| 706 |
+
if shouldOutputDetails {
|
| 707 |
+
log.Printf("[Anthropic] API返回400错误: %s", string(errBody))
|
| 708 |
+
// 只在调试模式下输出详细的请求信息
|
| 709 |
+
if IsDebugMode() {
|
| 710 |
+
logRequestDetails("[Anthropic] 实际API", httpReq.Header, body)
|
| 711 |
+
}
|
| 712 |
+
} else if isThinkingSignatureError && IsDebugMode() {
|
| 713 |
+
// thinking signature错误只在调试模式下输出简单信息
|
| 714 |
+
log.Printf("[Anthropic] API返回400错误: %s", string(errBody))
|
| 715 |
+
logRequestDetails("[Anthropic] 实际API", httpReq.Header, body)
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
// 重新构建响应,因为body已经被读取
|
| 719 |
+
resp.Body = io.NopCloser(bytes.NewReader(errBody))
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
return resp, nil
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
// isThinkingFormatError 检查是否是thinking格式相关的错误
|
| 726 |
+
func (s *AnthropicService) isThinkingFormatError(errorBody string) bool {
|
| 727 |
+
return strings.Contains(errorBody, "When `thinking` is enabled, a final `assistant` message must start with a thinking block") ||
|
| 728 |
+
strings.Contains(errorBody, "Expected `thinking` or `redacted_thinking`") ||
|
| 729 |
+
strings.Contains(errorBody, "To avoid this requirement, disable `thinking`")
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
// isThinkingSignatureError 检查是否是thinking signature过期错误
|
| 733 |
+
func (s *AnthropicService) isThinkingSignatureError(errorBody string) bool {
|
| 734 |
+
return strings.Contains(errorBody, "Invalid `signature` in `thinking` block") ||
|
| 735 |
+
strings.Contains(errorBody, "invalid_request_error") && strings.Contains(errorBody, "signature")
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
// isTemperatureError 检查是否是温度参数相关的错误
|
| 739 |
+
func (s *AnthropicService) isTemperatureError(errorBody string) bool {
|
| 740 |
+
return strings.Contains(errorBody, "requires temperature=1.0") ||
|
| 741 |
+
strings.Contains(errorBody, "Parallel Thinking' requires temperature")
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
// isParameterConflictError 检查是否是参数冲突错误
|
| 745 |
+
func (s *AnthropicService) isParameterConflictError(errorBody string) bool {
|
| 746 |
+
return strings.Contains(errorBody, "`temperature` and `top_p` cannot both be specified")
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
// isClaudeOfficial429Error 检查是否是Claude官方的429限流错误
|
| 750 |
+
func (s *AnthropicService) isClaudeOfficial429Error(errorBody string) bool {
|
| 751 |
+
// 尝试解析错误响应
|
| 752 |
+
var errResp struct {
|
| 753 |
+
Type string `json:"type"`
|
| 754 |
+
Error struct {
|
| 755 |
+
Type string `json:"type"`
|
| 756 |
+
Message string `json:"message"`
|
| 757 |
+
} `json:"error"`
|
| 758 |
+
RequestID string `json:"request_id"`
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
// 如果能解析成功且符合Claude官方格式
|
| 762 |
+
if err := json.Unmarshal([]byte(errorBody), &errResp); err == nil {
|
| 763 |
+
// Claude官方错误特征:
|
| 764 |
+
// 1. type = "error"
|
| 765 |
+
// 2. error.type = "rate_limit_error"
|
| 766 |
+
// 3. 错误消息包含anthropic.com或claude.com域名
|
| 767 |
+
if errResp.Type == "error" &&
|
| 768 |
+
errResp.Error.Type == "rate_limit_error" &&
|
| 769 |
+
(strings.Contains(errResp.Error.Message, "anthropic.com") ||
|
| 770 |
+
strings.Contains(errResp.Error.Message, "claude.com") ||
|
| 771 |
+
strings.Contains(errResp.Error.Message, "docs.claude.com")) {
|
| 772 |
+
return true
|
| 773 |
+
}
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
// 检查是否是非Claude官方的错误格式(如Google API格式)
|
| 777 |
+
var nonClaudeErr struct {
|
| 778 |
+
Error struct {
|
| 779 |
+
Code int `json:"code"`
|
| 780 |
+
Message string `json:"message"`
|
| 781 |
+
Status string `json:"status"`
|
| 782 |
+
} `json:"error"`
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
if err := json.Unmarshal([]byte(errorBody), &nonClaudeErr); err == nil {
|
| 786 |
+
// 非Claude官方错误特征:有code和status字段
|
| 787 |
+
if nonClaudeErr.Error.Code == 429 &&
|
| 788 |
+
nonClaudeErr.Error.Status == "RESOURCE_EXHAUSTED" {
|
| 789 |
+
return false
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
// 默认情况下,如果无法确定,保守处理:不返回原始响应
|
| 794 |
+
return false
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
// classifyAndLog429Error 分类并记录429错误的简化日志
|
| 798 |
+
func (s *AnthropicService) classifyAndLog429Error(errorBody string, accountID uint, email string) {
|
| 799 |
+
// 尝试解析Claude官方错误
|
| 800 |
+
var claudeErr struct {
|
| 801 |
+
Type string `json:"type"`
|
| 802 |
+
Error struct {
|
| 803 |
+
Type string `json:"type"`
|
| 804 |
+
Message string `json:"message"`
|
| 805 |
+
} `json:"error"`
|
| 806 |
+
}
|
| 807 |
+
|
| 808 |
+
if err := json.Unmarshal([]byte(errorBody), &claudeErr); err == nil {
|
| 809 |
+
if claudeErr.Type == "error" && claudeErr.Error.Type == "rate_limit_error" {
|
| 810 |
+
// Claude官方限流错误
|
| 811 |
+
log.Printf("[Anthropic] Claude rate_limit_error 账号ID:%d %s", accountID, email)
|
| 812 |
+
return
|
| 813 |
+
}
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
// 尝试解析GCP错误
|
| 817 |
+
var gcpErr struct {
|
| 818 |
+
Error struct {
|
| 819 |
+
Code int `json:"code"`
|
| 820 |
+
Message string `json:"message"`
|
| 821 |
+
Status string `json:"status"`
|
| 822 |
+
} `json:"error"`
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
if err := json.Unmarshal([]byte(errorBody), &gcpErr); err == nil {
|
| 826 |
+
if gcpErr.Error.Code == 429 && gcpErr.Error.Status == "RESOURCE_EXHAUSTED" {
|
| 827 |
+
// GCP限流错误
|
| 828 |
+
log.Printf("[Anthropic] GCP RESOURCE_EXHAUSTED 账号ID:%d %s", accountID, email)
|
| 829 |
+
return
|
| 830 |
+
}
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
// 其他未识别的429错误
|
| 834 |
+
log.Printf("[Anthropic] 429限流错误 账号ID:%d %s", accountID, email)
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
// MessagesProxy 直接代理请求和响应
|
| 838 |
+
func (s *AnthropicService) MessagesProxy(ctx context.Context, w http.ResponseWriter, body []byte) error {
|
| 839 |
+
var req struct {
|
| 840 |
+
Model string `json:"model"`
|
| 841 |
+
Stream bool `json:"stream"`
|
| 842 |
+
}
|
| 843 |
+
// 忽略错误,Messages方法会再次解析
|
| 844 |
+
_ = json.Unmarshal(body, &req)
|
| 845 |
+
|
| 846 |
+
resp, err := s.Messages(ctx, body, false)
|
| 847 |
+
if err != nil {
|
| 848 |
+
return err
|
| 849 |
+
}
|
| 850 |
+
defer resp.Body.Close()
|
| 851 |
+
|
| 852 |
+
// 判断是否需要过滤thinking内容
|
| 853 |
+
// 规则:如果用户调用的是非thinking版本,但平台强制开启了thinking,则需要过滤
|
| 854 |
+
needsFiltering := false
|
| 855 |
+
|
| 856 |
+
// 获取模型配置
|
| 857 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 858 |
+
|
| 859 |
+
// 如果模型配置中有thinking参数(平台强制thinking)
|
| 860 |
+
if exists && zenModel.Parameters != nil && zenModel.Parameters.Thinking != nil {
|
| 861 |
+
// 检查用户是否明确请求了thinking版本
|
| 862 |
+
// 如果模型ID不包含 "thinking" 后缀,说明用户要的是非thinking版本
|
| 863 |
+
if !strings.HasSuffix(req.Model, "-thinking") {
|
| 864 |
+
needsFiltering = true
|
| 865 |
+
}
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
if needsFiltering {
|
| 869 |
+
if req.Stream {
|
| 870 |
+
return s.streamFilteredResponse(w, resp)
|
| 871 |
+
}
|
| 872 |
+
return s.handleNonStreamFilteredResponse(w, resp)
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
return StreamResponse(w, resp)
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
func (s *AnthropicService) handleNonStreamFilteredResponse(w http.ResponseWriter, resp *http.Response) error {
|
| 879 |
+
// 读取全部响应体
|
| 880 |
+
bodyBytes, err := io.ReadAll(resp.Body)
|
| 881 |
+
if err != nil {
|
| 882 |
+
return err
|
| 883 |
+
}
|
| 884 |
+
|
| 885 |
+
// 复制响应头
|
| 886 |
+
for k, v := range resp.Header {
|
| 887 |
+
// 过滤掉 Content-Length 和 Content-Encoding
|
| 888 |
+
if k != "Content-Length" && k != "Content-Encoding" {
|
| 889 |
+
for _, vv := range v {
|
| 890 |
+
w.Header().Add(k, vv)
|
| 891 |
+
}
|
| 892 |
+
}
|
| 893 |
+
}
|
| 894 |
+
w.WriteHeader(resp.StatusCode)
|
| 895 |
+
|
| 896 |
+
// 尝试解析响应
|
| 897 |
+
var raw map[string]interface{}
|
| 898 |
+
if err := json.Unmarshal(bodyBytes, &raw); err != nil {
|
| 899 |
+
w.Write(bodyBytes)
|
| 900 |
+
return nil
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
// 过滤 content 中的 thinking block
|
| 904 |
+
if content, ok := raw["content"].([]interface{}); ok {
|
| 905 |
+
var newContent []interface{}
|
| 906 |
+
for _, block := range content {
|
| 907 |
+
if b, ok := block.(map[string]interface{}); ok {
|
| 908 |
+
if typeStr, ok := b["type"].(string); ok && (typeStr == "thinking" || typeStr == "thought") {
|
| 909 |
+
continue
|
| 910 |
+
}
|
| 911 |
+
}
|
| 912 |
+
newContent = append(newContent, block)
|
| 913 |
+
}
|
| 914 |
+
raw["content"] = newContent
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
return json.NewEncoder(w).Encode(raw)
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
// adjustTemperatureForModel 根据模型要求调整温度参数
|
| 921 |
+
func (s *AnthropicService) adjustTemperatureForModel(body []byte, modelID string) ([]byte, error) {
|
| 922 |
+
// 获取模型配置
|
| 923 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 924 |
+
|
| 925 |
+
// 检查模型配置中是否有特定的温度要求
|
| 926 |
+
if exists && zenModel.Parameters != nil && zenModel.Parameters.Temperature != nil {
|
| 927 |
+
return s.forceTemperature(body, *zenModel.Parameters.Temperature)
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
return body, nil
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
// forceTemperature 强制设置温度参数
|
| 934 |
+
func (s *AnthropicService) forceTemperature(body []byte, temperature float64) ([]byte, error) {
|
| 935 |
+
// 解析请求体
|
| 936 |
+
var reqMap map[string]interface{}
|
| 937 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 938 |
+
return body, nil // 如果解析失败,返回原始body
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
// 强制设置 temperature
|
| 942 |
+
reqMap["temperature"] = temperature
|
| 943 |
+
|
| 944 |
+
// 如果同时存在 top_p,移除它(某些模型不允许同时指定)
|
| 945 |
+
delete(reqMap, "top_p")
|
| 946 |
+
|
| 947 |
+
// 重新序列化
|
| 948 |
+
return json.Marshal(reqMap)
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
// removeTopP 移除 top_p 参数,避免与 temperature 冲突
|
| 952 |
+
func (s *AnthropicService) removeTopP(body []byte) ([]byte, error) {
|
| 953 |
+
// 解析请求体
|
| 954 |
+
var reqMap map[string]interface{}
|
| 955 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 956 |
+
return body, nil // 如果解析失败,返回原始body
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
// 移除 top_p 参数
|
| 960 |
+
delete(reqMap, "top_p")
|
| 961 |
+
|
| 962 |
+
// 重新序列化
|
| 963 |
+
return json.Marshal(reqMap)
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
// hasMatchingToolResult 检查消息中是否包含指定tool_use_id的tool_result
|
| 967 |
+
func hasMatchingToolResult(msg map[string]interface{}, toolUseID interface{}) bool {
|
| 968 |
+
if msg == nil || toolUseID == nil {
|
| 969 |
+
return false
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
toolUseIDStr, ok := toolUseID.(string)
|
| 973 |
+
if !ok {
|
| 974 |
+
return false
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
content, ok := msg["content"].([]interface{})
|
| 978 |
+
if !ok {
|
| 979 |
+
return false
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
for _, block := range content {
|
| 983 |
+
if b, ok := block.(map[string]interface{}); ok {
|
| 984 |
+
if b["type"] == "tool_result" {
|
| 985 |
+
if id, ok := b["tool_use_id"].(string); ok && id == toolUseIDStr {
|
| 986 |
+
return true
|
| 987 |
+
}
|
| 988 |
+
}
|
| 989 |
+
}
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
return false
|
| 993 |
+
}
|
| 994 |
+
|
| 995 |
+
// ensureThinkingConfig 确保需要 thinking 的模型有正确的配置
|
| 996 |
+
func (s *AnthropicService) ensureThinkingConfig(body []byte, modelID string) ([]byte, error) {
|
| 997 |
+
// 获取模型配置
|
| 998 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 999 |
+
|
| 1000 |
+
// 检查模型配置中是否包含thinking参数
|
| 1001 |
+
needsThinking := false
|
| 1002 |
+
var modelBudgetTokens int
|
| 1003 |
+
if exists && zenModel.Parameters != nil && zenModel.Parameters.Thinking != nil {
|
| 1004 |
+
needsThinking = true
|
| 1005 |
+
modelBudgetTokens = zenModel.Parameters.Thinking.BudgetTokens
|
| 1006 |
+
if modelBudgetTokens == 0 {
|
| 1007 |
+
modelBudgetTokens = 4096 // 默认值
|
| 1008 |
+
}
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
if !needsThinking {
|
| 1012 |
+
return body, nil
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
// 解析请求体
|
| 1016 |
+
var reqMap map[string]interface{}
|
| 1017 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 1018 |
+
return body, nil
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
// 检查用户是否明确不想要thinking模式
|
| 1022 |
+
userDisablesThinking := false
|
| 1023 |
+
if existingThinking, ok := reqMap["thinking"].(map[string]interface{}); ok {
|
| 1024 |
+
if thinkingType, ok := existingThinking["type"].(string); ok && thinkingType == "disabled" {
|
| 1025 |
+
userDisablesThinking = true
|
| 1026 |
+
}
|
| 1027 |
+
if enabled, ok := existingThinking["enabled"].(bool); ok && !enabled {
|
| 1028 |
+
userDisablesThinking = true
|
| 1029 |
+
}
|
| 1030 |
+
} else {
|
| 1031 |
+
// 如果没有thinking配置,检查是否是非thinking版本的模型调用
|
| 1032 |
+
// 例如 claude-haiku-4-5-20251001 而不是 claude-haiku-4-5-20251001-thinking
|
| 1033 |
+
if !strings.HasSuffix(modelID, "-thinking") {
|
| 1034 |
+
userDisablesThinking = true
|
| 1035 |
+
}
|
| 1036 |
+
}
|
| 1037 |
+
|
| 1038 |
+
// 如果用户不想要thinking但模型强制thinking,转换assistant消息为user消息
|
| 1039 |
+
if userDisablesThinking {
|
| 1040 |
+
if IsDebugMode() {
|
| 1041 |
+
log.Printf("[Anthropic] 用户不想要thinking模式,但模型强制thinking,转换assistant消息为user消息")
|
| 1042 |
+
}
|
| 1043 |
+
if messages, ok := reqMap["messages"].([]interface{}); ok {
|
| 1044 |
+
for i, msg := range messages {
|
| 1045 |
+
if msgMap, ok := msg.(map[string]interface{}); ok {
|
| 1046 |
+
if role, ok := msgMap["role"].(string); ok && role == "assistant" {
|
| 1047 |
+
// 转换thinking内容为text并改变角色为user
|
| 1048 |
+
if err := s.convertAssistantToUserMessage(msgMap); err != nil {
|
| 1049 |
+
log.Printf("[Anthropic] 转换assistant消息为user消息失败: %v", err)
|
| 1050 |
+
}
|
| 1051 |
+
}
|
| 1052 |
+
messages[i] = msgMap
|
| 1053 |
+
}
|
| 1054 |
+
}
|
| 1055 |
+
reqMap["messages"] = messages
|
| 1056 |
+
}
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
// 注意:即使有tool_choice,某些模型仍然需要thinking配置
|
| 1060 |
+
// 因此不再因为tool_choice的存在而跳过thinking配置
|
| 1061 |
+
|
| 1062 |
+
// 检查请求体中是否已有thinking配置
|
| 1063 |
+
if existingThinking, ok := reqMap["thinking"].(map[string]interface{}); ok {
|
| 1064 |
+
// 如果已有thinking配置,确保budget_tokens与模型配置一致
|
| 1065 |
+
if _, hasBudget := existingThinking["budget_tokens"]; hasBudget {
|
| 1066 |
+
// 强制使用模型配置中的budget_tokens值
|
| 1067 |
+
existingThinking["budget_tokens"] = modelBudgetTokens
|
| 1068 |
+
if IsDebugMode() {
|
| 1069 |
+
log.Printf("[Anthropic] 调整thinking.budget_tokens为模型配置值: %d", modelBudgetTokens)
|
| 1070 |
+
}
|
| 1071 |
+
} else {
|
| 1072 |
+
// 如果没有budget_tokens,添加
|
| 1073 |
+
existingThinking["budget_tokens"] = modelBudgetTokens
|
| 1074 |
+
}
|
| 1075 |
+
// 确保type字段正确
|
| 1076 |
+
if _, hasType := existingThinking["type"]; !hasType {
|
| 1077 |
+
existingThinking["type"] = "enabled"
|
| 1078 |
+
} else {
|
| 1079 |
+
// 强制启用thinking(因为模型要求)
|
| 1080 |
+
existingThinking["type"] = "enabled"
|
| 1081 |
+
}
|
| 1082 |
+
reqMap["thinking"] = existingThinking
|
| 1083 |
+
} else {
|
| 1084 |
+
// 添加 thinking 配置 - 使用模型配置中的值
|
| 1085 |
+
reqMap["thinking"] = map[string]interface{}{
|
| 1086 |
+
"type": "enabled",
|
| 1087 |
+
"budget_tokens": modelBudgetTokens,
|
| 1088 |
+
}
|
| 1089 |
+
if IsDebugMode() {
|
| 1090 |
+
log.Printf("[Anthropic] 添加thinking配置,budget_tokens: %d", modelBudgetTokens)
|
| 1091 |
+
}
|
| 1092 |
+
if IsDebugMode() {
|
| 1093 |
+
log.Printf("[Anthropic] 原始请求体 (处理前):")
|
| 1094 |
+
log.Printf("%s", sanitizeRequestBody(body))
|
| 1095 |
+
}
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
// 当启用 thinking 时,必须设置 temperature = 1.0
|
| 1099 |
+
reqMap["temperature"] = 1.0
|
| 1100 |
+
// 移除 top_p 以避免冲突
|
| 1101 |
+
delete(reqMap, "top_p")
|
| 1102 |
+
|
| 1103 |
+
// 注意:不再尝试为assistant消息添加thinking块,因为signature信息无法正确生成
|
| 1104 |
+
// 如果模型要求thinking模式但用户消息不符合格式,让API返回错误由上层处理
|
| 1105 |
+
|
| 1106 |
+
// 重新序列化
|
| 1107 |
+
modifiedBody, err := json.Marshal(reqMap)
|
| 1108 |
+
if err != nil {
|
| 1109 |
+
return body, err
|
| 1110 |
+
}
|
| 1111 |
+
|
| 1112 |
+
// 输出处理后的请求体日志
|
| 1113 |
+
if IsDebugMode() {
|
| 1114 |
+
log.Printf("[Anthropic] 处理后的请求体 (发送给实际API):")
|
| 1115 |
+
log.Printf("%s", sanitizeRequestBody(modifiedBody))
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
return modifiedBody, nil
|
| 1119 |
+
}
|
| 1120 |
+
|
| 1121 |
+
// 已移除fixAssistantMessageForThinking函数,因为signature信息无法正确生成
|
| 1122 |
+
|
| 1123 |
+
// convertThinkingToText 将thinking内容转换为普通文本格式(当用户不想要thinking模式时)
|
| 1124 |
+
func (s *AnthropicService) convertThinkingToText(msgMap map[string]interface{}) error {
|
| 1125 |
+
content, ok := msgMap["content"]
|
| 1126 |
+
if !ok {
|
| 1127 |
+
return nil
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
switch c := content.(type) {
|
| 1131 |
+
case []interface{}:
|
| 1132 |
+
var newContent []interface{}
|
| 1133 |
+
for _, block := range c {
|
| 1134 |
+
if blockMap, ok := block.(map[string]interface{}); ok {
|
| 1135 |
+
blockType, _ := blockMap["type"].(string)
|
| 1136 |
+
if blockType == "thinking" || blockType == "redacted_thinking" {
|
| 1137 |
+
// 将thinking块转换为text块
|
| 1138 |
+
if thinkingText, ok := blockMap["thinking"].(string); ok {
|
| 1139 |
+
newContent = append(newContent, map[string]interface{}{
|
| 1140 |
+
"type": "text",
|
| 1141 |
+
"text": "[thinking] " + thinkingText,
|
| 1142 |
+
})
|
| 1143 |
+
}
|
| 1144 |
+
} else {
|
| 1145 |
+
// 保留其他类型的块
|
| 1146 |
+
newContent = append(newContent, block)
|
| 1147 |
+
}
|
| 1148 |
+
}
|
| 1149 |
+
}
|
| 1150 |
+
msgMap["content"] = newContent
|
| 1151 |
+
if IsDebugMode() {
|
| 1152 |
+
log.Printf("[Anthropic] 将thinking块转换为普通文本格式")
|
| 1153 |
+
}
|
| 1154 |
+
}
|
| 1155 |
+
|
| 1156 |
+
return nil
|
| 1157 |
+
}
|
| 1158 |
+
|
| 1159 |
+
// convertAssistantToUserMessage 将assistant消息转换为user消息,避免thinking格式要求
|
| 1160 |
+
// 使用range循环逐个处理块,保留缓存信息,不合并消息
|
| 1161 |
+
func (s *AnthropicService) convertAssistantToUserMessage(msgMap map[string]interface{}) error {
|
| 1162 |
+
content, ok := msgMap["content"]
|
| 1163 |
+
if !ok {
|
| 1164 |
+
return nil
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
// 将角色从assistant改为user
|
| 1168 |
+
msgMap["role"] = "user"
|
| 1169 |
+
|
| 1170 |
+
switch c := content.(type) {
|
| 1171 |
+
case string:
|
| 1172 |
+
// 如果是字符串content,保持不变,只改角色
|
| 1173 |
+
if IsDebugMode() {
|
| 1174 |
+
log.Printf("[Anthropic] 将assistant字符串消息转换为user消息")
|
| 1175 |
+
}
|
| 1176 |
+
case []interface{}:
|
| 1177 |
+
// 使用range循环逐个处理每个块,保留结构和缓存信息
|
| 1178 |
+
for i, block := range c {
|
| 1179 |
+
if blockMap, ok := block.(map[string]interface{}); ok {
|
| 1180 |
+
blockType, _ := blockMap["type"].(string)
|
| 1181 |
+
|
| 1182 |
+
// 保留原有的缓存控制信息
|
| 1183 |
+
var cacheControl interface{}
|
| 1184 |
+
if cache, hasCacheControl := blockMap["cache_control"]; hasCacheControl {
|
| 1185 |
+
cacheControl = cache
|
| 1186 |
+
}
|
| 1187 |
+
|
| 1188 |
+
switch blockType {
|
| 1189 |
+
case "thinking", "redacted_thinking":
|
| 1190 |
+
// 将thinking块转换为text块,保留缓存信息
|
| 1191 |
+
if thinkingText, ok := blockMap["thinking"].(string); ok {
|
| 1192 |
+
newBlock := map[string]interface{}{
|
| 1193 |
+
"type": "text",
|
| 1194 |
+
"text": "[thinking] " + thinkingText,
|
| 1195 |
+
}
|
| 1196 |
+
if cacheControl != nil {
|
| 1197 |
+
newBlock["cache_control"] = cacheControl
|
| 1198 |
+
}
|
| 1199 |
+
c[i] = newBlock
|
| 1200 |
+
}
|
| 1201 |
+
case "tool_use":
|
| 1202 |
+
// 将tool_use块转换为text描述,保留缓存信息
|
| 1203 |
+
toolName, _ := blockMap["name"].(string)
|
| 1204 |
+
toolId, _ := blockMap["id"].(string)
|
| 1205 |
+
newBlock := map[string]interface{}{
|
| 1206 |
+
"type": "text",
|
| 1207 |
+
"text": fmt.Sprintf("[tool_use] %s (ID: %s)", toolName, toolId),
|
| 1208 |
+
}
|
| 1209 |
+
if cacheControl != nil {
|
| 1210 |
+
newBlock["cache_control"] = cacheControl
|
| 1211 |
+
}
|
| 1212 |
+
c[i] = newBlock
|
| 1213 |
+
case "tool_result":
|
| 1214 |
+
// 将tool_result块转换为text描述,保留缓存信息
|
| 1215 |
+
toolUseId, _ := blockMap["tool_use_id"].(string)
|
| 1216 |
+
isError, _ := blockMap["is_error"].(bool)
|
| 1217 |
+
var resultText string
|
| 1218 |
+
if isError {
|
| 1219 |
+
resultText = fmt.Sprintf("[tool_error] (ID: %s)", toolUseId)
|
| 1220 |
+
} else {
|
| 1221 |
+
resultText = fmt.Sprintf("[tool_result] (ID: %s)", toolUseId)
|
| 1222 |
+
}
|
| 1223 |
+
newBlock := map[string]interface{}{
|
| 1224 |
+
"type": "text",
|
| 1225 |
+
"text": resultText,
|
| 1226 |
+
}
|
| 1227 |
+
if cacheControl != nil {
|
| 1228 |
+
newBlock["cache_control"] = cacheControl
|
| 1229 |
+
}
|
| 1230 |
+
c[i] = newBlock
|
| 1231 |
+
default:
|
| 1232 |
+
// text块和其他类型的块保持不变,包括缓存信息
|
| 1233 |
+
// 不需要修改,保持原样
|
| 1234 |
+
}
|
| 1235 |
+
}
|
| 1236 |
+
// 非map类型的块也保持不变
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
+
msgMap["content"] = c
|
| 1240 |
+
if IsDebugMode() {
|
| 1241 |
+
log.Printf("[Anthropic] 将assistant消息转换为user消息,逐个处理内容块并保留缓存信息")
|
| 1242 |
+
}
|
| 1243 |
+
}
|
| 1244 |
+
|
| 1245 |
+
return nil
|
| 1246 |
+
}
|
| 1247 |
+
|
| 1248 |
+
// convertAssistantMessagesToUser 将请求体中的所有assistant消息转换为user消息
|
| 1249 |
+
func (s *AnthropicService) convertAssistantMessagesToUser(body []byte) ([]byte, error) {
|
| 1250 |
+
// 解析请求体
|
| 1251 |
+
var reqMap map[string]interface{}
|
| 1252 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 1253 |
+
return body, err
|
| 1254 |
+
}
|
| 1255 |
+
|
| 1256 |
+
// 处理messages数组,同时处理工具调用关系
|
| 1257 |
+
if messages, ok := reqMap["messages"].([]interface{}); ok {
|
| 1258 |
+
for i, msg := range messages {
|
| 1259 |
+
if msgMap, ok := msg.(map[string]interface{}); ok {
|
| 1260 |
+
// 无论是assistant还是user消息,都要检查并转换工具相关块
|
| 1261 |
+
if role, ok := msgMap["role"].(string); ok {
|
| 1262 |
+
if role == "assistant" {
|
| 1263 |
+
// 转换assistant消息为user消息
|
| 1264 |
+
if err := s.convertAssistantToUserMessage(msgMap); err != nil {
|
| 1265 |
+
log.Printf("[Anthropic] 转换第%d个assistant消息失败: %v", i, err)
|
| 1266 |
+
continue
|
| 1267 |
+
}
|
| 1268 |
+
} else if role == "user" {
|
| 1269 |
+
// 对于user消息,也要确保tool_result被正确处理
|
| 1270 |
+
if err := s.convertToolBlocksToText(msgMap); err != nil {
|
| 1271 |
+
log.Printf("[Anthropic] 转换第%d个user消息中的工具块失败: %v", i, err)
|
| 1272 |
+
continue
|
| 1273 |
+
}
|
| 1274 |
+
}
|
| 1275 |
+
messages[i] = msgMap
|
| 1276 |
+
}
|
| 1277 |
+
}
|
| 1278 |
+
}
|
| 1279 |
+
reqMap["messages"] = messages
|
| 1280 |
+
}
|
| 1281 |
+
|
| 1282 |
+
// 重新序列化
|
| 1283 |
+
modifiedBody, err := json.Marshal(reqMap)
|
| 1284 |
+
if err != nil {
|
| 1285 |
+
return body, err
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
if IsDebugMode() {
|
| 1289 |
+
log.Printf("[Anthropic] 已转换所有工具调用消息,处理后的请求体:")
|
| 1290 |
+
log.Printf("%s", sanitizeRequestBody(modifiedBody))
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
return modifiedBody, nil
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
// convertToolBlocksToText 将消息中的所有工具相关块转换为文本
|
| 1297 |
+
func (s *AnthropicService) convertToolBlocksToText(msgMap map[string]interface{}) error {
|
| 1298 |
+
content, ok := msgMap["content"]
|
| 1299 |
+
if !ok {
|
| 1300 |
+
return nil
|
| 1301 |
+
}
|
| 1302 |
+
|
| 1303 |
+
switch c := content.(type) {
|
| 1304 |
+
case []interface{}:
|
| 1305 |
+
// 使用range循环逐个处理每个块,将工具相关块转换为文本
|
| 1306 |
+
for i, block := range c {
|
| 1307 |
+
if blockMap, ok := block.(map[string]interface{}); ok {
|
| 1308 |
+
blockType, _ := blockMap["type"].(string)
|
| 1309 |
+
|
| 1310 |
+
// 保留原有的缓存控制信息
|
| 1311 |
+
var cacheControl interface{}
|
| 1312 |
+
if cache, hasCacheControl := blockMap["cache_control"]; hasCacheControl {
|
| 1313 |
+
cacheControl = cache
|
| 1314 |
+
}
|
| 1315 |
+
|
| 1316 |
+
switch blockType {
|
| 1317 |
+
case "tool_use":
|
| 1318 |
+
// 将tool_use块转换为text块
|
| 1319 |
+
toolName, _ := blockMap["name"].(string)
|
| 1320 |
+
toolId, _ := blockMap["id"].(string)
|
| 1321 |
+
newBlock := map[string]interface{}{
|
| 1322 |
+
"type": "text",
|
| 1323 |
+
"text": fmt.Sprintf("[tool_use] %s (ID: %s)", toolName, toolId),
|
| 1324 |
+
}
|
| 1325 |
+
if cacheControl != nil {
|
| 1326 |
+
newBlock["cache_control"] = cacheControl
|
| 1327 |
+
}
|
| 1328 |
+
c[i] = newBlock
|
| 1329 |
+
case "tool_result":
|
| 1330 |
+
// 将tool_result块转换为text块
|
| 1331 |
+
toolUseId, _ := blockMap["tool_use_id"].(string)
|
| 1332 |
+
isError, _ := blockMap["is_error"].(bool)
|
| 1333 |
+
var resultText string
|
| 1334 |
+
if isError {
|
| 1335 |
+
resultText = fmt.Sprintf("[tool_error] (ID: %s)", toolUseId)
|
| 1336 |
+
} else {
|
| 1337 |
+
resultText = fmt.Sprintf("[tool_result] (ID: %s)", toolUseId)
|
| 1338 |
+
}
|
| 1339 |
+
newBlock := map[string]interface{}{
|
| 1340 |
+
"type": "text",
|
| 1341 |
+
"text": resultText,
|
| 1342 |
+
}
|
| 1343 |
+
if cacheControl != nil {
|
| 1344 |
+
newBlock["cache_control"] = cacheControl
|
| 1345 |
+
}
|
| 1346 |
+
c[i] = newBlock
|
| 1347 |
+
default:
|
| 1348 |
+
// text块和其他类型的块保持不变
|
| 1349 |
+
// 不需要修改,保持原样
|
| 1350 |
+
}
|
| 1351 |
+
}
|
| 1352 |
+
}
|
| 1353 |
+
|
| 1354 |
+
msgMap["content"] = c
|
| 1355 |
+
if IsDebugMode() {
|
| 1356 |
+
log.Printf("[Anthropic] 已将消息中的工具块转换为文本格式")
|
| 1357 |
+
}
|
| 1358 |
+
}
|
| 1359 |
+
|
| 1360 |
+
return nil
|
| 1361 |
+
}
|
| 1362 |
+
|
| 1363 |
+
// adjustParametersForModel 根据模型要求调整参数,避免冲突
|
| 1364 |
+
func (s *AnthropicService) adjustParametersForModel(body []byte, modelID string) ([]byte, error) {
|
| 1365 |
+
// 对于 claude-opus-4-5-20251101 等模型,不能同时有 temperature 和 top_p
|
| 1366 |
+
modelsNoTopP := []string{
|
| 1367 |
+
"claude-opus-4-5-20251101",
|
| 1368 |
+
"claude-opus-4-1-20250805",
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
for _, model := range modelsNoTopP {
|
| 1372 |
+
if modelID == model {
|
| 1373 |
+
body, _ = s.removeTopP(body)
|
| 1374 |
+
break
|
| 1375 |
+
}
|
| 1376 |
+
}
|
| 1377 |
+
|
| 1378 |
+
// 继续处理温度参数
|
| 1379 |
+
return s.adjustTemperatureForModel(body, modelID)
|
| 1380 |
+
}
|
| 1381 |
+
|
| 1382 |
+
func (s *AnthropicService) streamFilteredResponse(w http.ResponseWriter, resp *http.Response) error {
|
| 1383 |
+
// 复制响应头
|
| 1384 |
+
for k, v := range resp.Header {
|
| 1385 |
+
if k != "Content-Encoding" && k != "Content-Length" {
|
| 1386 |
+
for _, vv := range v {
|
| 1387 |
+
w.Header().Add(k, vv)
|
| 1388 |
+
}
|
| 1389 |
+
}
|
| 1390 |
+
}
|
| 1391 |
+
w.WriteHeader(resp.StatusCode)
|
| 1392 |
+
|
| 1393 |
+
flusher, ok := w.(http.Flusher)
|
| 1394 |
+
if !ok {
|
| 1395 |
+
_, err := io.Copy(w, resp.Body)
|
| 1396 |
+
return err
|
| 1397 |
+
}
|
| 1398 |
+
|
| 1399 |
+
reader := bufio.NewReader(resp.Body)
|
| 1400 |
+
isThinking := false // 标记当前是否处于 thinking block 中
|
| 1401 |
+
|
| 1402 |
+
for {
|
| 1403 |
+
line, err := reader.ReadString('\n')
|
| 1404 |
+
if err != nil {
|
| 1405 |
+
if err == io.EOF {
|
| 1406 |
+
return nil
|
| 1407 |
+
}
|
| 1408 |
+
return err
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
trimmedLine := strings.TrimSpace(line)
|
| 1412 |
+
if trimmedLine == "" {
|
| 1413 |
+
fmt.Fprintf(w, "\n")
|
| 1414 |
+
flusher.Flush()
|
| 1415 |
+
continue
|
| 1416 |
+
}
|
| 1417 |
+
|
| 1418 |
+
if strings.HasPrefix(trimmedLine, "event:") {
|
| 1419 |
+
// 读取下一行 data
|
| 1420 |
+
dataLine, err := reader.ReadString('\n')
|
| 1421 |
+
if err != nil {
|
| 1422 |
+
return err
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
// 解析 event 类型
|
| 1426 |
+
event := strings.TrimSpace(strings.TrimPrefix(trimmedLine, "event:"))
|
| 1427 |
+
data := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:"))
|
| 1428 |
+
|
| 1429 |
+
var shouldFilter bool
|
| 1430 |
+
|
| 1431 |
+
if event == "content_block_start" {
|
| 1432 |
+
var payload struct {
|
| 1433 |
+
ContentBlock struct {
|
| 1434 |
+
Type string `json:"type"`
|
| 1435 |
+
} `json:"content_block"`
|
| 1436 |
+
}
|
| 1437 |
+
if json.Unmarshal([]byte(data), &payload) == nil {
|
| 1438 |
+
if payload.ContentBlock.Type == "thinking" || payload.ContentBlock.Type == "thought" {
|
| 1439 |
+
isThinking = true
|
| 1440 |
+
shouldFilter = true
|
| 1441 |
+
}
|
| 1442 |
+
|
| 1443 |
+
}
|
| 1444 |
+
} else if event == "content_block_delta" {
|
| 1445 |
+
if isThinking {
|
| 1446 |
+
shouldFilter = true
|
| 1447 |
+
}
|
| 1448 |
+
} else if event == "content_block_stop" {
|
| 1449 |
+
if isThinking {
|
| 1450 |
+
shouldFilter = true
|
| 1451 |
+
isThinking = false
|
| 1452 |
+
}
|
| 1453 |
+
}
|
| 1454 |
+
|
| 1455 |
+
if !shouldFilter {
|
| 1456 |
+
fmt.Fprint(w, line) // event: ...
|
| 1457 |
+
fmt.Fprint(w, dataLine) // data: ...
|
| 1458 |
+
flusher.Flush()
|
| 1459 |
+
}
|
| 1460 |
+
} else {
|
| 1461 |
+
// 其他格式(如 ping),直接透传
|
| 1462 |
+
fmt.Fprint(w, line)
|
| 1463 |
+
flusher.Flush()
|
| 1464 |
+
}
|
| 1465 |
+
}
|
| 1466 |
+
}
|
| 1467 |
+
|
| 1468 |
+
// retryWithProxy 使用代理池重试请求
|
| 1469 |
+
func (s *AnthropicService) retryWithProxy(ctx context.Context, account *model.Account, modelID string, body []byte) (*http.Response, error) {
|
| 1470 |
+
// 获取模型配置
|
| 1471 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 1472 |
+
if !exists {
|
| 1473 |
+
return nil, fmt.Errorf("模型配置不存在: %s", modelID)
|
| 1474 |
+
}
|
| 1475 |
+
|
| 1476 |
+
// 预处理请求体 - 确保包含所需的thinking配置和参数调整
|
| 1477 |
+
processedBody, err := s.preprocessRequestBody(body, modelID, zenModel)
|
| 1478 |
+
if err != nil {
|
| 1479 |
+
log.Printf("[Anthropic] 代理重试请求体预处理失败: %v", err)
|
| 1480 |
+
// 如果预处理失败,使用原始body
|
| 1481 |
+
processedBody = body
|
| 1482 |
+
}
|
| 1483 |
+
|
| 1484 |
+
proxyPool := provider.GetProxyPool()
|
| 1485 |
+
if !proxyPool.HasProxies() {
|
| 1486 |
+
return nil, fmt.Errorf("没有可用的代理")
|
| 1487 |
+
}
|
| 1488 |
+
|
| 1489 |
+
maxRetries := 3
|
| 1490 |
+
for i := 0; i < maxRetries; i++ {
|
| 1491 |
+
// 获取随机代理
|
| 1492 |
+
proxyURL := proxyPool.GetRandomProxy()
|
| 1493 |
+
if proxyURL == "" {
|
| 1494 |
+
continue
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
log.Printf("[Anthropic] 尝试代理 %s (重试 %d/%d)", proxyURL, i+1, maxRetries)
|
| 1498 |
+
|
| 1499 |
+
// 创建使用代理的HTTP客户端
|
| 1500 |
+
proxyClient, err := provider.NewHTTPClientWithProxy(proxyURL, 0)
|
| 1501 |
+
if err != nil {
|
| 1502 |
+
log.Printf("[Anthropic] 创建代理客户端失败: %v", err)
|
| 1503 |
+
continue
|
| 1504 |
+
}
|
| 1505 |
+
|
| 1506 |
+
// 创建新请求
|
| 1507 |
+
httpReq, err := http.NewRequest("POST", AnthropicBaseURL+"/v1/messages", bytes.NewReader(processedBody))
|
| 1508 |
+
if err != nil {
|
| 1509 |
+
log.Printf("[Anthropic] 创建请求失败: %v", err)
|
| 1510 |
+
continue
|
| 1511 |
+
}
|
| 1512 |
+
|
| 1513 |
+
// 设置请求头
|
| 1514 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 1515 |
+
httpReq.Header.Set("anthropic-version", "2023-06-01")
|
| 1516 |
+
|
| 1517 |
+
// 添加模型配置的额外请求头
|
| 1518 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 1519 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 1520 |
+
httpReq.Header.Set(k, v)
|
| 1521 |
+
}
|
| 1522 |
+
}
|
| 1523 |
+
|
| 1524 |
+
// 只在非限速测试且调试模式下记录代理请求详情
|
| 1525 |
+
var reqCheck struct {
|
| 1526 |
+
Model string `json:"model"`
|
| 1527 |
+
}
|
| 1528 |
+
if IsDebugMode() && json.Unmarshal(body, &reqCheck) == nil && !strings.Contains(reqCheck.Model, "test") {
|
| 1529 |
+
log.Printf("[Anthropic] 代理请求详情 - URL: %s", httpReq.URL.String())
|
| 1530 |
+
logRequestDetails("[Anthropic] 代理请求", httpReq.Header, processedBody)
|
| 1531 |
+
}
|
| 1532 |
+
|
| 1533 |
+
// 执行请求
|
| 1534 |
+
resp, err := proxyClient.Do(httpReq)
|
| 1535 |
+
if err != nil {
|
| 1536 |
+
log.Printf("[Anthropic] 代理请求失败: %v", err)
|
| 1537 |
+
continue
|
| 1538 |
+
}
|
| 1539 |
+
|
| 1540 |
+
// 检查响应状态
|
| 1541 |
+
if resp.StatusCode == 429 {
|
| 1542 |
+
// 仍然是429,尝试下一个代理
|
| 1543 |
+
resp.Body.Close()
|
| 1544 |
+
log.Printf("[Anthropic] 代理 %s 仍返回429,尝试下一个", proxyURL)
|
| 1545 |
+
continue
|
| 1546 |
+
}
|
| 1547 |
+
|
| 1548 |
+
if resp.StatusCode >= 400 {
|
| 1549 |
+
// 其他错误,记录并尝试下一个代理
|
| 1550 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 1551 |
+
resp.Body.Close()
|
| 1552 |
+
|
| 1553 |
+
// 解析thinking状态
|
| 1554 |
+
thinkingStatus := "disabled"
|
| 1555 |
+
var reqCheck struct {
|
| 1556 |
+
Thinking map[string]interface{} `json:"thinking,omitempty"`
|
| 1557 |
+
}
|
| 1558 |
+
json.Unmarshal(body, &reqCheck)
|
| 1559 |
+
if reqCheck.Thinking != nil {
|
| 1560 |
+
if enabled, ok := reqCheck.Thinking["enabled"].(bool); ok && enabled {
|
| 1561 |
+
thinkingStatus = "enabled"
|
| 1562 |
+
} else if thinkingType, ok := reqCheck.Thinking["type"].(string); ok && thinkingType == "enabled" {
|
| 1563 |
+
thinkingStatus = "enabled"
|
| 1564 |
+
}
|
| 1565 |
+
// 如果有thinking配置且有budget_tokens,也记录
|
| 1566 |
+
if budget, ok := reqCheck.Thinking["budget_tokens"].(float64); ok && budget > 0 {
|
| 1567 |
+
thinkingStatus = fmt.Sprintf("enabled(budget=%g)", budget)
|
| 1568 |
+
}
|
| 1569 |
+
}
|
| 1570 |
+
|
| 1571 |
+
log.Printf("[Anthropic] 代理 %s 返回错误 %d: %s (Model: %s, Thinking: %s)", proxyURL, resp.StatusCode, string(errBody), modelID, thinkingStatus)
|
| 1572 |
+
continue
|
| 1573 |
+
}
|
| 1574 |
+
|
| 1575 |
+
// 成功
|
| 1576 |
+
log.Printf("[Anthropic] 代理 %s 请求成功", proxyURL)
|
| 1577 |
+
return resp, nil
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
+
return nil, fmt.Errorf("所有代理重试均失败")
|
| 1581 |
+
}
|
| 1582 |
+
|
| 1583 |
+
// preprocessRequestBody 预处理请求体,应用所有必要的配置和调整
|
| 1584 |
+
func (s *AnthropicService) preprocessRequestBody(body []byte, modelID string, zenModel model.ZenModel) ([]byte, error) {
|
| 1585 |
+
// 注意:已移除模型替换逻辑,直接使用原始请求体
|
| 1586 |
+
modifiedBody := body
|
| 1587 |
+
|
| 1588 |
+
// 2. 确保thinking配置
|
| 1589 |
+
var err error
|
| 1590 |
+
modifiedBody, err = s.ensureThinkingConfig(modifiedBody, modelID)
|
| 1591 |
+
if err != nil {
|
| 1592 |
+
return modifiedBody, fmt.Errorf("确保thinking配置失败: %w", err)
|
| 1593 |
+
}
|
| 1594 |
+
|
| 1595 |
+
// 3. 根据模型调整参数
|
| 1596 |
+
modifiedBody, err = s.adjustParametersForModel(modifiedBody, modelID)
|
| 1597 |
+
if err != nil {
|
| 1598 |
+
return modifiedBody, fmt.Errorf("调整模型参数失败: %w", err)
|
| 1599 |
+
}
|
| 1600 |
+
|
| 1601 |
+
return modifiedBody, nil
|
| 1602 |
+
}
|
internal/service/api.go
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"net/http"
|
| 6 |
+
|
| 7 |
+
"zencoder-2api/internal/model"
|
| 8 |
+
"zencoder-2api/internal/service/provider"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
type APIService struct {
|
| 12 |
+
manager *provider.Manager
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
func NewAPIService() *APIService {
|
| 16 |
+
return &APIService{
|
| 17 |
+
manager: provider.GetManager(),
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
func (s *APIService) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 22 |
+
// 检查模型是否存在于模型字典中
|
| 23 |
+
_, exists := model.GetZenModel(req.Model)
|
| 24 |
+
if !exists {
|
| 25 |
+
return nil, ErrNoAvailableAccount
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
var lastErr error
|
| 29 |
+
|
| 30 |
+
for i := 0; i < MaxRetries; i++ {
|
| 31 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 32 |
+
if err != nil {
|
| 33 |
+
return nil, err
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
resp, err := s.doChat(account, req)
|
| 37 |
+
if err != nil {
|
| 38 |
+
MarkAccountError(account)
|
| 39 |
+
lastErr = err
|
| 40 |
+
continue
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
ResetAccountError(account)
|
| 44 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 45 |
+
if !exists {
|
| 46 |
+
// 模型不存在,使用默认倍率
|
| 47 |
+
UseCredit(account, 1.0)
|
| 48 |
+
} else {
|
| 49 |
+
// API服务没有HTTP响应,只能使用模型倍率
|
| 50 |
+
UseCredit(account, zenModel.Multiplier)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
return resp, nil
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
func (s *APIService) doChat(account *model.Account, req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 60 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 61 |
+
if !exists {
|
| 62 |
+
return nil, ErrNoAvailableAccount
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
cfg := s.buildConfig(account, zenModel)
|
| 66 |
+
p, err := s.manager.GetProvider(account.ID, zenModel, cfg)
|
| 67 |
+
if err != nil {
|
| 68 |
+
return nil, err
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// 注意:已移除模型重定向逻辑,直接使用用户请求的模型名
|
| 72 |
+
|
| 73 |
+
return p.Chat(req)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
func (s *APIService) buildConfig(account *model.Account, zenModel model.ZenModel) provider.Config {
|
| 77 |
+
cfg := provider.Config{
|
| 78 |
+
APIKey: account.AccessToken,
|
| 79 |
+
Proxy: account.Proxy,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// 设置额外请求头
|
| 83 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 84 |
+
cfg.ExtraHeaders = zenModel.Parameters.ExtraHeaders
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
return cfg
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
func (s *APIService) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 91 |
+
// 检查模型是否存在于模型字典中
|
| 92 |
+
_, exists := model.GetZenModel(req.Model)
|
| 93 |
+
if !exists {
|
| 94 |
+
return ErrNoAvailableAccount
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
var lastErr error
|
| 98 |
+
|
| 99 |
+
for i := 0; i < MaxRetries; i++ {
|
| 100 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 101 |
+
if err != nil {
|
| 102 |
+
return err
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
err = s.doChatStream(account, req, writer)
|
| 106 |
+
if err != nil {
|
| 107 |
+
MarkAccountError(account)
|
| 108 |
+
lastErr = err
|
| 109 |
+
continue
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
ResetAccountError(account)
|
| 113 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 114 |
+
if !exists {
|
| 115 |
+
// 模型不存在,使用默认倍率
|
| 116 |
+
UseCredit(account, 1.0)
|
| 117 |
+
} else {
|
| 118 |
+
// 流式响应,使用模型倍率
|
| 119 |
+
UseCredit(account, zenModel.Multiplier)
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return nil
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
return fmt.Errorf("all retries failed: %w", lastErr)
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
func (s *APIService) doChatStream(account *model.Account, req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 129 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 130 |
+
if !exists {
|
| 131 |
+
return ErrNoAvailableAccount
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
cfg := s.buildConfig(account, zenModel)
|
| 135 |
+
p, err := s.manager.GetProvider(account.ID, zenModel, cfg)
|
| 136 |
+
if err != nil {
|
| 137 |
+
return err
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// 注意:已移除模型重定向逻辑,直接使用用户请求的模型名
|
| 141 |
+
|
| 142 |
+
return p.ChatStream(req, writer)
|
| 143 |
+
}
|
internal/service/autogen.go
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"errors"
|
| 5 |
+
"fmt"
|
| 6 |
+
"log"
|
| 7 |
+
"strings"
|
| 8 |
+
"sync"
|
| 9 |
+
"time"
|
| 10 |
+
"zencoder-2api/internal/database"
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
|
| 13 |
+
"gorm.io/gorm"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
type AutoGenerationService struct {
|
| 17 |
+
mu sync.Mutex
|
| 18 |
+
lastTriggered map[uint]time.Time // tokenID -> last triggered time
|
| 19 |
+
isGenerating map[uint]bool // tokenID -> is generating
|
| 20 |
+
debounceTime time.Duration // 防抖时间
|
| 21 |
+
generationDelay time.Duration // 生成任务间隔时间
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
var autoGenService *AutoGenerationService
|
| 25 |
+
|
| 26 |
+
func InitAutoGenerationService() {
|
| 27 |
+
autoGenService = &AutoGenerationService{
|
| 28 |
+
lastTriggered: make(map[uint]time.Time),
|
| 29 |
+
isGenerating: make(map[uint]bool),
|
| 30 |
+
debounceTime: 5 * time.Minute, // 5分钟防抖
|
| 31 |
+
generationDelay: 1 * time.Hour, // 生成任务间隔1小时
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// 启动监控协程
|
| 35 |
+
go autoGenService.startMonitoring()
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// SaveGenerationToken 保存生成模式使用的token
|
| 39 |
+
func SaveGenerationToken(token string, description string) error {
|
| 40 |
+
db := database.GetDB()
|
| 41 |
+
|
| 42 |
+
// 检查是否已存在
|
| 43 |
+
var existing model.TokenRecord
|
| 44 |
+
if err := db.Where("token = ?", token).First(&existing).Error; err == nil {
|
| 45 |
+
// 更新最后生成时间
|
| 46 |
+
existing.LastGeneratedAt = time.Now()
|
| 47 |
+
existing.GeneratedCount += 1
|
| 48 |
+
return db.Save(&existing).Error
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// 创建新记录
|
| 52 |
+
record := model.TokenRecord{
|
| 53 |
+
Token: token,
|
| 54 |
+
Description: description,
|
| 55 |
+
GeneratedCount: 1,
|
| 56 |
+
LastGeneratedAt: time.Now(),
|
| 57 |
+
AutoGenerate: true,
|
| 58 |
+
Threshold: 10,
|
| 59 |
+
GenerateBatch: 30,
|
| 60 |
+
IsActive: true,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
return db.Create(&record).Error
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// SaveGenerationTokenWithRefresh 保存生成模式使用的 refresh_token
|
| 67 |
+
func SaveGenerationTokenWithRefresh(refreshToken string, accessToken string, description string, expiresIn int) error {
|
| 68 |
+
db := database.GetDB()
|
| 69 |
+
|
| 70 |
+
// 计算过期时间
|
| 71 |
+
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
| 72 |
+
|
| 73 |
+
// 解析JWT获取用户信息,特别是邮箱
|
| 74 |
+
var email, planType string
|
| 75 |
+
var subscriptionDate time.Time
|
| 76 |
+
|
| 77 |
+
if accessToken != "" {
|
| 78 |
+
if payload, err := ParseJWT(accessToken); err == nil {
|
| 79 |
+
email = payload.Email
|
| 80 |
+
planType = payload.CustomClaims.Plan
|
| 81 |
+
if planType != "" {
|
| 82 |
+
planType = strings.ToUpper(planType[:1]) + planType[1:]
|
| 83 |
+
}
|
| 84 |
+
subscriptionDate = GetSubscriptionDate(payload)
|
| 85 |
+
log.Printf("[SaveGenerationToken] 解析JWT成功: Email=%s, Plan=%s, SubStart=%s",
|
| 86 |
+
email, planType, subscriptionDate.Format("2006-01-02"))
|
| 87 |
+
} else {
|
| 88 |
+
log.Printf("[SaveGenerationToken] 解析JWT失败: %v", err)
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// 如果有邮箱,按邮箱查找;否则按refresh_token查找
|
| 93 |
+
var existing model.TokenRecord
|
| 94 |
+
var err error
|
| 95 |
+
|
| 96 |
+
if email != "" {
|
| 97 |
+
// 优先按邮箱查找,实现相同邮箱的记录合并
|
| 98 |
+
err = db.Where("email = ?", email).First(&existing).Error
|
| 99 |
+
} else {
|
| 100 |
+
// 没有邮箱时,按refresh_token查找
|
| 101 |
+
err = db.Where("refresh_token = ?", refreshToken).First(&existing).Error
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
if err == nil {
|
| 105 |
+
// 更新现有记录
|
| 106 |
+
updates := map[string]interface{}{
|
| 107 |
+
"token": accessToken,
|
| 108 |
+
"refresh_token": refreshToken,
|
| 109 |
+
"token_expiry": expiresAt,
|
| 110 |
+
"description": description,
|
| 111 |
+
"updated_at": time.Now(),
|
| 112 |
+
"plan_type": planType,
|
| 113 |
+
"subscription_start_date": subscriptionDate,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// 如果之前没有refresh_token,标记为有
|
| 117 |
+
if existing.RefreshToken == "" {
|
| 118 |
+
updates["has_refresh_token"] = true
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
return db.Model(&existing).Updates(updates).Error
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// 创建新记录
|
| 125 |
+
record := model.TokenRecord{
|
| 126 |
+
Token: accessToken,
|
| 127 |
+
RefreshToken: refreshToken,
|
| 128 |
+
TokenExpiry: expiresAt,
|
| 129 |
+
Description: description,
|
| 130 |
+
Email: email,
|
| 131 |
+
PlanType: planType,
|
| 132 |
+
SubscriptionStartDate: subscriptionDate,
|
| 133 |
+
HasRefreshToken: true,
|
| 134 |
+
CreatedAt: time.Now(),
|
| 135 |
+
UpdatedAt: time.Now(),
|
| 136 |
+
AutoGenerate: true,
|
| 137 |
+
Threshold: 10,
|
| 138 |
+
GenerateBatch: 30,
|
| 139 |
+
IsActive: true,
|
| 140 |
+
GeneratedCount: 0,
|
| 141 |
+
TotalSuccess: 0,
|
| 142 |
+
TotalFail: 0,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
if err := db.Create(&record).Error; err != nil {
|
| 146 |
+
return fmt.Errorf("failed to save generation token: %w", err)
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return nil
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// GetActiveTokenRecords 获取所有活跃的token记录
|
| 153 |
+
func GetActiveTokenRecords() ([]model.TokenRecord, error) {
|
| 154 |
+
var records []model.TokenRecord
|
| 155 |
+
err := database.GetDB().Where("is_active = ?", true).Find(&records).Error
|
| 156 |
+
return records, err
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// GetAllTokenRecords 获取所有token记录
|
| 160 |
+
func GetAllTokenRecords() ([]model.TokenRecord, error) {
|
| 161 |
+
var records []model.TokenRecord
|
| 162 |
+
err := database.GetDB().Order("created_at DESC").Find(&records).Error
|
| 163 |
+
return records, err
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// GetGenerationTasks 获取生成任务历史
|
| 167 |
+
func GetGenerationTasks(tokenRecordID uint) ([]model.GenerationTask, error) {
|
| 168 |
+
var tasks []model.GenerationTask
|
| 169 |
+
query := database.GetDB().Order("created_at DESC")
|
| 170 |
+
if tokenRecordID > 0 {
|
| 171 |
+
query = query.Where("token_record_id = ?", tokenRecordID)
|
| 172 |
+
}
|
| 173 |
+
err := query.Find(&tasks).Error
|
| 174 |
+
return tasks, err
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// UpdateTokenRecord 更新token记录设置
|
| 178 |
+
func UpdateTokenRecord(id uint, updates map[string]interface{}) error {
|
| 179 |
+
return database.GetDB().Model(&model.TokenRecord{}).Where("id = ?", id).Updates(updates).Error
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// 监控账号池并触发自动生成
|
| 183 |
+
func (s *AutoGenerationService) startMonitoring() {
|
| 184 |
+
ticker := time.NewTicker(1 * time.Minute) // 每分钟检查一次
|
| 185 |
+
defer ticker.Stop()
|
| 186 |
+
|
| 187 |
+
for range ticker.C {
|
| 188 |
+
s.checkAndTriggerGeneration()
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// 检查并触发生成
|
| 193 |
+
func (s *AutoGenerationService) checkAndTriggerGeneration() {
|
| 194 |
+
// 获取所有活跃的token记录
|
| 195 |
+
records, err := GetActiveTokenRecords()
|
| 196 |
+
if err != nil {
|
| 197 |
+
log.Printf("[AutoGen] 获取token记录失败: %v", err)
|
| 198 |
+
return
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// 计算当前可用账号数量
|
| 202 |
+
var activeAccountCount int64
|
| 203 |
+
database.GetDB().Model(&model.Account{}).
|
| 204 |
+
Where("status = ?", "normal").
|
| 205 |
+
Where("token_expiry > ?", time.Now()).
|
| 206 |
+
Count(&activeAccountCount)
|
| 207 |
+
|
| 208 |
+
log.Printf("[AutoGen] 当前活跃账号数量: %d", activeAccountCount)
|
| 209 |
+
|
| 210 |
+
// 检查每个token记录的阈值
|
| 211 |
+
for _, record := range records {
|
| 212 |
+
if !record.AutoGenerate || !record.IsActive {
|
| 213 |
+
continue
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// 检查是否达到阈值
|
| 217 |
+
if int(activeAccountCount) <= record.Threshold {
|
| 218 |
+
s.triggerGeneration(record)
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
// 触发生成任务(带防抖)
|
| 224 |
+
func (s *AutoGenerationService) triggerGeneration(record model.TokenRecord) {
|
| 225 |
+
s.mu.Lock()
|
| 226 |
+
defer s.mu.Unlock()
|
| 227 |
+
|
| 228 |
+
// 检查是否正在生成
|
| 229 |
+
if s.isGenerating[record.ID] {
|
| 230 |
+
log.Printf("[AutoGen] Token %d 正在生成中,跳过", record.ID)
|
| 231 |
+
return
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// 检查防抖时间
|
| 235 |
+
if lastTime, ok := s.lastTriggered[record.ID]; ok {
|
| 236 |
+
if time.Since(lastTime) < s.debounceTime {
|
| 237 |
+
log.Printf("[AutoGen] Token %d 防抖中,距上次触发 %v", record.ID, time.Since(lastTime))
|
| 238 |
+
return
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// 检查生成间隔
|
| 243 |
+
if !record.LastGeneratedAt.IsZero() && time.Since(record.LastGeneratedAt) < s.generationDelay {
|
| 244 |
+
log.Printf("[AutoGen] Token %d 未达到生成间隔时间,距上次生成 %v", record.ID, time.Since(record.LastGeneratedAt))
|
| 245 |
+
return
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// 标记开始生成
|
| 249 |
+
s.isGenerating[record.ID] = true
|
| 250 |
+
s.lastTriggered[record.ID] = time.Now()
|
| 251 |
+
|
| 252 |
+
// 异步执行生成任务
|
| 253 |
+
go s.executeGeneration(record)
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
// 执行生成任务
|
| 257 |
+
func (s *AutoGenerationService) executeGeneration(record model.TokenRecord) {
|
| 258 |
+
defer func() {
|
| 259 |
+
s.mu.Lock()
|
| 260 |
+
s.isGenerating[record.ID] = false
|
| 261 |
+
s.mu.Unlock()
|
| 262 |
+
}()
|
| 263 |
+
|
| 264 |
+
log.Printf("[AutoGen] 开始自动生成任务 - Token %d, 批次大小: %d", record.ID, record.GenerateBatch)
|
| 265 |
+
|
| 266 |
+
// 检查token记录状态
|
| 267 |
+
if record.Status != "active" {
|
| 268 |
+
log.Printf("[AutoGen] Token记录 %d 状态异常 (%s),跳过生成任务", record.ID, record.Status)
|
| 269 |
+
return
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// 检查token是否需要刷新
|
| 273 |
+
if record.RefreshToken != "" && time.Now().After(record.TokenExpiry.Add(-time.Hour)) {
|
| 274 |
+
log.Printf("[AutoGen] Token记录 %d 的token即将过期,尝试刷新", record.ID)
|
| 275 |
+
if err := UpdateTokenRecordToken(&record); err != nil {
|
| 276 |
+
log.Printf("[AutoGen] Token记录 %d 刷新失败,停止生成任务: %v", record.ID, err)
|
| 277 |
+
return
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// 创建生成任务记录
|
| 282 |
+
task := model.GenerationTask{
|
| 283 |
+
TokenRecordID: record.ID,
|
| 284 |
+
Token: record.Token,
|
| 285 |
+
BatchSize: record.GenerateBatch,
|
| 286 |
+
Status: "running",
|
| 287 |
+
StartedAt: time.Now(),
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
if err := database.GetDB().Create(&task).Error; err != nil {
|
| 291 |
+
log.Printf("[AutoGen] 创建任务记录失败: %v", err)
|
| 292 |
+
return
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
// 批量生成凭证
|
| 296 |
+
credentials, errs := BatchGenerateCredentials(record.Token, record.GenerateBatch)
|
| 297 |
+
|
| 298 |
+
// 检查生成过程中是否有token失效的错误
|
| 299 |
+
for _, err := range errs {
|
| 300 |
+
if strings.Contains(err.Error(), "locked out") || strings.Contains(err.Error(), "User is locked out") {
|
| 301 |
+
log.Printf("[AutoGen] 检测到原始token被锁定,禁用token记录 %d: %v", record.ID, err)
|
| 302 |
+
// 将token记录标记为封禁状态
|
| 303 |
+
if markErr := markTokenRecordAsBanned(&record, "原始token被锁定: "+err.Error()); markErr != nil {
|
| 304 |
+
log.Printf("[AutoGen] 标记token记录封禁状态失败: %v", markErr)
|
| 305 |
+
}
|
| 306 |
+
// 根据邮箱禁用相关的token记录
|
| 307 |
+
if record.Email != "" {
|
| 308 |
+
if disableErr := disableTokenRecordsByEmail(record.Email, "关联账号被锁定"); disableErr != nil {
|
| 309 |
+
log.Printf("[AutoGen] 禁用相关token记录失败: %v", disableErr)
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
// 提前结束任务
|
| 313 |
+
task.Status = "failed"
|
| 314 |
+
task.ErrorMessage = "原始token被锁定"
|
| 315 |
+
task.CompletedAt = time.Now()
|
| 316 |
+
database.GetDB().Save(&task)
|
| 317 |
+
return
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
successCount := 0
|
| 322 |
+
failCount := len(errs)
|
| 323 |
+
|
| 324 |
+
// 处理生成的凭证
|
| 325 |
+
for _, cred := range credentials {
|
| 326 |
+
account := model.Account{
|
| 327 |
+
ClientID: cred.ClientID,
|
| 328 |
+
ClientSecret: cred.Secret,
|
| 329 |
+
IsActive: true,
|
| 330 |
+
Status: "normal",
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// 获取Token并解析信息
|
| 334 |
+
if _, err := RefreshToken(&account); err != nil {
|
| 335 |
+
failCount++
|
| 336 |
+
// 检查是否是账号锁定错误
|
| 337 |
+
if lockoutErr, ok := err.(*AccountLockoutError); ok {
|
| 338 |
+
log.Printf("[AutoGen] 账号 %s 被锁定: %s", cred.ClientID, lockoutErr.Body)
|
| 339 |
+
} else {
|
| 340 |
+
log.Printf("[AutoGen] 账号 %s 认证失败: %v", cred.ClientID, err)
|
| 341 |
+
}
|
| 342 |
+
continue
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
// 解析JWT获取详细信息
|
| 346 |
+
if payload, err := ParseJWT(account.AccessToken); err == nil {
|
| 347 |
+
account.Email = payload.Email
|
| 348 |
+
account.SubscriptionStartDate = GetSubscriptionDate(payload)
|
| 349 |
+
|
| 350 |
+
if payload.Expiration > 0 {
|
| 351 |
+
account.TokenExpiry = time.Unix(payload.Expiration, 0)
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
// 设置计划类型
|
| 355 |
+
plan := "Free"
|
| 356 |
+
if payload.CustomClaims.Plan != "" {
|
| 357 |
+
plan = payload.CustomClaims.Plan
|
| 358 |
+
}
|
| 359 |
+
account.PlanType = model.PlanType(plan)
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// 保存账号
|
| 363 |
+
var existing model.Account
|
| 364 |
+
err := database.GetDB().Where("client_id = ?", account.ClientID).First(&existing).Error
|
| 365 |
+
|
| 366 |
+
if err == nil {
|
| 367 |
+
// 更新已存在的账号
|
| 368 |
+
existing.AccessToken = account.AccessToken
|
| 369 |
+
existing.TokenExpiry = account.TokenExpiry
|
| 370 |
+
existing.PlanType = account.PlanType
|
| 371 |
+
existing.Email = account.Email
|
| 372 |
+
existing.SubscriptionStartDate = account.SubscriptionStartDate
|
| 373 |
+
existing.IsActive = true
|
| 374 |
+
existing.Status = "normal"
|
| 375 |
+
existing.ClientSecret = account.ClientSecret
|
| 376 |
+
|
| 377 |
+
if err := database.GetDB().Save(&existing).Error; err != nil {
|
| 378 |
+
failCount++
|
| 379 |
+
log.Printf("[AutoGen] 更新账号 %s 失败: %v", account.ClientID, err)
|
| 380 |
+
} else {
|
| 381 |
+
successCount++
|
| 382 |
+
}
|
| 383 |
+
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
| 384 |
+
// 记录不存在是正常的,创建新账号(不输出错误日志)
|
| 385 |
+
if err := database.GetDB().Create(&account).Error; err != nil {
|
| 386 |
+
failCount++
|
| 387 |
+
log.Printf("[AutoGen] 创建账号 %s 失败: %v", account.ClientID, err)
|
| 388 |
+
} else {
|
| 389 |
+
successCount++
|
| 390 |
+
}
|
| 391 |
+
} else {
|
| 392 |
+
// 其他数据库错误(非record not found的真实错误)
|
| 393 |
+
failCount++
|
| 394 |
+
log.Printf("[AutoGen] 查询账号 %s 时发生数据库错误: %v", account.ClientID, err)
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
// 更新任务状态
|
| 399 |
+
task.SuccessCount = successCount
|
| 400 |
+
task.FailCount = failCount
|
| 401 |
+
task.Status = "completed"
|
| 402 |
+
if successCount == 0 && failCount > 0 {
|
| 403 |
+
task.Status = "failed"
|
| 404 |
+
task.ErrorMessage = fmt.Sprintf("所有账号生成失败")
|
| 405 |
+
}
|
| 406 |
+
task.CompletedAt = time.Now()
|
| 407 |
+
|
| 408 |
+
if err := database.GetDB().Save(&task).Error; err != nil {
|
| 409 |
+
log.Printf("[AutoGen] 更新任务记录失败: %v", err)
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// 更新token记录,累计所有统计数据
|
| 413 |
+
updates := map[string]interface{}{
|
| 414 |
+
"last_generated_at": time.Now(),
|
| 415 |
+
"generated_count": gorm.Expr("generated_count + ?", successCount),
|
| 416 |
+
"total_success": gorm.Expr("total_success + ?", successCount),
|
| 417 |
+
"total_fail": gorm.Expr("total_fail + ?", failCount),
|
| 418 |
+
"total_tasks": gorm.Expr("total_tasks + 1"),
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
if err := database.GetDB().Model(&model.TokenRecord{}).
|
| 422 |
+
Where("id = ?", record.ID).
|
| 423 |
+
Updates(updates).Error; err != nil {
|
| 424 |
+
log.Printf("[AutoGen] 更新token记录失败: %v", err)
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
// 刷新账号池
|
| 428 |
+
RefreshAccountPool()
|
| 429 |
+
|
| 430 |
+
log.Printf("[AutoGen] 自动生成完成 - 成功: %d, 失败: %d", successCount, failCount)
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
// ManualTriggerGeneration 手动触发生成
|
| 434 |
+
func ManualTriggerGeneration(tokenRecordID uint) error {
|
| 435 |
+
var record model.TokenRecord
|
| 436 |
+
if err := database.GetDB().First(&record, tokenRecordID).Error; err != nil {
|
| 437 |
+
return fmt.Errorf("token记录不存在: %v", err)
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
if !record.IsActive {
|
| 441 |
+
return fmt.Errorf("token记录未激活")
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
go autoGenService.executeGeneration(record)
|
| 445 |
+
return nil
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
// RefreshAccountPool 刷新账号池
|
| 449 |
+
func RefreshAccountPool() {
|
| 450 |
+
if pool != nil {
|
| 451 |
+
pool.refresh()
|
| 452 |
+
}
|
| 453 |
+
}
|
internal/service/credential.go
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"math/rand"
|
| 9 |
+
"net/http"
|
| 10 |
+
"time"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
const (
|
| 14 |
+
CredentialGenerateURL = "https://fe.zencoder.ai/frontegg/identity/resources/users/api-tokens/v1"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
type CredentialGenerateRequest struct {
|
| 18 |
+
Description string `json:"description"`
|
| 19 |
+
ExpiresInMinutes int `json:"expiresInMinutes"`
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
type CredentialGenerateResponse struct {
|
| 23 |
+
ClientID string `json:"clientId"`
|
| 24 |
+
Description string `json:"description"`
|
| 25 |
+
CreatedAt string `json:"createdAt"`
|
| 26 |
+
Secret string `json:"secret"`
|
| 27 |
+
Expires string `json:"expires"`
|
| 28 |
+
RefreshToken string `json:"refreshToken,omitempty"` // 添加 RefreshToken 字段
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// GenerateRandomDescription 生成随机5字符描述
|
| 32 |
+
func GenerateRandomDescription() string {
|
| 33 |
+
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
| 34 |
+
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
| 35 |
+
b := make([]byte, 5)
|
| 36 |
+
for i := range b {
|
| 37 |
+
b[i] = charset[rng.Intn(len(charset))]
|
| 38 |
+
}
|
| 39 |
+
return string(b)
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// GenerateCredential 使用 token 生成一个新凭证
|
| 43 |
+
func GenerateCredential(token string) (*CredentialGenerateResponse, error) {
|
| 44 |
+
reqBody := CredentialGenerateRequest{
|
| 45 |
+
Description: GenerateRandomDescription(),
|
| 46 |
+
ExpiresInMinutes: 525600, // 1 year
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
bodyBytes, err := json.Marshal(reqBody)
|
| 50 |
+
if err != nil {
|
| 51 |
+
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
req, err := http.NewRequest("POST", CredentialGenerateURL, bytes.NewReader(bodyBytes))
|
| 55 |
+
if err != nil {
|
| 56 |
+
return nil, fmt.Errorf("failed to create request: %w", err)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// 设置请求头
|
| 60 |
+
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
| 61 |
+
req.Header.Set("Connection", "keep-alive")
|
| 62 |
+
req.Header.Set("accept", "*/*")
|
| 63 |
+
req.Header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,ja;q=0.6")
|
| 64 |
+
req.Header.Set("authorization", "Bearer "+token)
|
| 65 |
+
req.Header.Set("cache-control", "no-cache")
|
| 66 |
+
req.Header.Set("content-type", "application/json")
|
| 67 |
+
req.Header.Set("frontegg-source", "admin-portal")
|
| 68 |
+
req.Header.Set("origin", "https://auth.zencoder.ai")
|
| 69 |
+
req.Header.Set("pragma", "no-cache")
|
| 70 |
+
req.Header.Set("priority", "u=1, i")
|
| 71 |
+
req.Header.Set("referer", "https://auth.zencoder.ai/")
|
| 72 |
+
req.Header.Set("sec-ch-ua", `"Google Chrome";v="143", "Chromium";v="143", "Not A(Brand";v="24"`)
|
| 73 |
+
req.Header.Set("sec-ch-ua-mobile", "?0")
|
| 74 |
+
req.Header.Set("sec-ch-ua-platform", `"Windows"`)
|
| 75 |
+
req.Header.Set("sec-fetch-dest", "empty")
|
| 76 |
+
req.Header.Set("sec-fetch-mode", "cors")
|
| 77 |
+
req.Header.Set("sec-fetch-site", "same-site")
|
| 78 |
+
req.Header.Set("user-agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/143.0.0.0 Safari/537.36")
|
| 79 |
+
req.Header.Set("x-frontegg-framework", "next@15.3.8")
|
| 80 |
+
req.Header.Set("x-frontegg-sdk", "@frontegg/nextjs@9.2.10")
|
| 81 |
+
|
| 82 |
+
client := &http.Client{Timeout: 30 * time.Second}
|
| 83 |
+
resp, err := client.Do(req)
|
| 84 |
+
if err != nil {
|
| 85 |
+
return nil, fmt.Errorf("failed to send request: %w", err)
|
| 86 |
+
}
|
| 87 |
+
defer resp.Body.Close()
|
| 88 |
+
|
| 89 |
+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
| 90 |
+
body, _ := io.ReadAll(resp.Body)
|
| 91 |
+
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
var result CredentialGenerateResponse
|
| 95 |
+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
| 96 |
+
return nil, fmt.Errorf("failed to decode response: %w", err)
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return &result, nil
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// BatchGenerateCredentials 批量生成凭证
|
| 103 |
+
func BatchGenerateCredentials(token string, count int) ([]*CredentialGenerateResponse, []error) {
|
| 104 |
+
var results []*CredentialGenerateResponse
|
| 105 |
+
var errors []error
|
| 106 |
+
|
| 107 |
+
for i := 0; i < count; i++ {
|
| 108 |
+
cred, err := GenerateCredential(token)
|
| 109 |
+
if err != nil {
|
| 110 |
+
errors = append(errors, fmt.Errorf("credential %d: %w", i+1, err))
|
| 111 |
+
continue
|
| 112 |
+
}
|
| 113 |
+
results = append(results, cred)
|
| 114 |
+
|
| 115 |
+
// 添加短暂延迟避免请求过快
|
| 116 |
+
if i < count-1 {
|
| 117 |
+
time.Sleep(500 * time.Millisecond)
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
return results, errors
|
| 122 |
+
}
|
internal/service/debug.go
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"log"
|
| 7 |
+
"os"
|
| 8 |
+
"sync"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
var (
|
| 12 |
+
debugMode bool
|
| 13 |
+
debugModeOnce sync.Once
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
// IsDebugMode 检查是否启用调试模式
|
| 17 |
+
func IsDebugMode() bool {
|
| 18 |
+
debugModeOnce.Do(func() {
|
| 19 |
+
debugMode = os.Getenv("DEBUG") == "true" || os.Getenv("DEBUG") == "1"
|
| 20 |
+
})
|
| 21 |
+
return debugMode
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// RequestLogger 用于收集请求级日志
|
| 25 |
+
type RequestLogger struct {
|
| 26 |
+
logs []string
|
| 27 |
+
mu sync.Mutex
|
| 28 |
+
hasError bool
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// NewRequestLogger 创建新的请求日志记录器
|
| 32 |
+
func NewRequestLogger() *RequestLogger {
|
| 33 |
+
return &RequestLogger{
|
| 34 |
+
logs: make([]string, 0, 20),
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Log 记录一条日志
|
| 39 |
+
func (l *RequestLogger) Log(format string, args ...interface{}) {
|
| 40 |
+
msg := fmt.Sprintf(format, args...)
|
| 41 |
+
|
| 42 |
+
// 如果全局 DEBUG 开启,直接打印
|
| 43 |
+
if IsDebugMode() {
|
| 44 |
+
log.Print("[DEBUG] " + msg)
|
| 45 |
+
return
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// 否则缓冲
|
| 49 |
+
l.mu.Lock()
|
| 50 |
+
l.logs = append(l.logs, "[DEBUG] " + msg)
|
| 51 |
+
l.mu.Unlock()
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// MarkError 标记发生错误
|
| 55 |
+
func (l *RequestLogger) MarkError() {
|
| 56 |
+
l.mu.Lock()
|
| 57 |
+
l.hasError = true
|
| 58 |
+
l.mu.Unlock()
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Flush 输出缓冲的日志(如果有错误)
|
| 62 |
+
func (l *RequestLogger) Flush() {
|
| 63 |
+
// 只有在非 Debug 模式且发生错误时才需要 Flush (Debug 模式下已经实时打印了)
|
| 64 |
+
if !IsDebugMode() && l.hasError {
|
| 65 |
+
l.mu.Lock()
|
| 66 |
+
defer l.mu.Unlock()
|
| 67 |
+
for _, msg := range l.logs {
|
| 68 |
+
log.Print(msg)
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
type contextKey string
|
| 74 |
+
|
| 75 |
+
const loggerContextKey contextKey = "request_logger"
|
| 76 |
+
|
| 77 |
+
// WithLogger 将 logger 注入 context
|
| 78 |
+
func WithLogger(ctx context.Context, logger *RequestLogger) context.Context {
|
| 79 |
+
return context.WithValue(ctx, loggerContextKey, logger)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// GetLogger 从 context 获取 logger
|
| 83 |
+
func GetLogger(ctx context.Context) *RequestLogger {
|
| 84 |
+
val := ctx.Value(loggerContextKey)
|
| 85 |
+
if val != nil {
|
| 86 |
+
if logger, ok := val.(*RequestLogger); ok {
|
| 87 |
+
return logger
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
return nil
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// 辅助函数:获取 logger 并记录
|
| 94 |
+
func logToContext(ctx context.Context, format string, args ...interface{}) {
|
| 95 |
+
logger := GetLogger(ctx)
|
| 96 |
+
if logger != nil {
|
| 97 |
+
logger.Log(format, args...)
|
| 98 |
+
} else if IsDebugMode() {
|
| 99 |
+
log.Printf("[DEBUG] "+format, args...)
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// DebugLog 调试日志输出
|
| 104 |
+
func DebugLog(ctx context.Context, format string, args ...interface{}) {
|
| 105 |
+
logToContext(ctx, format, args...)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// DebugLogRequest 请求开始日志
|
| 109 |
+
func DebugLogRequest(ctx context.Context, provider, endpoint, model string) {
|
| 110 |
+
logToContext(ctx, "[%s] >>> 请求开始: endpoint=%s, model=%s", provider, endpoint, model)
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// DebugLogRetry 重试日志
|
| 114 |
+
func DebugLogRetry(ctx context.Context, provider string, attempt int, accountID uint, err error) {
|
| 115 |
+
if logger := GetLogger(ctx); logger != nil {
|
| 116 |
+
logger.MarkError()
|
| 117 |
+
}
|
| 118 |
+
logToContext(ctx, "[%s] ↻ 重试 #%d: accountID=%d, error=%v", provider, attempt, accountID, err)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// DebugLogAccountSelected 账号选择日志
|
| 122 |
+
func DebugLogAccountSelected(ctx context.Context, provider string, accountID uint, email string) {
|
| 123 |
+
logToContext(ctx, "[%s] ✓ 选择账号: id=%d, email=%s", provider, accountID, email)
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// DebugLogRequestSent 请求发送日志
|
| 127 |
+
func DebugLogRequestSent(ctx context.Context, provider, url string) {
|
| 128 |
+
logToContext(ctx, "[%s] → 发送请求: %s", provider, url)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// DebugLogResponseReceived 响应接收日志
|
| 132 |
+
func DebugLogResponseReceived(ctx context.Context, provider string, statusCode int) {
|
| 133 |
+
logToContext(ctx, "[%s] ← 收到响应: status=%d", provider, statusCode)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// DebugLogRequestEnd 请求结束日志
|
| 137 |
+
func DebugLogRequestEnd(ctx context.Context, provider string, success bool, err error) {
|
| 138 |
+
if !success || err != nil {
|
| 139 |
+
if logger := GetLogger(ctx); logger != nil {
|
| 140 |
+
logger.MarkError()
|
| 141 |
+
}
|
| 142 |
+
logToContext(ctx, "[%s] <<< 请求完成: success=false, error=%v", provider, err)
|
| 143 |
+
} else {
|
| 144 |
+
logToContext(ctx, "[%s] <<< 请求完成: success=true", provider)
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// DebugLogRequestHeaders 请求头日志
|
| 149 |
+
func DebugLogRequestHeaders(ctx context.Context, provider string, headers map[string][]string) {
|
| 150 |
+
logToContext(ctx, "[%s] 请求头:", provider)
|
| 151 |
+
for k, v := range headers {
|
| 152 |
+
// 隐藏敏感信息
|
| 153 |
+
if k == "Authorization" || k == "x-api-key" {
|
| 154 |
+
logToContext(ctx, "[%s] %s: ***", provider, k)
|
| 155 |
+
} else {
|
| 156 |
+
logToContext(ctx, "[%s] %s: %v", provider, k, v)
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// DebugLogResponseHeaders 响应头日志
|
| 162 |
+
func DebugLogResponseHeaders(ctx context.Context, provider string, headers map[string][]string) {
|
| 163 |
+
logToContext(ctx, "[%s] 响应头:", provider)
|
| 164 |
+
for k, v := range headers {
|
| 165 |
+
// 隐藏敏感信息
|
| 166 |
+
if k == "X-Api-Key" || k == "Authorization" {
|
| 167 |
+
logToContext(ctx, "[%s] %s: ***", provider, k)
|
| 168 |
+
} else {
|
| 169 |
+
logToContext(ctx, "[%s] %s: %v", provider, k, v)
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// DebugLogActualModel 实际调用模型日志
|
| 175 |
+
func DebugLogActualModel(ctx context.Context, provider, requestModel, actualModel string) {
|
| 176 |
+
logToContext(ctx, "[%s] 模型映射: %s → %s", provider, requestModel, actualModel)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// DebugLogErrorResponse 错误响应内容日志
|
| 180 |
+
func DebugLogErrorResponse(ctx context.Context, provider string, statusCode int, body string) {
|
| 181 |
+
if logger := GetLogger(ctx); logger != nil {
|
| 182 |
+
logger.MarkError()
|
| 183 |
+
}
|
| 184 |
+
logToContext(ctx, "[%s] ✗ 错误响应 [%d]: %s", provider, statusCode, body)
|
| 185 |
+
}
|
internal/service/errors.go
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import "errors"
|
| 4 |
+
|
| 5 |
+
var (
|
| 6 |
+
ErrNoAvailableAccount = errors.New("没有可用token")
|
| 7 |
+
ErrNoPermission = errors.New("没有账号有权限使用此模型")
|
| 8 |
+
ErrTokenExpired = errors.New("token已过期")
|
| 9 |
+
ErrRequestFailed = errors.New("请求失败")
|
| 10 |
+
)
|
internal/service/gemini.go
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"context"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"log"
|
| 9 |
+
"net/http"
|
| 10 |
+
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
"zencoder-2api/internal/service/provider"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const GeminiBaseURL = "https://api.zencoder.ai/gemini"
|
| 16 |
+
|
| 17 |
+
type GeminiService struct{}
|
| 18 |
+
|
| 19 |
+
func NewGeminiService() *GeminiService {
|
| 20 |
+
return &GeminiService{}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// GenerateContent 处理generateContent请求
|
| 24 |
+
func (s *GeminiService) GenerateContent(ctx context.Context, modelName string, body []byte) (*http.Response, error) {
|
| 25 |
+
// 检查模型是否存在于模型字典中
|
| 26 |
+
_, exists := model.GetZenModel(modelName)
|
| 27 |
+
if !exists {
|
| 28 |
+
DebugLog(ctx, "[Gemini] 模型不存在: %s", modelName)
|
| 29 |
+
return nil, ErrNoAvailableAccount
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
DebugLogRequest(ctx, "Gemini", "generateContent", modelName)
|
| 33 |
+
|
| 34 |
+
var lastErr error
|
| 35 |
+
for i := 0; i < MaxRetries; i++ {
|
| 36 |
+
account, err := GetNextAccountForModel(modelName)
|
| 37 |
+
if err != nil {
|
| 38 |
+
DebugLogRequestEnd(ctx, "Gemini", false, err)
|
| 39 |
+
return nil, err
|
| 40 |
+
}
|
| 41 |
+
DebugLogAccountSelected(ctx, "Gemini", account.ID, account.Email)
|
| 42 |
+
|
| 43 |
+
resp, err := s.doRequest(ctx, account, modelName, body, false)
|
| 44 |
+
if err != nil {
|
| 45 |
+
MarkAccountError(account)
|
| 46 |
+
lastErr = err
|
| 47 |
+
DebugLogRetry(ctx, "Gemini", i+1, account.ID, err)
|
| 48 |
+
continue
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
DebugLogResponseReceived(ctx, "Gemini", resp.StatusCode)
|
| 52 |
+
DebugLogResponseHeaders(ctx, "Gemini", resp.Header)
|
| 53 |
+
|
| 54 |
+
// 总是输出重要的响应头信息
|
| 55 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 56 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 57 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 58 |
+
log.Printf("[Gemini] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 59 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 60 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 61 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if resp.StatusCode >= 400 {
|
| 65 |
+
// 读取错误响应内容用于日志
|
| 66 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 67 |
+
resp.Body.Close()
|
| 68 |
+
DebugLogErrorResponse(ctx, "Gemini", resp.StatusCode, string(errBody))
|
| 69 |
+
|
| 70 |
+
// 400和500错误直接返回,不进行账号错误计数
|
| 71 |
+
if resp.StatusCode == 400 || resp.StatusCode == 500 {
|
| 72 |
+
DebugLogRequestEnd(ctx, "Gemini", false, fmt.Errorf("API error: %d", resp.StatusCode))
|
| 73 |
+
return nil, fmt.Errorf("API error: %d - %s", resp.StatusCode, string(errBody))
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// 429 错误特殊处理
|
| 77 |
+
if resp.StatusCode == 429 {
|
| 78 |
+
log.Printf("[Gemini] 429限流错误,尝试使用代理重试")
|
| 79 |
+
|
| 80 |
+
// 尝试使用代理池重试
|
| 81 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, modelName, body, false)
|
| 82 |
+
if proxyErr == nil && proxyResp != nil {
|
| 83 |
+
// 代理重试成功
|
| 84 |
+
return proxyResp, nil
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
log.Printf("[Gemini] 代理重试失败: %v", proxyErr)
|
| 88 |
+
MarkAccountRateLimitedWithResponse(account, resp)
|
| 89 |
+
} else {
|
| 90 |
+
MarkAccountError(account)
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 94 |
+
DebugLogRetry(ctx, "Gemini", i+1, account.ID, lastErr)
|
| 95 |
+
continue
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
ResetAccountError(account)
|
| 99 |
+
zenModel, exists := model.GetZenModel(modelName)
|
| 100 |
+
if !exists {
|
| 101 |
+
// 模型不存在,使用默认倍率
|
| 102 |
+
UpdateAccountCreditsFromResponse(account, resp, 1.0)
|
| 103 |
+
} else {
|
| 104 |
+
// 使用统一的积分更新函数,自动处理响应头中的积分信息
|
| 105 |
+
UpdateAccountCreditsFromResponse(account, resp, zenModel.Multiplier)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
DebugLogRequestEnd(ctx, "Gemini", true, nil)
|
| 109 |
+
return resp, nil
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
DebugLogRequestEnd(ctx, "Gemini", false, lastErr)
|
| 113 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// StreamGenerateContent 处理streamGenerateContent请求
|
| 117 |
+
func (s *GeminiService) StreamGenerateContent(ctx context.Context, modelName string, body []byte) (*http.Response, error) {
|
| 118 |
+
// 检查模型是否存在于模型字典中
|
| 119 |
+
_, exists := model.GetZenModel(modelName)
|
| 120 |
+
if !exists {
|
| 121 |
+
DebugLog(ctx, "[Gemini] 模型不存在: %s", modelName)
|
| 122 |
+
return nil, ErrNoAvailableAccount
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
DebugLogRequest(ctx, "Gemini", "streamGenerateContent", modelName)
|
| 126 |
+
|
| 127 |
+
var lastErr error
|
| 128 |
+
for i := 0; i < MaxRetries; i++ {
|
| 129 |
+
account, err := GetNextAccountForModel(modelName)
|
| 130 |
+
if err != nil {
|
| 131 |
+
DebugLogRequestEnd(ctx, "Gemini", false, err)
|
| 132 |
+
return nil, err
|
| 133 |
+
}
|
| 134 |
+
DebugLogAccountSelected(ctx, "Gemini", account.ID, account.Email)
|
| 135 |
+
|
| 136 |
+
resp, err := s.doRequest(ctx, account, modelName, body, true)
|
| 137 |
+
if err != nil {
|
| 138 |
+
MarkAccountError(account)
|
| 139 |
+
lastErr = err
|
| 140 |
+
DebugLogRetry(ctx, "Gemini", i+1, account.ID, err)
|
| 141 |
+
continue
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
DebugLogResponseReceived(ctx, "Gemini", resp.StatusCode)
|
| 145 |
+
DebugLogResponseHeaders(ctx, "Gemini", resp.Header)
|
| 146 |
+
|
| 147 |
+
// 总是输出重要的响应头信息
|
| 148 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 149 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 150 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 151 |
+
log.Printf("[Gemini] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 152 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 153 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 154 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if resp.StatusCode >= 400 {
|
| 158 |
+
// 读取错误响应内容用于日志
|
| 159 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 160 |
+
resp.Body.Close()
|
| 161 |
+
DebugLogErrorResponse(ctx, "Gemini", resp.StatusCode, string(errBody))
|
| 162 |
+
|
| 163 |
+
// 400和500错误直接返回,不进行账号错误计数
|
| 164 |
+
if resp.StatusCode == 400 || resp.StatusCode == 500 {
|
| 165 |
+
DebugLogRequestEnd(ctx, "Gemini", false, fmt.Errorf("API error: %d", resp.StatusCode))
|
| 166 |
+
return nil, fmt.Errorf("API error: %d - %s", resp.StatusCode, string(errBody))
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// 429 错误特殊处理
|
| 170 |
+
if resp.StatusCode == 429 {
|
| 171 |
+
log.Printf("[Gemini] 429限流错误,尝试使用代理重试")
|
| 172 |
+
|
| 173 |
+
// 尝试使用代理池重试
|
| 174 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, modelName, body, true)
|
| 175 |
+
if proxyErr == nil && proxyResp != nil {
|
| 176 |
+
// 代理重试成功
|
| 177 |
+
return proxyResp, nil
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
log.Printf("[Gemini] 代理重试失败: %v", proxyErr)
|
| 181 |
+
MarkAccountRateLimitedWithResponse(account, resp)
|
| 182 |
+
} else {
|
| 183 |
+
MarkAccountError(account)
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 187 |
+
DebugLogRetry(ctx, "Gemini", i+1, account.ID, lastErr)
|
| 188 |
+
continue
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
ResetAccountError(account)
|
| 192 |
+
zenModel, exists := model.GetZenModel(modelName)
|
| 193 |
+
if !exists {
|
| 194 |
+
// 模型不存在,使用默认倍率
|
| 195 |
+
UseCredit(account, 1.0)
|
| 196 |
+
} else {
|
| 197 |
+
// 流式响应,暂时使用模型倍率(因为没有完整响应头)
|
| 198 |
+
UseCredit(account, zenModel.Multiplier)
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
DebugLogRequestEnd(ctx, "Gemini", true, nil)
|
| 202 |
+
return resp, nil
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
DebugLogRequestEnd(ctx, "Gemini", false, lastErr)
|
| 206 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
func (s *GeminiService) doRequest(ctx context.Context, account *model.Account, modelName string, body []byte, stream bool) (*http.Response, error) {
|
| 210 |
+
zenModel, exists := model.GetZenModel(modelName)
|
| 211 |
+
if !exists {
|
| 212 |
+
return nil, ErrNoAvailableAccount
|
| 213 |
+
}
|
| 214 |
+
httpClient := provider.NewHTTPClient(account.Proxy, 0)
|
| 215 |
+
|
| 216 |
+
action := "generateContent"
|
| 217 |
+
queryParam := ""
|
| 218 |
+
if stream {
|
| 219 |
+
action = "streamGenerateContent"
|
| 220 |
+
queryParam = "?alt=sse"
|
| 221 |
+
}
|
| 222 |
+
reqURL := fmt.Sprintf("%s/v1beta/models/%s:%s%s", GeminiBaseURL, modelName, action, queryParam)
|
| 223 |
+
DebugLogRequestSent(ctx, "Gemini", reqURL)
|
| 224 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(body))
|
| 225 |
+
if err != nil {
|
| 226 |
+
return nil, err
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// 设置Zencoder自定义请求头
|
| 230 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 231 |
+
|
| 232 |
+
// 流式请求禁用压缩,确保可以逐行读取
|
| 233 |
+
if stream {
|
| 234 |
+
httpReq.Header.Set("Accept-Encoding", "identity")
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// 添加模型配置的额外请求头
|
| 238 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 239 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 240 |
+
httpReq.Header.Set(k, v)
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
// 记录请求头用于调试
|
| 245 |
+
DebugLogRequestHeaders(ctx, "Gemini", httpReq.Header)
|
| 246 |
+
|
| 247 |
+
return httpClient.Do(httpReq)
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// GenerateContentProxy 代理generateContent请求
|
| 251 |
+
func (s *GeminiService) GenerateContentProxy(ctx context.Context, w http.ResponseWriter, modelName string, body []byte) error {
|
| 252 |
+
resp, err := s.GenerateContent(ctx, modelName, body)
|
| 253 |
+
if err != nil {
|
| 254 |
+
return err
|
| 255 |
+
}
|
| 256 |
+
defer resp.Body.Close()
|
| 257 |
+
|
| 258 |
+
return StreamResponse(w, resp)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
// retryWithProxy 使用代理池重试Gemini请求
|
| 262 |
+
func (s *GeminiService) retryWithProxy(ctx context.Context, account *model.Account, modelName string, body []byte, stream bool) (*http.Response, error) {
|
| 263 |
+
// 获取模型配置
|
| 264 |
+
zenModel, exists := model.GetZenModel(modelName)
|
| 265 |
+
if !exists {
|
| 266 |
+
return nil, fmt.Errorf("模型配置不存在: %s", modelName)
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
proxyPool := provider.GetProxyPool()
|
| 270 |
+
if !proxyPool.HasProxies() {
|
| 271 |
+
return nil, fmt.Errorf("没有可用的代理")
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
maxRetries := 3
|
| 275 |
+
for i := 0; i < maxRetries; i++ {
|
| 276 |
+
// 获取随机代理
|
| 277 |
+
proxyURL := proxyPool.GetRandomProxy()
|
| 278 |
+
if proxyURL == "" {
|
| 279 |
+
continue
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
log.Printf("[Gemini] 尝试代理 %s (重试 %d/%d)", proxyURL, i+1, maxRetries)
|
| 283 |
+
|
| 284 |
+
// 创建使用代理的HTTP客户端
|
| 285 |
+
proxyClient, err := provider.NewHTTPClientWithProxy(proxyURL, 0)
|
| 286 |
+
if err != nil {
|
| 287 |
+
log.Printf("[Gemini] 创建代理客户端失败: %v", err)
|
| 288 |
+
continue
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
// 创建新请求
|
| 292 |
+
action := "generateContent"
|
| 293 |
+
queryParam := ""
|
| 294 |
+
if stream {
|
| 295 |
+
action = "streamGenerateContent"
|
| 296 |
+
queryParam = "?alt=sse"
|
| 297 |
+
}
|
| 298 |
+
reqURL := fmt.Sprintf("%s/v1beta/models/%s:%s%s", GeminiBaseURL, modelName, action, queryParam)
|
| 299 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(body))
|
| 300 |
+
if err != nil {
|
| 301 |
+
log.Printf("[Gemini] 创建请求失败: %v", err)
|
| 302 |
+
continue
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// 设置请求头
|
| 306 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 307 |
+
|
| 308 |
+
// 流式请求禁用压缩,确保��以逐行读取
|
| 309 |
+
if stream {
|
| 310 |
+
httpReq.Header.Set("Accept-Encoding", "identity")
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
// 添加模型配置的额外请求头
|
| 314 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 315 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 316 |
+
httpReq.Header.Set(k, v)
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// 执行请求
|
| 321 |
+
resp, err := proxyClient.Do(httpReq)
|
| 322 |
+
if err != nil {
|
| 323 |
+
log.Printf("[Gemini] 代理请求失败: %v", err)
|
| 324 |
+
continue
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
// 检查响应状态
|
| 328 |
+
if resp.StatusCode == 429 {
|
| 329 |
+
// 仍然是429,尝试下一个代理
|
| 330 |
+
resp.Body.Close()
|
| 331 |
+
log.Printf("[Gemini] 代理 %s 仍返回429,尝试下一个", proxyURL)
|
| 332 |
+
continue
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
if resp.StatusCode >= 400 {
|
| 336 |
+
// 其他错误,记录并尝试下一个代理
|
| 337 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 338 |
+
resp.Body.Close()
|
| 339 |
+
log.Printf("[Gemini] 代理 %s 返回错误 %d: %s", proxyURL, resp.StatusCode, string(errBody))
|
| 340 |
+
continue
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
// 成功
|
| 344 |
+
log.Printf("[Gemini] 代理 %s 请求成功", proxyURL)
|
| 345 |
+
return resp, nil
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
return nil, fmt.Errorf("所有代理重试均失败")
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
// StreamGenerateContentProxy 代理streamGenerateContent请求
|
| 352 |
+
func (s *GeminiService) StreamGenerateContentProxy(ctx context.Context, w http.ResponseWriter, modelName string, body []byte) error {
|
| 353 |
+
resp, err := s.StreamGenerateContent(ctx, modelName, body)
|
| 354 |
+
if err != nil {
|
| 355 |
+
return err
|
| 356 |
+
}
|
| 357 |
+
defer resp.Body.Close()
|
| 358 |
+
|
| 359 |
+
return StreamResponse(w, resp)
|
| 360 |
+
}
|
internal/service/grok.go
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"context"
|
| 6 |
+
"encoding/json"
|
| 7 |
+
"fmt"
|
| 8 |
+
"io"
|
| 9 |
+
"log"
|
| 10 |
+
"net/http"
|
| 11 |
+
"strings"
|
| 12 |
+
|
| 13 |
+
"zencoder-2api/internal/model"
|
| 14 |
+
"zencoder-2api/internal/service/provider"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
const GrokBaseURL = "https://api.zencoder.ai/xai"
|
| 18 |
+
|
| 19 |
+
type GrokService struct{}
|
| 20 |
+
|
| 21 |
+
func NewGrokService() *GrokService {
|
| 22 |
+
return &GrokService{}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// ChatCompletions 处理/v1/chat/completions请求
|
| 26 |
+
func (s *GrokService) ChatCompletions(ctx context.Context, body []byte) (*http.Response, error) {
|
| 27 |
+
var req struct {
|
| 28 |
+
Model string `json:"model"`
|
| 29 |
+
}
|
| 30 |
+
if err := json.Unmarshal(body, &req); err != nil {
|
| 31 |
+
return nil, fmt.Errorf("invalid request body: %w", err)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// 检查模型是否存在于模型字典中
|
| 35 |
+
_, exists := model.GetZenModel(req.Model)
|
| 36 |
+
if !exists {
|
| 37 |
+
DebugLog(ctx, "[Grok] 模型不存在: %s", req.Model)
|
| 38 |
+
return nil, ErrNoAvailableAccount
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
DebugLogRequest(ctx, "Grok", "/v1/chat/completions", req.Model)
|
| 42 |
+
|
| 43 |
+
var lastErr error
|
| 44 |
+
for i := 0; i < MaxRetries; i++ {
|
| 45 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 46 |
+
if err != nil {
|
| 47 |
+
DebugLogRequestEnd(ctx, "Grok", false, err)
|
| 48 |
+
return nil, err
|
| 49 |
+
}
|
| 50 |
+
DebugLogAccountSelected(ctx, "Grok", account.ID, account.Email)
|
| 51 |
+
|
| 52 |
+
resp, err := s.doRequest(ctx, account, req.Model, body)
|
| 53 |
+
if err != nil {
|
| 54 |
+
MarkAccountError(account)
|
| 55 |
+
lastErr = err
|
| 56 |
+
DebugLogRetry(ctx, "Grok", i+1, account.ID, err)
|
| 57 |
+
continue
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
DebugLogResponseReceived(ctx, "Grok", resp.StatusCode)
|
| 61 |
+
DebugLogResponseHeaders(ctx, "Grok", resp.Header)
|
| 62 |
+
|
| 63 |
+
// 总是输出重要的响应头信息
|
| 64 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 65 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 66 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 67 |
+
log.Printf("[Grok] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 68 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 69 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 70 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if resp.StatusCode >= 400 {
|
| 74 |
+
// 读取错误响应内容
|
| 75 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 76 |
+
resp.Body.Close()
|
| 77 |
+
|
| 78 |
+
// 429 错误特殊处理 - 直接返回,不重试
|
| 79 |
+
if resp.StatusCode == 429 {
|
| 80 |
+
log.Printf("[Grok] 429限流错误,尝试使用代理重试")
|
| 81 |
+
|
| 82 |
+
// 尝试使用代理池重试
|
| 83 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, req.Model, body)
|
| 84 |
+
if proxyErr == nil && proxyResp != nil {
|
| 85 |
+
// 代理重试成功
|
| 86 |
+
return proxyResp, nil
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
log.Printf("[Grok] 代理重试失败: %v", proxyErr)
|
| 90 |
+
// 在DEBUG模式下记录详细信息
|
| 91 |
+
DebugLogErrorResponse(ctx, "Grok", resp.StatusCode, string(errBody))
|
| 92 |
+
// 将账号放入短期冷却(5秒)
|
| 93 |
+
MarkAccountRateLimitedShort(account)
|
| 94 |
+
// 标记错误并结束请求
|
| 95 |
+
DebugLogRequestEnd(ctx, "Grok", false, ErrNoAvailableAccount)
|
| 96 |
+
// 返回通用错误
|
| 97 |
+
return nil, ErrNoAvailableAccount
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
DebugLogErrorResponse(ctx, "Grok", resp.StatusCode, string(errBody))
|
| 101 |
+
|
| 102 |
+
// 400和500错误直接返回,不进行账号错误计数
|
| 103 |
+
if resp.StatusCode == 400 || resp.StatusCode == 500 {
|
| 104 |
+
DebugLogRequestEnd(ctx, "Grok", false, fmt.Errorf("API error: %d", resp.StatusCode))
|
| 105 |
+
return nil, fmt.Errorf("API error: %d - %s", resp.StatusCode, string(errBody))
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
MarkAccountError(account)
|
| 109 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 110 |
+
DebugLogRetry(ctx, "Grok", i+1, account.ID, lastErr)
|
| 111 |
+
continue
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
ResetAccountError(account)
|
| 115 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 116 |
+
if !exists {
|
| 117 |
+
// 模型不存在,使用默认倍率
|
| 118 |
+
UpdateAccountCreditsFromResponse(account, resp, 1.0)
|
| 119 |
+
} else {
|
| 120 |
+
// 使用统一的积分更新函数,自动处理响应头中的积分信息
|
| 121 |
+
UpdateAccountCreditsFromResponse(account, resp, zenModel.Multiplier)
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
DebugLogRequestEnd(ctx, "Grok", true, nil)
|
| 125 |
+
return resp, nil
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
DebugLogRequestEnd(ctx, "Grok", false, lastErr)
|
| 129 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
func (s *GrokService) doRequest(ctx context.Context, account *model.Account, modelID string, body []byte) (*http.Response, error) {
|
| 133 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 134 |
+
if !exists {
|
| 135 |
+
return nil, ErrNoAvailableAccount
|
| 136 |
+
}
|
| 137 |
+
httpClient := provider.NewHTTPClient(account.Proxy, 0)
|
| 138 |
+
|
| 139 |
+
// 处理请求体,Grok Code 模型要求 temperature=0
|
| 140 |
+
modifiedBody := body
|
| 141 |
+
if strings.Contains(modelID, "grok-code") {
|
| 142 |
+
modifiedBody, _ = s.setTemperatureZero(body)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
reqURL := GrokBaseURL + "/v1/chat/completions"
|
| 146 |
+
DebugLogRequestSent(ctx, "Grok", reqURL)
|
| 147 |
+
|
| 148 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(modifiedBody))
|
| 149 |
+
if err != nil {
|
| 150 |
+
return nil, err
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// 设置Zencoder自定义请求头
|
| 154 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 155 |
+
|
| 156 |
+
// 添加模型配置的额外请求头
|
| 157 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 158 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 159 |
+
httpReq.Header.Set(k, v)
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// 记录请求头用于调试
|
| 164 |
+
DebugLogRequestHeaders(ctx, "Grok", httpReq.Header)
|
| 165 |
+
|
| 166 |
+
return httpClient.Do(httpReq)
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// setTemperatureZero 设置 temperature=0
|
| 170 |
+
func (s *GrokService) setTemperatureZero(body []byte) ([]byte, error) {
|
| 171 |
+
var reqMap map[string]interface{}
|
| 172 |
+
if err := json.Unmarshal(body, &reqMap); err != nil {
|
| 173 |
+
return body, err
|
| 174 |
+
}
|
| 175 |
+
reqMap["temperature"] = 0
|
| 176 |
+
return json.Marshal(reqMap)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// ChatCompletionsProxy 代理chat completions请求
|
| 180 |
+
func (s *GrokService) ChatCompletionsProxy(ctx context.Context, w http.ResponseWriter, body []byte) error {
|
| 181 |
+
resp, err := s.ChatCompletions(ctx, body)
|
| 182 |
+
if err != nil {
|
| 183 |
+
return err
|
| 184 |
+
}
|
| 185 |
+
defer resp.Body.Close()
|
| 186 |
+
|
| 187 |
+
return StreamResponse(w, resp)
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// retryWithProxy 使用代理池重试Grok请求
|
| 191 |
+
func (s *GrokService) retryWithProxy(ctx context.Context, account *model.Account, modelID string, body []byte) (*http.Response, error) {
|
| 192 |
+
// 获取模型配置
|
| 193 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 194 |
+
if !exists {
|
| 195 |
+
return nil, fmt.Errorf("模型配置不存在: %s", modelID)
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
proxyPool := provider.GetProxyPool()
|
| 199 |
+
if !proxyPool.HasProxies() {
|
| 200 |
+
return nil, fmt.Errorf("没有可用的代理")
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
maxRetries := 3
|
| 204 |
+
for i := 0; i < maxRetries; i++ {
|
| 205 |
+
// 获取随机代理
|
| 206 |
+
proxyURL := proxyPool.GetRandomProxy()
|
| 207 |
+
if proxyURL == "" {
|
| 208 |
+
continue
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
log.Printf("[Grok] 尝试代理 %s (重试 %d/%d)", proxyURL, i+1, maxRetries)
|
| 212 |
+
|
| 213 |
+
// 创建使用代理的HTTP客户端
|
| 214 |
+
proxyClient, err := provider.NewHTTPClientWithProxy(proxyURL, 0)
|
| 215 |
+
if err != nil {
|
| 216 |
+
log.Printf("[Grok] 创建代理客户端失败: %v", err)
|
| 217 |
+
continue
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// 处理请求体,Grok Code 模型要求 temperature=0
|
| 221 |
+
modifiedBody := body
|
| 222 |
+
if strings.Contains(modelID, "grok-code") {
|
| 223 |
+
modifiedBody, _ = s.setTemperatureZero(body)
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
// 创建新请求
|
| 227 |
+
reqURL := GrokBaseURL + "/v1/chat/completions"
|
| 228 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(modifiedBody))
|
| 229 |
+
if err != nil {
|
| 230 |
+
log.Printf("[Grok] 创建请求失败: %v", err)
|
| 231 |
+
continue
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// 设置请求头
|
| 235 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 236 |
+
|
| 237 |
+
// 添加模型配置的额外请求头
|
| 238 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 239 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 240 |
+
httpReq.Header.Set(k, v)
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
// 执行请求
|
| 245 |
+
resp, err := proxyClient.Do(httpReq)
|
| 246 |
+
if err != nil {
|
| 247 |
+
log.Printf("[Grok] 代理请求失败: %v", err)
|
| 248 |
+
continue
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// 检查响应状态
|
| 252 |
+
if resp.StatusCode == 429 {
|
| 253 |
+
// 仍然是429,尝试下一个代理
|
| 254 |
+
resp.Body.Close()
|
| 255 |
+
log.Printf("[Grok] 代理 %s 仍返回429,尝试下一个", proxyURL)
|
| 256 |
+
continue
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
if resp.StatusCode >= 400 {
|
| 260 |
+
// 其他错误,记录并尝试下一个代理
|
| 261 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 262 |
+
resp.Body.Close()
|
| 263 |
+
log.Printf("[Grok] 代理 %s 返回错误 %d: %s", proxyURL, resp.StatusCode, string(errBody))
|
| 264 |
+
continue
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
// 成功
|
| 268 |
+
log.Printf("[Grok] 代理 %s 请求成功", proxyURL)
|
| 269 |
+
return resp, nil
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
return nil, fmt.Errorf("所有代理重试均失败")
|
| 273 |
+
}
|
internal/service/headers.go
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"math/rand"
|
| 5 |
+
"net/http"
|
| 6 |
+
"time"
|
| 7 |
+
|
| 8 |
+
"zencoder-2api/internal/model"
|
| 9 |
+
|
| 10 |
+
"github.com/google/uuid"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
var (
|
| 14 |
+
// 可变的 User-Agent 列表
|
| 15 |
+
userAgents = []string{
|
| 16 |
+
"zen-cli/0.9.0-SNAPSHOT_4c6ffdd-windows-x64",
|
| 17 |
+
"zen-cli/0.9.0-SNAPSHOT_5d7ggee-windows-x64",
|
| 18 |
+
"zen-cli/0.9.0-SNAPSHOT_6e8hhff-windows-x64",
|
| 19 |
+
"zen-cli/0.8.9-SNAPSHOT_3b5eedd-windows-x64",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// 可变的 Node 版本
|
| 23 |
+
nodeVersions = []string{
|
| 24 |
+
"v24.3.0",
|
| 25 |
+
"v24.2.0",
|
| 26 |
+
"v24.1.0",
|
| 27 |
+
"v23.5.0",
|
| 28 |
+
"v22.11.0",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// 可变的 zencoder 版本
|
| 32 |
+
zencoderVersions = []string{
|
| 33 |
+
"3.24.0",
|
| 34 |
+
"3.23.9",
|
| 35 |
+
"3.23.8",
|
| 36 |
+
"3.24.1",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// 可变的 package 版本
|
| 40 |
+
packageVersions = []string{
|
| 41 |
+
"6.9.1",
|
| 42 |
+
"6.9.0",
|
| 43 |
+
"6.8.9",
|
| 44 |
+
"6.8.8",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
rng = rand.New(rand.NewSource(time.Now().UnixNano()))
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
// 随机选择一个元素
|
| 51 |
+
func randomChoice(items []string) string {
|
| 52 |
+
return items[rng.Intn(len(items))]
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// SetZencoderHeaders 设置Zencoder自定义请求头
|
| 56 |
+
func SetZencoderHeaders(req *http.Request, account *model.Account, zenModel model.ZenModel) {
|
| 57 |
+
// 基础请求头 - 使用随机 User-Agent
|
| 58 |
+
req.Header.Set("User-Agent", "zen-cli/0.9.0-SNAPSHOT_4c6ffdd-windows-x64")
|
| 59 |
+
req.Header.Set("Accept", "application/json")
|
| 60 |
+
req.Header.Set("Content-Type", "application/json")
|
| 61 |
+
req.Header.Set("Connection", "keep-alive")
|
| 62 |
+
|
| 63 |
+
// 认证头
|
| 64 |
+
req.Header.Set("Authorization", "Bearer "+account.AccessToken)
|
| 65 |
+
|
| 66 |
+
// x-stainless 系列
|
| 67 |
+
req.Header.Set("x-stainless-arch", "x64")
|
| 68 |
+
req.Header.Set("x-stainless-lang", "js")
|
| 69 |
+
req.Header.Set("x-stainless-os", "Windows")
|
| 70 |
+
req.Header.Set("x-stainless-package-version", "0.70.1")
|
| 71 |
+
req.Header.Set("x-stainless-retry-count", "0")
|
| 72 |
+
req.Header.Set("x-stainless-runtime", "node")
|
| 73 |
+
req.Header.Set("x-stainless-runtime-version", "v24.3.0")
|
| 74 |
+
|
| 75 |
+
// zen/zencoder 系列 - 使用随机版本和唯一 ID
|
| 76 |
+
req.Header.Set("zen-model-id", zenModel.ID)
|
| 77 |
+
req.Header.Set("zencoder-arch", "x64")
|
| 78 |
+
req.Header.Set("zencoder-auto-model", "false")
|
| 79 |
+
req.Header.Set("zencoder-client-type", "vscode")
|
| 80 |
+
req.Header.Set("zencoder-operation-id", uuid.New().String())
|
| 81 |
+
req.Header.Set("zencoder-operation-type", "agent_call")
|
| 82 |
+
req.Header.Set("zencoder-os", "windows")
|
| 83 |
+
req.Header.Set("zencoder-version", "3.24.0")
|
| 84 |
+
}
|
internal/service/jwt.go
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/base64"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"errors"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
type CustomClaims struct {
|
| 12 |
+
Plan string `json:"plan"`
|
| 13 |
+
Autobots struct {
|
| 14 |
+
SubscriptionStartDate string `json:"subscription_start_date"`
|
| 15 |
+
} `json:"autobots"`
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
type JWTPayload struct {
|
| 19 |
+
Subject string `json:"sub"`
|
| 20 |
+
ClientID string `json:"client_id"`
|
| 21 |
+
Email string `json:"email"`
|
| 22 |
+
CustomClaims CustomClaims `json:"customClaims"`
|
| 23 |
+
IssuedAt int64 `json:"iat"`
|
| 24 |
+
Expiration int64 `json:"exp"`
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
func ParseJWT(tokenString string) (*JWTPayload, error) {
|
| 28 |
+
parts := strings.Split(tokenString, ".")
|
| 29 |
+
if len(parts) != 3 {
|
| 30 |
+
return nil, errors.New("invalid token format")
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
payloadPart := parts[1]
|
| 34 |
+
|
| 35 |
+
// Add padding if missing
|
| 36 |
+
if l := len(payloadPart) % 4; l > 0 {
|
| 37 |
+
payloadPart += strings.Repeat("=", 4-l)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
decoded, err := base64.URLEncoding.DecodeString(payloadPart)
|
| 41 |
+
if err != nil {
|
| 42 |
+
// Try standard encoding if URL encoding fails
|
| 43 |
+
decoded, err = base64.StdEncoding.DecodeString(payloadPart)
|
| 44 |
+
if err != nil {
|
| 45 |
+
return nil, err
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
var payload JWTPayload
|
| 50 |
+
if err := json.Unmarshal(decoded, &payload); err != nil {
|
| 51 |
+
return nil, err
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
return &payload, nil
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
func GetSubscriptionDate(payload *JWTPayload) time.Time {
|
| 58 |
+
// Try to parse SubscriptionStartDate from CustomClaims
|
| 59 |
+
if v := payload.CustomClaims.Autobots.SubscriptionStartDate; v != "" {
|
| 60 |
+
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
| 61 |
+
return t
|
| 62 |
+
}
|
| 63 |
+
if t, err := time.Parse("2006-01-02", v); err == nil {
|
| 64 |
+
return t
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Fallback to IssuedAt
|
| 69 |
+
if payload.IssuedAt > 0 {
|
| 70 |
+
return time.Unix(payload.IssuedAt, 0)
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
return time.Now()
|
| 74 |
+
}
|
internal/service/openai.go
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"bytes"
|
| 6 |
+
"context"
|
| 7 |
+
"encoding/json"
|
| 8 |
+
"fmt"
|
| 9 |
+
"io"
|
| 10 |
+
"log"
|
| 11 |
+
"net/http"
|
| 12 |
+
"strings"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"zencoder-2api/internal/model"
|
| 16 |
+
"zencoder-2api/internal/service/provider"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
const OpenAIBaseURL = "https://api.zencoder.ai/openai"
|
| 20 |
+
|
| 21 |
+
type OpenAIService struct{}
|
| 22 |
+
|
| 23 |
+
func NewOpenAIService() *OpenAIService {
|
| 24 |
+
return &OpenAIService{}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// ChatCompletions 处理/v1/chat/completions请求
|
| 28 |
+
func (s *OpenAIService) ChatCompletions(ctx context.Context, body []byte) (*http.Response, error) {
|
| 29 |
+
var req struct {
|
| 30 |
+
Model string `json:"model"`
|
| 31 |
+
}
|
| 32 |
+
if err := json.Unmarshal(body, &req); err != nil {
|
| 33 |
+
return nil, fmt.Errorf("invalid request body: %w", err)
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// 检查模型是否存在于模型字典中
|
| 37 |
+
_, exists := model.GetZenModel(req.Model)
|
| 38 |
+
if !exists {
|
| 39 |
+
DebugLog(ctx, "[OpenAI] 模型不存在: %s", req.Model)
|
| 40 |
+
return nil, ErrNoAvailableAccount
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
DebugLogRequest(ctx, "OpenAI", "/v1/chat/completions", req.Model)
|
| 44 |
+
|
| 45 |
+
var lastErr error
|
| 46 |
+
for i := 0; i < MaxRetries; i++ {
|
| 47 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 48 |
+
if err != nil {
|
| 49 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, err)
|
| 50 |
+
return nil, err
|
| 51 |
+
}
|
| 52 |
+
DebugLogAccountSelected(ctx, "OpenAI", account.ID, account.Email)
|
| 53 |
+
|
| 54 |
+
// Zencoder API使用/v1/responses端点
|
| 55 |
+
// 需要转换请求体:messages -> input
|
| 56 |
+
convertedBody, err := s.convertChatToResponsesBody(body)
|
| 57 |
+
if err != nil {
|
| 58 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, err)
|
| 59 |
+
return nil, fmt.Errorf("failed to convert request body: %w", err)
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
resp, err := s.doRequest(ctx, account, req.Model, "/v1/responses", convertedBody)
|
| 63 |
+
if err != nil {
|
| 64 |
+
MarkAccountError(account)
|
| 65 |
+
lastErr = err
|
| 66 |
+
DebugLogRetry(ctx, "OpenAI", i+1, account.ID, err)
|
| 67 |
+
continue
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
DebugLogResponseReceived(ctx, "OpenAI", resp.StatusCode)
|
| 71 |
+
DebugLogResponseHeaders(ctx, "OpenAI", resp.Header)
|
| 72 |
+
|
| 73 |
+
// 总是输出重要的响应头信息
|
| 74 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 75 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 76 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 77 |
+
log.Printf("[OpenAI] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 78 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 79 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 80 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if resp.StatusCode >= 400 {
|
| 84 |
+
// 读取错误响应内容
|
| 85 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 86 |
+
resp.Body.Close()
|
| 87 |
+
DebugLogErrorResponse(ctx, "OpenAI", resp.StatusCode, string(errBody))
|
| 88 |
+
|
| 89 |
+
// 400和500错误直接返回,不进行账号错误计数
|
| 90 |
+
if resp.StatusCode == 400 || resp.StatusCode == 500 {
|
| 91 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, fmt.Errorf("API error: %d", resp.StatusCode))
|
| 92 |
+
return nil, fmt.Errorf("API error: %d - %s", resp.StatusCode, string(errBody))
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// 429 错误特殊处理
|
| 96 |
+
if resp.StatusCode == 429 {
|
| 97 |
+
log.Printf("[OpenAI] 429限流错误,尝试使用代理重试")
|
| 98 |
+
|
| 99 |
+
// 尝试使用代理池重试
|
| 100 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, req.Model, "/v1/responses", convertedBody)
|
| 101 |
+
if proxyErr == nil && proxyResp != nil {
|
| 102 |
+
// 代理重试成功
|
| 103 |
+
return proxyResp, nil
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
log.Printf("[OpenAI] 代理重试失败: %v", proxyErr)
|
| 107 |
+
MarkAccountRateLimitedWithResponse(account, resp)
|
| 108 |
+
} else {
|
| 109 |
+
MarkAccountError(account)
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 113 |
+
DebugLogRetry(ctx, "OpenAI", i+1, account.ID, lastErr)
|
| 114 |
+
continue
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
ResetAccountError(account)
|
| 118 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 119 |
+
if !exists {
|
| 120 |
+
// 模型不存在,使用默认倍率
|
| 121 |
+
UpdateAccountCreditsFromResponse(account, resp, 1.0)
|
| 122 |
+
} else {
|
| 123 |
+
// 使用统一的积分更新函数,自动处理响应头中的积分信息
|
| 124 |
+
UpdateAccountCreditsFromResponse(account, resp, zenModel.Multiplier)
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
DebugLogRequestEnd(ctx, "OpenAI", true, nil)
|
| 128 |
+
return resp, nil
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, lastErr)
|
| 132 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// Responses 处理/v1/responses请求
|
| 136 |
+
func (s *OpenAIService) Responses(ctx context.Context, body []byte) (*http.Response, error) {
|
| 137 |
+
var req struct {
|
| 138 |
+
Model string `json:"model"`
|
| 139 |
+
}
|
| 140 |
+
if err := json.Unmarshal(body, &req); err != nil {
|
| 141 |
+
return nil, fmt.Errorf("invalid request body: %w", err)
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// 检查模型是否存在于模型字典中
|
| 145 |
+
_, exists := model.GetZenModel(req.Model)
|
| 146 |
+
if !exists {
|
| 147 |
+
DebugLog(ctx, "[OpenAI] 模型不存在: %s", req.Model)
|
| 148 |
+
return nil, ErrNoAvailableAccount
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
DebugLogRequest(ctx, "OpenAI", "/v1/responses", req.Model)
|
| 152 |
+
|
| 153 |
+
var lastErr error
|
| 154 |
+
for i := 0; i < MaxRetries; i++ {
|
| 155 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 156 |
+
if err != nil {
|
| 157 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, err)
|
| 158 |
+
return nil, err
|
| 159 |
+
}
|
| 160 |
+
DebugLogAccountSelected(ctx, "OpenAI", account.ID, account.Email)
|
| 161 |
+
|
| 162 |
+
resp, err := s.doRequest(ctx, account, req.Model, "/v1/responses", body)
|
| 163 |
+
if err != nil {
|
| 164 |
+
MarkAccountError(account)
|
| 165 |
+
lastErr = err
|
| 166 |
+
DebugLogRetry(ctx, "OpenAI", i+1, account.ID, err)
|
| 167 |
+
continue
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
DebugLogResponseReceived(ctx, "OpenAI", resp.StatusCode)
|
| 171 |
+
DebugLogResponseHeaders(ctx, "OpenAI", resp.Header)
|
| 172 |
+
|
| 173 |
+
// 总是输出重要的响应头信息
|
| 174 |
+
if resp.Header.Get("Zen-Pricing-Period-Limit") != "" ||
|
| 175 |
+
resp.Header.Get("Zen-Pricing-Period-Cost") != "" ||
|
| 176 |
+
resp.Header.Get("Zen-Request-Cost") != "" {
|
| 177 |
+
log.Printf("[OpenAI] 积分信息 - 周期限额: %s, 周期消耗: %s, 本次消耗: %s",
|
| 178 |
+
resp.Header.Get("Zen-Pricing-Period-Limit"),
|
| 179 |
+
resp.Header.Get("Zen-Pricing-Period-Cost"),
|
| 180 |
+
resp.Header.Get("Zen-Request-Cost"))
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if resp.StatusCode >= 400 {
|
| 184 |
+
// 读取错误响应内容
|
| 185 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 186 |
+
resp.Body.Close()
|
| 187 |
+
|
| 188 |
+
// 429 错误特殊处理 - 直接返回,不重试
|
| 189 |
+
if resp.StatusCode == 429 {
|
| 190 |
+
log.Printf("[OpenAI] 429限流错误,尝试使用代理重试")
|
| 191 |
+
|
| 192 |
+
// 尝试使用代理池重试
|
| 193 |
+
proxyResp, proxyErr := s.retryWithProxy(ctx, account, req.Model, "/v1/responses", body)
|
| 194 |
+
if proxyErr == nil && proxyResp != nil {
|
| 195 |
+
// 代理重试成功
|
| 196 |
+
return proxyResp, nil
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
log.Printf("[OpenAI] 代理重试失败: %v", proxyErr)
|
| 200 |
+
// 将账号放入短期冷却(5秒)
|
| 201 |
+
MarkAccountRateLimitedShort(account)
|
| 202 |
+
// 不输出错误日志,直接返回
|
| 203 |
+
return nil, ErrNoAvailableAccount
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
DebugLogErrorResponse(ctx, "OpenAI", resp.StatusCode, string(errBody))
|
| 207 |
+
|
| 208 |
+
// 400和500错误直接返回,不进行账号错误计数
|
| 209 |
+
if resp.StatusCode == 400 || resp.StatusCode == 500 {
|
| 210 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, fmt.Errorf("API error: %d", resp.StatusCode))
|
| 211 |
+
return nil, fmt.Errorf("API error: %d - %s", resp.StatusCode, string(errBody))
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
MarkAccountError(account)
|
| 215 |
+
lastErr = fmt.Errorf("API error: %d", resp.StatusCode)
|
| 216 |
+
DebugLogRetry(ctx, "OpenAI", i+1, account.ID, lastErr)
|
| 217 |
+
continue
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
ResetAccountError(account)
|
| 221 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 222 |
+
if !exists {
|
| 223 |
+
// 模型不存在,使用默认倍率
|
| 224 |
+
UpdateAccountCreditsFromResponse(account, resp, 1.0)
|
| 225 |
+
} else {
|
| 226 |
+
// 使用统一的积分更新函数,自动处理响应头中的积分信息
|
| 227 |
+
UpdateAccountCreditsFromResponse(account, resp, zenModel.Multiplier)
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
DebugLogRequestEnd(ctx, "OpenAI", true, nil)
|
| 231 |
+
return resp, nil
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
DebugLogRequestEnd(ctx, "OpenAI", false, lastErr)
|
| 235 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
// convertChatToResponsesBody 将 Chat Completion 的请求体转换为 Responses API 的请求体
|
| 239 |
+
func (s *OpenAIService) convertChatToResponsesBody(body []byte) ([]byte, error) {
|
| 240 |
+
var raw map[string]interface{}
|
| 241 |
+
if err := json.Unmarshal(body, &raw); err != nil {
|
| 242 |
+
return nil, err
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// 移除 /v1/responses API 不支持的参数
|
| 246 |
+
delete(raw, "stream_options") // 不支持 stream_options.include_usage 等
|
| 247 |
+
delete(raw, "function_call") // 旧版函数调用参数
|
| 248 |
+
delete(raw, "functions") // 旧版函数定义参数
|
| 249 |
+
|
| 250 |
+
// 转换 token 限制参数
|
| 251 |
+
// max_completion_tokens (新) / max_tokens (旧) -> max_output_tokens (Responses API)
|
| 252 |
+
if val, ok := raw["max_completion_tokens"]; ok {
|
| 253 |
+
raw["max_output_tokens"] = val
|
| 254 |
+
delete(raw, "max_completion_tokens")
|
| 255 |
+
} else if val, ok := raw["max_tokens"]; ok {
|
| 256 |
+
raw["max_output_tokens"] = val
|
| 257 |
+
delete(raw, "max_tokens")
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
modelStr, _ := raw["model"].(string)
|
| 261 |
+
|
| 262 |
+
// 检查是否有 messages 字段
|
| 263 |
+
if messages, ok := raw["messages"].([]interface{}); ok {
|
| 264 |
+
if modelStr == "gpt-5-nano-2025-08-07" {
|
| 265 |
+
// gpt-5-nano 特殊处理:转换为复杂的 input 结构
|
| 266 |
+
newInput := make([]map[string]interface{}, 0)
|
| 267 |
+
for _, m := range messages {
|
| 268 |
+
if msgMap, ok := m.(map[string]interface{}); ok {
|
| 269 |
+
role, _ := msgMap["role"].(string)
|
| 270 |
+
content := msgMap["content"]
|
| 271 |
+
|
| 272 |
+
newItem := map[string]interface{}{
|
| 273 |
+
"type": "message",
|
| 274 |
+
"role": role,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
newContent := make([]map[string]interface{}, 0)
|
| 278 |
+
if contentStr, ok := content.(string); ok {
|
| 279 |
+
newContent = append(newContent, map[string]interface{}{
|
| 280 |
+
"type": "input_text",
|
| 281 |
+
"text": contentStr,
|
| 282 |
+
})
|
| 283 |
+
}
|
| 284 |
+
// 这里的 content 如果是数组,暂时忽略或假设是纯文本场景
|
| 285 |
+
// 如果需要支持多模态,需要进一步解析 content 数组
|
| 286 |
+
|
| 287 |
+
newItem["content"] = newContent
|
| 288 |
+
newInput = append(newInput, newItem)
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
raw["input"] = newInput
|
| 292 |
+
} else {
|
| 293 |
+
// 标准转换:直接移动到 input
|
| 294 |
+
raw["input"] = messages
|
| 295 |
+
}
|
| 296 |
+
delete(raw, "messages")
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// gpt-5-nano-2025-08-07 特殊处理参数
|
| 300 |
+
if modelStr == "gpt-5-nano-2025-08-07" {
|
| 301 |
+
// 添加该模型所需的特定参数
|
| 302 |
+
raw["prompt_cache_key"] = "generate-name"
|
| 303 |
+
raw["store"] = false
|
| 304 |
+
raw["include"] = []string{"reasoning.encrypted_content"}
|
| 305 |
+
raw["service_tier"] = "auto"
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
return json.Marshal(raw)
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
func (s *OpenAIService) doRequest(ctx context.Context, account *model.Account, modelID, path string, body []byte) (*http.Response, error) {
|
| 312 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 313 |
+
if !exists {
|
| 314 |
+
return nil, ErrNoAvailableAccount
|
| 315 |
+
}
|
| 316 |
+
httpClient := provider.NewHTTPClient(account.Proxy, 0)
|
| 317 |
+
|
| 318 |
+
// 将模型参数合并到请求体中
|
| 319 |
+
modifiedBody := body
|
| 320 |
+
if zenModel.Parameters != nil {
|
| 321 |
+
var raw map[string]interface{}
|
| 322 |
+
if json.Unmarshal(modifiedBody, &raw) == nil {
|
| 323 |
+
// 添加 reasoning 配置
|
| 324 |
+
if zenModel.Parameters.Reasoning != nil && raw["reasoning"] == nil {
|
| 325 |
+
reasoningMap := map[string]interface{}{
|
| 326 |
+
"effort": zenModel.Parameters.Reasoning.Effort,
|
| 327 |
+
}
|
| 328 |
+
if zenModel.Parameters.Reasoning.Summary != "" {
|
| 329 |
+
reasoningMap["summary"] = zenModel.Parameters.Reasoning.Summary
|
| 330 |
+
}
|
| 331 |
+
raw["reasoning"] = reasoningMap
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
// 添加 text 配置
|
| 335 |
+
if zenModel.Parameters.Text != nil && raw["text"] == nil {
|
| 336 |
+
raw["text"] = map[string]interface{}{
|
| 337 |
+
"verbosity": zenModel.Parameters.Text.Verbosity,
|
| 338 |
+
}
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
// 添加 temperature 配置
|
| 342 |
+
if zenModel.Parameters.Temperature != nil && raw["temperature"] == nil {
|
| 343 |
+
raw["temperature"] = *zenModel.Parameters.Temperature
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
modifiedBody, _ = json.Marshal(raw)
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
// gpt-5-nano-2025-08-07 特殊处理参数
|
| 351 |
+
if modelID == "gpt-5-nano-2025-08-07" {
|
| 352 |
+
var raw map[string]interface{}
|
| 353 |
+
if json.Unmarshal(modifiedBody, &raw) == nil {
|
| 354 |
+
// 添加 text 参数
|
| 355 |
+
if _, ok := raw["text"]; !ok {
|
| 356 |
+
raw["text"] = map[string]string{"verbosity": "medium"}
|
| 357 |
+
}
|
| 358 |
+
// 添加 temperature 参数 (如果缺失)
|
| 359 |
+
if _, ok := raw["temperature"]; !ok {
|
| 360 |
+
raw["temperature"] = 1
|
| 361 |
+
}
|
| 362 |
+
// 强制开启 stream,因为该模型似乎不支持非流式
|
| 363 |
+
raw["stream"] = true
|
| 364 |
+
|
| 365 |
+
// 修正 reasoning 参数,添加 summary
|
| 366 |
+
if reasoning, ok := raw["reasoning"].(map[string]interface{}); ok {
|
| 367 |
+
reasoning["summary"] = "auto"
|
| 368 |
+
raw["reasoning"] = reasoning
|
| 369 |
+
} else {
|
| 370 |
+
raw["reasoning"] = map[string]interface{}{
|
| 371 |
+
"effort": "minimal",
|
| 372 |
+
"summary": "auto",
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
modifiedBody, _ = json.Marshal(raw)
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
// 注意:已移除模型重定向逻辑,直接使用用户请求的模型名
|
| 380 |
+
DebugLogActualModel(ctx, "OpenAI", modelID, modelID)
|
| 381 |
+
|
| 382 |
+
reqURL := OpenAIBaseURL + path
|
| 383 |
+
DebugLogRequestSent(ctx, "OpenAI", reqURL)
|
| 384 |
+
|
| 385 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(modifiedBody))
|
| 386 |
+
if err != nil {
|
| 387 |
+
return nil, err
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// 设置Zencoder自定义请求头
|
| 391 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 392 |
+
|
| 393 |
+
// 添加模型配置的额外请求头
|
| 394 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 395 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 396 |
+
httpReq.Header.Set(k, v)
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// 记录请求头用于调试
|
| 401 |
+
DebugLogRequestHeaders(ctx, "OpenAI", httpReq.Header)
|
| 402 |
+
|
| 403 |
+
// 强制记录请求体用于调试
|
| 404 |
+
log.Printf("[DEBUG] [OpenAI] 请求体:")
|
| 405 |
+
log.Printf("[DEBUG] [OpenAI] %s", string(modifiedBody))
|
| 406 |
+
|
| 407 |
+
return httpClient.Do(httpReq)
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
// ChatCompletionsProxy 代理chat completions请求
|
| 411 |
+
func (s *OpenAIService) ChatCompletionsProxy(ctx context.Context, w http.ResponseWriter, body []byte) error {
|
| 412 |
+
// 解析 model 和 stream 参数
|
| 413 |
+
var req struct {
|
| 414 |
+
Model string `json:"model"`
|
| 415 |
+
Stream bool `json:"stream"`
|
| 416 |
+
}
|
| 417 |
+
// 忽略错误,因为ChatCompletions会再次解析并处理错误
|
| 418 |
+
_ = json.Unmarshal(body, &req)
|
| 419 |
+
|
| 420 |
+
resp, err := s.ChatCompletions(ctx, body)
|
| 421 |
+
if err != nil {
|
| 422 |
+
return err
|
| 423 |
+
}
|
| 424 |
+
defer resp.Body.Close()
|
| 425 |
+
|
| 426 |
+
if req.Stream {
|
| 427 |
+
return s.streamConvertedResponse(w, resp, req.Model)
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
return s.handleNonStreamResponse(w, resp, req.Model)
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
func (s *OpenAIService) handleNonStreamResponse(w http.ResponseWriter, resp *http.Response, modelID string) error {
|
| 434 |
+
// 读取全部响应体
|
| 435 |
+
bodyBytes, err := io.ReadAll(resp.Body)
|
| 436 |
+
if err != nil {
|
| 437 |
+
return err
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
// 复制响应头
|
| 441 |
+
for k, v := range resp.Header {
|
| 442 |
+
// 过滤掉 Content-Length (会重新计算) 和 Content-Encoding (Go会自动解压)
|
| 443 |
+
if k != "Content-Length" && k != "Content-Encoding" {
|
| 444 |
+
for _, vv := range v {
|
| 445 |
+
w.Header().Add(k, vv)
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
w.WriteHeader(resp.StatusCode)
|
| 450 |
+
|
| 451 |
+
// 尝试解析响应
|
| 452 |
+
var raw map[string]interface{}
|
| 453 |
+
if err := json.Unmarshal(bodyBytes, &raw); err != nil {
|
| 454 |
+
// 如果不是 JSON,检查是否是 SSE 流 (可能是因为我们强制开启了 stream)
|
| 455 |
+
bodyStr := string(bodyBytes)
|
| 456 |
+
trimmedBody := strings.TrimSpace(bodyStr)
|
| 457 |
+
contentType := resp.Header.Get("Content-Type")
|
| 458 |
+
isSSE := strings.Contains(contentType, "text/event-stream") ||
|
| 459 |
+
strings.HasPrefix(trimmedBody, "data:") ||
|
| 460 |
+
strings.HasPrefix(trimmedBody, "event:") ||
|
| 461 |
+
strings.HasPrefix(trimmedBody, ":") ||
|
| 462 |
+
modelID == "gpt-5-nano-2025-08-07" // 强制该模型走 SSE 解析
|
| 463 |
+
|
| 464 |
+
if isSSE {
|
| 465 |
+
var fullContent string
|
| 466 |
+
scanner := bufio.NewScanner(bytes.NewReader(bodyBytes))
|
| 467 |
+
for scanner.Scan() {
|
| 468 |
+
line := strings.TrimSpace(scanner.Text())
|
| 469 |
+
if !strings.HasPrefix(line, "data: ") {
|
| 470 |
+
continue
|
| 471 |
+
}
|
| 472 |
+
data := strings.TrimPrefix(line, "data: ")
|
| 473 |
+
if data == "[DONE]" {
|
| 474 |
+
break
|
| 475 |
+
}
|
| 476 |
+
var chunk map[string]interface{}
|
| 477 |
+
if json.Unmarshal([]byte(data), &chunk) == nil {
|
| 478 |
+
// 尝试提取 content
|
| 479 |
+
if val, ok := chunk["text"].(string); ok {
|
| 480 |
+
fullContent += val
|
| 481 |
+
} else if val, ok := chunk["content"].(string); ok {
|
| 482 |
+
fullContent += val
|
| 483 |
+
} else if val, ok := chunk["response"].(string); ok {
|
| 484 |
+
fullContent += val
|
| 485 |
+
}
|
| 486 |
+
// 标准 chunk
|
| 487 |
+
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
|
| 488 |
+
if choice, ok := choices[0].(map[string]interface{}); ok {
|
| 489 |
+
if delta, ok := choice["delta"].(map[string]interface{}); ok {
|
| 490 |
+
if content, ok := delta["content"].(string); ok {
|
| 491 |
+
fullContent += content
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
// 如果提取到了内容,或者是强制模型(即使没提取到也返回空内容以避免透传错误格式)
|
| 500 |
+
if fullContent != "" || modelID == "gpt-5-nano-2025-08-07" {
|
| 501 |
+
timestamp := time.Now().Unix()
|
| 502 |
+
respObj := model.ChatCompletionResponse{
|
| 503 |
+
ID: fmt.Sprintf("chatcmpl-%d", timestamp),
|
| 504 |
+
Object: "chat.completion",
|
| 505 |
+
Created: timestamp,
|
| 506 |
+
Model: modelID,
|
| 507 |
+
Choices: []model.Choice{
|
| 508 |
+
{
|
| 509 |
+
Index: 0,
|
| 510 |
+
Message: model.ChatMessage{
|
| 511 |
+
Role: "assistant",
|
| 512 |
+
Content: fullContent,
|
| 513 |
+
},
|
| 514 |
+
FinishReason: "stop",
|
| 515 |
+
},
|
| 516 |
+
},
|
| 517 |
+
}
|
| 518 |
+
return json.NewEncoder(w).Encode(respObj)
|
| 519 |
+
}
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
// 既不是 JSON 也不是 SSE,直接透传
|
| 523 |
+
w.Write(bodyBytes)
|
| 524 |
+
return nil
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
// 检查是否已经是 OpenAI 格式 (包含 choices)
|
| 528 |
+
if _, ok := raw["choices"]; ok {
|
| 529 |
+
w.Write(bodyBytes)
|
| 530 |
+
return nil
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// 尝试从常见字段提取内容进行转换
|
| 534 |
+
var content string
|
| 535 |
+
if val, ok := raw["text"].(string); ok {
|
| 536 |
+
content = val
|
| 537 |
+
} else if val, ok := raw["content"].(string); ok {
|
| 538 |
+
content = val
|
| 539 |
+
} else if val, ok := raw["response"].(string); ok {
|
| 540 |
+
content = val
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
if content != "" {
|
| 544 |
+
timestamp := time.Now().Unix()
|
| 545 |
+
respObj := model.ChatCompletionResponse{
|
| 546 |
+
ID: fmt.Sprintf("chatcmpl-%d", timestamp),
|
| 547 |
+
Object: "chat.completion",
|
| 548 |
+
Created: timestamp,
|
| 549 |
+
Model: modelID,
|
| 550 |
+
Choices: []model.Choice{
|
| 551 |
+
{
|
| 552 |
+
Index: 0,
|
| 553 |
+
Message: model.ChatMessage{
|
| 554 |
+
Role: "assistant",
|
| 555 |
+
Content: content,
|
| 556 |
+
},
|
| 557 |
+
FinishReason: "stop",
|
| 558 |
+
},
|
| 559 |
+
},
|
| 560 |
+
}
|
| 561 |
+
return json.NewEncoder(w).Encode(respObj)
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
// 无法识别格式,直接透传
|
| 565 |
+
w.Write(bodyBytes)
|
| 566 |
+
return nil
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
func (s *OpenAIService) streamConvertedResponse(w http.ResponseWriter, resp *http.Response, modelID string) error {
|
| 570 |
+
// 复制响应头
|
| 571 |
+
for k, v := range resp.Header {
|
| 572 |
+
// 过滤掉 Content-Encoding 和 Content-Length
|
| 573 |
+
if k != "Content-Encoding" && k != "Content-Length" {
|
| 574 |
+
for _, vv := range v {
|
| 575 |
+
w.Header().Add(k, vv)
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
}
|
| 579 |
+
w.WriteHeader(resp.StatusCode)
|
| 580 |
+
|
| 581 |
+
flusher, ok := w.(http.Flusher)
|
| 582 |
+
if !ok {
|
| 583 |
+
// 如果不支持Flusher,回退到普通复制
|
| 584 |
+
_, err := io.Copy(w, resp.Body)
|
| 585 |
+
return err
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
reader := bufio.NewReader(resp.Body)
|
| 589 |
+
timestamp := time.Now().Unix()
|
| 590 |
+
id := fmt.Sprintf("chatcmpl-%d", timestamp)
|
| 591 |
+
|
| 592 |
+
for {
|
| 593 |
+
line, err := reader.ReadString('\n')
|
| 594 |
+
if err != nil {
|
| 595 |
+
if err == io.EOF {
|
| 596 |
+
return nil
|
| 597 |
+
}
|
| 598 |
+
return err
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
// 处理空行
|
| 602 |
+
trimmedLine := strings.TrimSpace(line)
|
| 603 |
+
if trimmedLine == "" {
|
| 604 |
+
fmt.Fprintf(w, "\n")
|
| 605 |
+
flusher.Flush()
|
| 606 |
+
continue
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
// 解析 data: 前缀
|
| 610 |
+
if !strings.HasPrefix(trimmedLine, "data: ") {
|
| 611 |
+
// 尝试解析为 JSON 对象 (处理被强制转为非流式的响应)
|
| 612 |
+
var rawObj map[string]interface{}
|
| 613 |
+
if json.Unmarshal([]byte(trimmedLine), &rawObj) == nil {
|
| 614 |
+
// 尝试从 JSON 中提取内容
|
| 615 |
+
var content string
|
| 616 |
+
if val, ok := rawObj["text"].(string); ok {
|
| 617 |
+
content = val
|
| 618 |
+
} else if val, ok := rawObj["content"].(string); ok {
|
| 619 |
+
content = val
|
| 620 |
+
} else if val, ok := rawObj["response"].(string); ok {
|
| 621 |
+
content = val
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
if content != "" {
|
| 625 |
+
// 构造并发送 SSE chunk
|
| 626 |
+
chunk := model.ChatCompletionChunk{
|
| 627 |
+
ID: id,
|
| 628 |
+
Object: "chat.completion.chunk",
|
| 629 |
+
Created: timestamp,
|
| 630 |
+
Model: modelID,
|
| 631 |
+
Choices: []model.StreamChoice{
|
| 632 |
+
{
|
| 633 |
+
Index: 0,
|
| 634 |
+
Delta: model.ChatMessage{
|
| 635 |
+
Content: content,
|
| 636 |
+
},
|
| 637 |
+
FinishReason: nil,
|
| 638 |
+
},
|
| 639 |
+
},
|
| 640 |
+
}
|
| 641 |
+
newBytes, _ := json.Marshal(chunk)
|
| 642 |
+
fmt.Fprintf(w, "data: %s\n\n", string(newBytes))
|
| 643 |
+
|
| 644 |
+
// 发送结束标记
|
| 645 |
+
fmt.Fprintf(w, "data: [DONE]\n\n")
|
| 646 |
+
flusher.Flush()
|
| 647 |
+
return nil
|
| 648 |
+
}
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
// 非 data 行直接通过
|
| 652 |
+
fmt.Fprint(w, line)
|
| 653 |
+
flusher.Flush()
|
| 654 |
+
continue
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
data := strings.TrimPrefix(trimmedLine, "data: ")
|
| 658 |
+
if data == "[DONE]" {
|
| 659 |
+
fmt.Fprintf(w, "data: [DONE]\n\n")
|
| 660 |
+
flusher.Flush()
|
| 661 |
+
return nil
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
// 尝试解析 JSON
|
| 665 |
+
var raw map[string]interface{}
|
| 666 |
+
if err := json.Unmarshal([]byte(data), &raw); err != nil {
|
| 667 |
+
// 解析失败,直接透传
|
| 668 |
+
fmt.Fprint(w, line)
|
| 669 |
+
flusher.Flush()
|
| 670 |
+
continue
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
// 检查是否已经是 OpenAI 格式
|
| 674 |
+
if _, hasChoices := raw["choices"]; hasChoices {
|
| 675 |
+
fmt.Fprint(w, line)
|
| 676 |
+
flusher.Flush()
|
| 677 |
+
continue
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
// 尝试转换非标准格式
|
| 681 |
+
// 假设可能有 text, content, response 等字段
|
| 682 |
+
var content string
|
| 683 |
+
if val, ok := raw["text"].(string); ok {
|
| 684 |
+
content = val
|
| 685 |
+
} else if val, ok := raw["content"].(string); ok {
|
| 686 |
+
content = val
|
| 687 |
+
} else if val, ok := raw["response"].(string); ok {
|
| 688 |
+
content = val
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
if content != "" {
|
| 692 |
+
// 构造标准 OpenAI Chunk
|
| 693 |
+
chunk := model.ChatCompletionChunk{
|
| 694 |
+
ID: id,
|
| 695 |
+
Object: "chat.completion.chunk",
|
| 696 |
+
Created: timestamp,
|
| 697 |
+
Model: modelID,
|
| 698 |
+
Choices: []model.StreamChoice{
|
| 699 |
+
{
|
| 700 |
+
Index: 0,
|
| 701 |
+
Delta: model.ChatMessage{
|
| 702 |
+
Content: content,
|
| 703 |
+
},
|
| 704 |
+
FinishReason: nil,
|
| 705 |
+
},
|
| 706 |
+
},
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
newBytes, _ := json.Marshal(chunk)
|
| 710 |
+
fmt.Fprintf(w, "data: %s\n\n", string(newBytes))
|
| 711 |
+
flusher.Flush()
|
| 712 |
+
} else {
|
| 713 |
+
// 无法识别内容,直接透传
|
| 714 |
+
fmt.Fprint(w, line)
|
| 715 |
+
flusher.Flush()
|
| 716 |
+
}
|
| 717 |
+
}
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
// ResponsesProxy 代理responses请求
|
| 721 |
+
func (s *OpenAIService) ResponsesProxy(ctx context.Context, w http.ResponseWriter, body []byte) error {
|
| 722 |
+
resp, err := s.Responses(ctx, body)
|
| 723 |
+
if err != nil {
|
| 724 |
+
return err
|
| 725 |
+
}
|
| 726 |
+
defer resp.Body.Close()
|
| 727 |
+
|
| 728 |
+
return StreamResponse(w, resp)
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
// retryWithProxy 使用代理池重试OpenAI请求
|
| 732 |
+
func (s *OpenAIService) retryWithProxy(ctx context.Context, account *model.Account, modelID, path string, body []byte) (*http.Response, error) {
|
| 733 |
+
// 获取模型配置
|
| 734 |
+
zenModel, exists := model.GetZenModel(modelID)
|
| 735 |
+
if !exists {
|
| 736 |
+
return nil, fmt.Errorf("模型配置不存在: %s", modelID)
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
proxyPool := provider.GetProxyPool()
|
| 740 |
+
if !proxyPool.HasProxies() {
|
| 741 |
+
return nil, fmt.Errorf("没有可用的代理")
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
maxRetries := 3
|
| 745 |
+
for i := 0; i < maxRetries; i++ {
|
| 746 |
+
// 获取随机代理
|
| 747 |
+
proxyURL := proxyPool.GetRandomProxy()
|
| 748 |
+
if proxyURL == "" {
|
| 749 |
+
continue
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
log.Printf("[OpenAI] 尝试代理 %s (重试 %d/%d)", proxyURL, i+1, maxRetries)
|
| 753 |
+
|
| 754 |
+
// 创建使用代理的HTTP客户端
|
| 755 |
+
proxyClient, err := provider.NewHTTPClientWithProxy(proxyURL, 0)
|
| 756 |
+
if err != nil {
|
| 757 |
+
log.Printf("[OpenAI] 创建代理客户端失败: %v", err)
|
| 758 |
+
continue
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
// 将模型参数合并到请求体中
|
| 762 |
+
modifiedBody := body
|
| 763 |
+
if zenModel.Parameters != nil {
|
| 764 |
+
var raw map[string]interface{}
|
| 765 |
+
if json.Unmarshal(modifiedBody, &raw) == nil {
|
| 766 |
+
// 添加 reasoning 配置
|
| 767 |
+
if zenModel.Parameters.Reasoning != nil && raw["reasoning"] == nil {
|
| 768 |
+
reasoningMap := map[string]interface{}{
|
| 769 |
+
"effort": zenModel.Parameters.Reasoning.Effort,
|
| 770 |
+
}
|
| 771 |
+
if zenModel.Parameters.Reasoning.Summary != "" {
|
| 772 |
+
reasoningMap["summary"] = zenModel.Parameters.Reasoning.Summary
|
| 773 |
+
}
|
| 774 |
+
raw["reasoning"] = reasoningMap
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
// 添加 text 配置
|
| 778 |
+
if zenModel.Parameters.Text != nil && raw["text"] == nil {
|
| 779 |
+
raw["text"] = map[string]interface{}{
|
| 780 |
+
"verbosity": zenModel.Parameters.Text.Verbosity,
|
| 781 |
+
}
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
// 添加 temperature 配置
|
| 785 |
+
if zenModel.Parameters.Temperature != nil && raw["temperature"] == nil {
|
| 786 |
+
raw["temperature"] = *zenModel.Parameters.Temperature
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
modifiedBody, _ = json.Marshal(raw)
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
// 特殊模型的额外处理
|
| 794 |
+
if modelID == "gpt-5-nano-2025-08-07" {
|
| 795 |
+
var raw map[string]interface{}
|
| 796 |
+
if json.Unmarshal(modifiedBody, &raw) == nil {
|
| 797 |
+
if _, ok := raw["text"]; !ok {
|
| 798 |
+
raw["text"] = map[string]string{"verbosity": "medium"}
|
| 799 |
+
}
|
| 800 |
+
if _, ok := raw["temperature"]; !ok {
|
| 801 |
+
raw["temperature"] = 1
|
| 802 |
+
}
|
| 803 |
+
raw["stream"] = true
|
| 804 |
+
if reasoning, ok := raw["reasoning"].(map[string]interface{}); ok {
|
| 805 |
+
reasoning["summary"] = "auto"
|
| 806 |
+
raw["reasoning"] = reasoning
|
| 807 |
+
} else {
|
| 808 |
+
raw["reasoning"] = map[string]interface{}{
|
| 809 |
+
"effort": "minimal",
|
| 810 |
+
"summary": "auto",
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
modifiedBody, _ = json.Marshal(raw)
|
| 814 |
+
}
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
// 创建新请求
|
| 818 |
+
reqURL := OpenAIBaseURL + path
|
| 819 |
+
httpReq, err := http.NewRequest("POST", reqURL, bytes.NewReader(modifiedBody))
|
| 820 |
+
if err != nil {
|
| 821 |
+
log.Printf("[OpenAI] 创建请求失败: %v", err)
|
| 822 |
+
continue
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
// 设置请求头
|
| 826 |
+
SetZencoderHeaders(httpReq, account, zenModel)
|
| 827 |
+
|
| 828 |
+
// 添加模型配置的额外请求头
|
| 829 |
+
if zenModel.Parameters != nil && zenModel.Parameters.ExtraHeaders != nil {
|
| 830 |
+
for k, v := range zenModel.Parameters.ExtraHeaders {
|
| 831 |
+
httpReq.Header.Set(k, v)
|
| 832 |
+
}
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
// 强制记录代理请求体用于调试
|
| 836 |
+
log.Printf("[DEBUG] [OpenAI] 代理请求体:")
|
| 837 |
+
log.Printf("[DEBUG] [OpenAI] %s", string(modifiedBody))
|
| 838 |
+
|
| 839 |
+
// 执行请求
|
| 840 |
+
resp, err := proxyClient.Do(httpReq)
|
| 841 |
+
if err != nil {
|
| 842 |
+
log.Printf("[OpenAI] 代理请求失败: %v", err)
|
| 843 |
+
continue
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
// 检查响应状态
|
| 847 |
+
if resp.StatusCode == 429 {
|
| 848 |
+
// 仍然是429,尝试下一个代理
|
| 849 |
+
resp.Body.Close()
|
| 850 |
+
log.Printf("[OpenAI] 代理 %s 仍返回429,尝试下一个", proxyURL)
|
| 851 |
+
continue
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
if resp.StatusCode >= 400 {
|
| 855 |
+
// 其他错误,记录并尝试下一个代理
|
| 856 |
+
errBody, _ := io.ReadAll(resp.Body)
|
| 857 |
+
resp.Body.Close()
|
| 858 |
+
log.Printf("[OpenAI] 代理 %s 返回错误 %d: %s", proxyURL, resp.StatusCode, string(errBody))
|
| 859 |
+
continue
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
// 成功
|
| 863 |
+
log.Printf("[OpenAI] 代理 %s 请求成功", proxyURL)
|
| 864 |
+
return resp, nil
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
return nil, fmt.Errorf("所有代理重试均失败")
|
| 868 |
+
}
|
internal/service/pool.go
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"log"
|
| 6 |
+
"net/http"
|
| 7 |
+
"sync"
|
| 8 |
+
"time"
|
| 9 |
+
|
| 10 |
+
"zencoder-2api/internal/database"
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
type AccountPool struct {
|
| 15 |
+
mu sync.RWMutex
|
| 16 |
+
accounts []*model.Account
|
| 17 |
+
index uint64
|
| 18 |
+
maxErrs int
|
| 19 |
+
stopChan chan struct{}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
var pool *AccountPool
|
| 23 |
+
|
| 24 |
+
func init() {
|
| 25 |
+
pool = &AccountPool{
|
| 26 |
+
maxErrs: 3,
|
| 27 |
+
accounts: make([]*model.Account, 0),
|
| 28 |
+
stopChan: make(chan struct{}),
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// InitAccountPool 初始化账号池并启动刷新协程
|
| 33 |
+
func InitAccountPool() {
|
| 34 |
+
// 数据迁移:将旧字段状态迁移到 Status
|
| 35 |
+
pool.migrateData()
|
| 36 |
+
|
| 37 |
+
// 初始加载
|
| 38 |
+
pool.refresh()
|
| 39 |
+
// 启动后台刷新
|
| 40 |
+
go pool.refreshLoop()
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
func (p *AccountPool) migrateData() {
|
| 44 |
+
db := database.GetDB()
|
| 45 |
+
// 默认设为 normal
|
| 46 |
+
db.Model(&model.Account{}).Where("status = '' OR status IS NULL").Update("status", "normal")
|
| 47 |
+
|
| 48 |
+
// 迁移冷却状态
|
| 49 |
+
db.Model(&model.Account{}).Where("is_cooling = ?", true).Update("status", "cooling")
|
| 50 |
+
|
| 51 |
+
// 迁移错误封禁状态
|
| 52 |
+
db.Model(&model.Account{}).Where("is_active = ? AND error_count >= ?", false, p.maxErrs).Update("status", "error")
|
| 53 |
+
|
| 54 |
+
// 迁移手动禁用状态 (!Active && !Cooling && Error < Max)
|
| 55 |
+
db.Model(&model.Account{}).Where("is_active = ? AND is_cooling = ? AND error_count < ?", false, false, p.maxErrs).Update("status", "disabled")
|
| 56 |
+
|
| 57 |
+
// 迁移 category 到 status (如果 category 是 banned/error/cooling/abnormal)
|
| 58 |
+
db.Model(&model.Account{}).Where("category = ?", "banned").Update("status", "banned")
|
| 59 |
+
db.Model(&model.Account{}).Where("category = ?", "error").Update("status", "error")
|
| 60 |
+
db.Model(&model.Account{}).Where("category = ?", "cooling").Update("status", "cooling")
|
| 61 |
+
db.Model(&model.Account{}).Where("category = ?", "abnormal").Update("status", "cooling")
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
func (p *AccountPool) refreshLoop() {
|
| 65 |
+
ticker := time.NewTicker(30 * time.Second)
|
| 66 |
+
defer ticker.Stop()
|
| 67 |
+
|
| 68 |
+
for {
|
| 69 |
+
select {
|
| 70 |
+
case <-ticker.C:
|
| 71 |
+
p.refresh()
|
| 72 |
+
p.cleanupTimeoutAccounts() // 清理超时账号
|
| 73 |
+
case <-p.stopChan:
|
| 74 |
+
return
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// cleanupTimeoutAccounts 定期清理超时的账号状态
|
| 80 |
+
func (p *AccountPool) cleanupTimeoutAccounts() {
|
| 81 |
+
now := time.Now()
|
| 82 |
+
statusMu.Lock()
|
| 83 |
+
defer statusMu.Unlock()
|
| 84 |
+
|
| 85 |
+
cleanedCount := 0
|
| 86 |
+
for _, status := range accountStatuses {
|
| 87 |
+
// 清理超过60秒还在使用中的账号
|
| 88 |
+
if status.InUse && !status.InUseSince.IsZero() && now.Sub(status.InUseSince) > 60*time.Second {
|
| 89 |
+
status.InUse = false
|
| 90 |
+
status.InUseSince = time.Time{}
|
| 91 |
+
cleanedCount++
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
if cleanedCount > 0 {
|
| 96 |
+
log.Printf("[INFO] 定期清理:释放了 %d 个超时账号", cleanedCount)
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
func (p *AccountPool) refresh() {
|
| 101 |
+
// 先恢复冷却账号
|
| 102 |
+
recoverCoolingAccounts()
|
| 103 |
+
|
| 104 |
+
// 刷新即将过期的token(1小时内过期)
|
| 105 |
+
p.refreshExpiredTokens()
|
| 106 |
+
|
| 107 |
+
var dbAccounts []model.Account
|
| 108 |
+
// 只查询状态为 normal 的账号
|
| 109 |
+
result := database.GetDB().Where("status = ?", "normal").
|
| 110 |
+
Where("token_expiry > ?", time.Now()).
|
| 111 |
+
Find(&dbAccounts)
|
| 112 |
+
|
| 113 |
+
if result.Error != nil {
|
| 114 |
+
log.Printf("[Error] Failed to refresh account pool: %v", result.Error)
|
| 115 |
+
return
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
p.mu.Lock()
|
| 119 |
+
defer p.mu.Unlock()
|
| 120 |
+
|
| 121 |
+
// 重新构建缓存,但保留现有对象的指针以维持状态(如果ID匹配)
|
| 122 |
+
// 或者简单全量替换,依赖 30s 的一致性窗口
|
| 123 |
+
// 为了简化并防止并发问题,这里使用全量替换,将 DB 数据作为 Source of Truth
|
| 124 |
+
newAccounts := make([]*model.Account, len(dbAccounts))
|
| 125 |
+
for i := range dbAccounts {
|
| 126 |
+
newAccounts[i] = &dbAccounts[i]
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// 如果账号数量有显著变化,记录日志
|
| 130 |
+
oldCount := len(p.accounts)
|
| 131 |
+
newCount := len(newAccounts)
|
| 132 |
+
if oldCount != newCount {
|
| 133 |
+
log.Printf("[AccountPool] 账号池刷新:%d -> %d 个可用账号", oldCount, newCount)
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
p.accounts = newAccounts
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// refreshExpiredTokens 刷新即将过期的账号token
|
| 140 |
+
func (p *AccountPool) refreshExpiredTokens() {
|
| 141 |
+
now := time.Now()
|
| 142 |
+
threshold := now.Add(time.Hour) // 1小时内即将过期的token
|
| 143 |
+
|
| 144 |
+
var expiredAccounts []model.Account
|
| 145 |
+
// 只排除banned状态的账号,其他状态的账号仍可以刷新token
|
| 146 |
+
result := database.GetDB().Where("status != ?", "banned").
|
| 147 |
+
Where("client_id != '' AND client_secret != ''").
|
| 148 |
+
Where("token_expiry < ?", threshold).
|
| 149 |
+
Find(&expiredAccounts)
|
| 150 |
+
|
| 151 |
+
if result.Error != nil {
|
| 152 |
+
log.Printf("[AccountPool] 查询即将过期的账号失败: %v", result.Error)
|
| 153 |
+
return
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// 额外验证:再次过滤掉banned状态的账号
|
| 157 |
+
var validAccounts []model.Account
|
| 158 |
+
for _, acc := range expiredAccounts {
|
| 159 |
+
if acc.Status != "banned" {
|
| 160 |
+
validAccounts = append(validAccounts, acc)
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
expiredAccounts = validAccounts
|
| 164 |
+
|
| 165 |
+
if len(expiredAccounts) == 0 {
|
| 166 |
+
return
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
log.Printf("[AccountPool] 发现 %d 个非封禁账号的token需要刷新", len(expiredAccounts))
|
| 170 |
+
|
| 171 |
+
// 限制并发刷新数量,避免对API造成压力
|
| 172 |
+
semaphore := make(chan struct{}, 10) // 最多10个并发
|
| 173 |
+
var refreshCount int32
|
| 174 |
+
var successCount int32
|
| 175 |
+
|
| 176 |
+
// 并发刷新token
|
| 177 |
+
for i := range expiredAccounts {
|
| 178 |
+
account := &expiredAccounts[i]
|
| 179 |
+
|
| 180 |
+
go func(acc *model.Account) {
|
| 181 |
+
semaphore <- struct{}{} // 获取信号量
|
| 182 |
+
defer func() { <-semaphore }() // 释放信号量
|
| 183 |
+
|
| 184 |
+
refreshCount++
|
| 185 |
+
|
| 186 |
+
// 根据账号类型选择不同的刷新方式
|
| 187 |
+
if acc.ClientSecret == "refresh-token-login" {
|
| 188 |
+
// refresh-token-login 账号使用 refresh_token 刷新
|
| 189 |
+
if err := p.refreshRefreshTokenAccount(acc); err != nil {
|
| 190 |
+
log.Printf("[AccountPool] refresh-token账号 %s (ID:%d) token刷新失败: %v",
|
| 191 |
+
acc.ClientID, acc.ID, err)
|
| 192 |
+
} else {
|
| 193 |
+
successCount++
|
| 194 |
+
log.Printf("[AccountPool] refresh-token账号 %s (ID:%d) token刷新成功,新过期时间: %s",
|
| 195 |
+
acc.ClientID, acc.ID, acc.TokenExpiry.Format("2006-01-02 15:04:05"))
|
| 196 |
+
}
|
| 197 |
+
} else {
|
| 198 |
+
// 普通账号使用 OAuth client credentials 刷新
|
| 199 |
+
if err := p.refreshSingleAccountToken(acc); err != nil {
|
| 200 |
+
log.Printf("[AccountPool] 账号 %s (ID:%d) token刷新失败: %v",
|
| 201 |
+
acc.ClientID, acc.ID, err)
|
| 202 |
+
} else {
|
| 203 |
+
successCount++
|
| 204 |
+
log.Printf("[AccountPool] 账号 %s (ID:%d) token刷新成功,新过期时间: %s",
|
| 205 |
+
acc.ClientID, acc.ID, acc.TokenExpiry.Format("2006-01-02 15:04:05"))
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}(account)
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
// 等待所有刷新完成(最多等待30秒)
|
| 212 |
+
timeout := time.After(30 * time.Second)
|
| 213 |
+
ticker := time.NewTicker(100 * time.Millisecond)
|
| 214 |
+
defer ticker.Stop()
|
| 215 |
+
|
| 216 |
+
for {
|
| 217 |
+
select {
|
| 218 |
+
case <-timeout:
|
| 219 |
+
log.Printf("[AccountPool] Token刷新超时,已完成 %d/%d", refreshCount, len(expiredAccounts))
|
| 220 |
+
return
|
| 221 |
+
case <-ticker.C:
|
| 222 |
+
if int(refreshCount) >= len(expiredAccounts) {
|
| 223 |
+
log.Printf("[AccountPool] Token刷新完成:成功 %d/%d", successCount, len(expiredAccounts))
|
| 224 |
+
return
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
// refreshSingleAccountToken 刷新单个账号的token
|
| 231 |
+
func (p *AccountPool) refreshSingleAccountToken(account *model.Account) error {
|
| 232 |
+
return refreshAccountToken(account)
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
func GetNextAccount() (*model.Account, error) {
|
| 236 |
+
return GetNextAccountForModel("")
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
// AccountStatus 账号运行时状态
|
| 240 |
+
type AccountStatus struct {
|
| 241 |
+
LastUsed time.Time
|
| 242 |
+
InUse bool
|
| 243 |
+
FrozenUntil time.Time
|
| 244 |
+
InUseSince time.Time // 记录开始使用的时间
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// 账号运行时状态管理
|
| 248 |
+
var (
|
| 249 |
+
accountStatuses = make(map[uint]*AccountStatus)
|
| 250 |
+
statusMu sync.RWMutex
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
// GetNextAccountForModel 获取可用于指定模型的账号
|
| 254 |
+
// 使用内存状态管理,避免高并发下的竞态条件
|
| 255 |
+
func GetNextAccountForModel(modelID string) (*model.Account, error) {
|
| 256 |
+
pool.mu.RLock()
|
| 257 |
+
accounts := pool.accounts // 获取账号列表引用
|
| 258 |
+
pool.mu.RUnlock()
|
| 259 |
+
|
| 260 |
+
if len(accounts) == 0 {
|
| 261 |
+
return nil, ErrNoAvailableAccount
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
// 获取候选账号
|
| 265 |
+
var candidates []*model.Account
|
| 266 |
+
now := time.Now()
|
| 267 |
+
statusMu.RLock()
|
| 268 |
+
for _, acc := range accounts {
|
| 269 |
+
// 检查模型权限
|
| 270 |
+
if modelID != "" && !model.CanUseModel(acc.PlanType, modelID) {
|
| 271 |
+
continue
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
// 获取或初始化状态
|
| 275 |
+
status, exists := accountStatuses[acc.ID]
|
| 276 |
+
if !exists {
|
| 277 |
+
// 初始化状态
|
| 278 |
+
accountStatuses[acc.ID] = &AccountStatus{
|
| 279 |
+
LastUsed: acc.LastUsed,
|
| 280 |
+
InUse: false,
|
| 281 |
+
FrozenUntil: acc.CoolingUntil,
|
| 282 |
+
InUseSince: time.Time{},
|
| 283 |
+
}
|
| 284 |
+
status = accountStatuses[acc.ID]
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// 自动释放超时账号(超过30秒未释放的账号)
|
| 288 |
+
if status.InUse && !status.InUseSince.IsZero() && now.Sub(status.InUseSince) > 30*time.Second {
|
| 289 |
+
status.InUse = false
|
| 290 |
+
status.InUseSince = time.Time{}
|
| 291 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 使用超时,已自动释放", acc.Email, acc.ID)
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
// 检查是否可用(未被使用且未被冻结)
|
| 295 |
+
if !status.InUse && now.After(status.FrozenUntil) {
|
| 296 |
+
candidates = append(candidates, acc)
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
statusMu.RUnlock()
|
| 300 |
+
|
| 301 |
+
if len(candidates) == 0 {
|
| 302 |
+
// 提供详细的调试信息
|
| 303 |
+
totalAccounts := len(accounts)
|
| 304 |
+
inUseCount := 0
|
| 305 |
+
frozenCount := 0
|
| 306 |
+
noPermissionCount := 0
|
| 307 |
+
|
| 308 |
+
statusMu.RLock()
|
| 309 |
+
for _, acc := range accounts {
|
| 310 |
+
if modelID != "" && !model.CanUseModel(acc.PlanType, modelID) {
|
| 311 |
+
noPermissionCount++
|
| 312 |
+
continue
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
if status, exists := accountStatuses[acc.ID]; exists {
|
| 316 |
+
if status.InUse {
|
| 317 |
+
inUseCount++
|
| 318 |
+
} else if !now.After(status.FrozenUntil) {
|
| 319 |
+
frozenCount++
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
statusMu.RUnlock()
|
| 324 |
+
|
| 325 |
+
log.Printf("[ERROR] 无可用账号 - 总账号数: %d, 权限不足: %d, 使用中: %d, 冻结中: %d, 模型: %s",
|
| 326 |
+
totalAccounts, noPermissionCount, inUseCount, frozenCount, modelID)
|
| 327 |
+
|
| 328 |
+
return nil, ErrNoPermission
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// 选择最长时间未使用的账号
|
| 332 |
+
var selected *model.Account
|
| 333 |
+
oldestTime := time.Now()
|
| 334 |
+
|
| 335 |
+
statusMu.RLock()
|
| 336 |
+
for _, acc := range candidates {
|
| 337 |
+
status := accountStatuses[acc.ID]
|
| 338 |
+
if status == nil {
|
| 339 |
+
continue
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// 如果账号从未使用过,优先选择
|
| 343 |
+
if status.LastUsed.IsZero() {
|
| 344 |
+
selected = acc
|
| 345 |
+
break
|
| 346 |
+
}
|
| 347 |
+
// 选择最长时间未使用的账号
|
| 348 |
+
if status.LastUsed.Before(oldestTime) {
|
| 349 |
+
oldestTime = status.LastUsed
|
| 350 |
+
selected = acc
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
statusMu.RUnlock()
|
| 354 |
+
|
| 355 |
+
// 如果没有找到合适的账号,使用轮询
|
| 356 |
+
if selected == nil {
|
| 357 |
+
selected = candidates[time.Now().UnixNano()%int64(len(candidates))]
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// 立即在内存中标记账号为使用中
|
| 361 |
+
statusMu.Lock()
|
| 362 |
+
currentTime := time.Now()
|
| 363 |
+
if status, exists := accountStatuses[selected.ID]; exists {
|
| 364 |
+
status.InUse = true
|
| 365 |
+
status.LastUsed = currentTime
|
| 366 |
+
status.InUseSince = currentTime
|
| 367 |
+
} else {
|
| 368 |
+
accountStatuses[selected.ID] = &AccountStatus{
|
| 369 |
+
LastUsed: currentTime,
|
| 370 |
+
InUse: true,
|
| 371 |
+
FrozenUntil: time.Time{},
|
| 372 |
+
InUseSince: currentTime,
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
statusMu.Unlock()
|
| 376 |
+
|
| 377 |
+
// 异步更新数据库
|
| 378 |
+
go func(acc *model.Account, usedTime time.Time) {
|
| 379 |
+
database.GetDB().Model(acc).Update("last_used", usedTime)
|
| 380 |
+
}(selected, time.Now())
|
| 381 |
+
|
| 382 |
+
return selected, nil
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
// ReleaseAccount 释放账号(标记为未使用)
|
| 386 |
+
func ReleaseAccount(account *model.Account) {
|
| 387 |
+
if account == nil {
|
| 388 |
+
return
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
statusMu.Lock()
|
| 392 |
+
defer statusMu.Unlock()
|
| 393 |
+
|
| 394 |
+
if status, exists := accountStatuses[account.ID]; exists {
|
| 395 |
+
status.InUse = false
|
| 396 |
+
status.InUseSince = time.Time{} // 重置使用开始时间
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// recoverCoolingAccounts 恢复冷却期已过的账号
|
| 401 |
+
func recoverCoolingAccounts() {
|
| 402 |
+
var coolingAccounts []model.Account
|
| 403 |
+
// 查询 status = cooling 且时间已到的账号(使用 UTC 时间)
|
| 404 |
+
nowUTC := time.Now().UTC()
|
| 405 |
+
database.GetDB().Where("status = ?", "cooling").
|
| 406 |
+
Where("cooling_until < ?", nowUTC).
|
| 407 |
+
Find(&coolingAccounts)
|
| 408 |
+
|
| 409 |
+
for _, acc := range coolingAccounts {
|
| 410 |
+
acc.IsCooling = false
|
| 411 |
+
acc.IsActive = true
|
| 412 |
+
acc.Category = "normal" // 保持兼容
|
| 413 |
+
acc.Status = "normal" // 恢复状态
|
| 414 |
+
acc.BanReason = "" // 清除封禁原因
|
| 415 |
+
database.GetDB().Save(&acc)
|
| 416 |
+
log.Printf("[INFO] 账号 %s (ID:%d) 冷却期结束,已恢复 (冷却结束时间: %s UTC)",
|
| 417 |
+
acc.Email, acc.ID, acc.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
func MarkAccountError(account *model.Account) {
|
| 422 |
+
account.ErrorCount++
|
| 423 |
+
if account.ErrorCount >= pool.maxErrs {
|
| 424 |
+
account.IsActive = false
|
| 425 |
+
account.Status = "error" // 更新状态
|
| 426 |
+
account.Category = "error"
|
| 427 |
+
account.BanReason = "Error count exceeded limit"
|
| 428 |
+
}
|
| 429 |
+
database.GetDB().Save(account)
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// MarkAccountRateLimited 标记账号遇到 429 限流错误
|
| 433 |
+
func MarkAccountRateLimited(account *model.Account) {
|
| 434 |
+
account.RateLimitHits++
|
| 435 |
+
account.IsCooling = true
|
| 436 |
+
account.IsActive = false
|
| 437 |
+
|
| 438 |
+
// 设置冷却时间:1小时(使用UTC时间)
|
| 439 |
+
account.CoolingUntil = time.Now().UTC().Add(1 * time.Hour)
|
| 440 |
+
|
| 441 |
+
// 更新状态
|
| 442 |
+
oldStatus := account.Status
|
| 443 |
+
account.Status = "cooling"
|
| 444 |
+
account.Category = "cooling"
|
| 445 |
+
account.BanReason = "Rate limited (429)"
|
| 446 |
+
|
| 447 |
+
database.GetDB().Save(account)
|
| 448 |
+
|
| 449 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 遇到 429 限流 (第 %d 次),已移至冷却分组,冷却至 %s UTC",
|
| 450 |
+
account.Email, account.ID, account.RateLimitHits, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 451 |
+
|
| 452 |
+
if oldStatus != "cooling" {
|
| 453 |
+
log.Printf("[INFO] 账号 %s 状态变更: %s -> cooling", account.Email, oldStatus)
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
// MarkAccountRateLimitedWithResponse 根据响应头信息处理429限流错误
|
| 458 |
+
func MarkAccountRateLimitedWithResponse(account *model.Account, resp *http.Response) {
|
| 459 |
+
if resp == nil || resp.Header == nil {
|
| 460 |
+
// 如果没有响应头,使用默认处理
|
| 461 |
+
MarkAccountRateLimited(account)
|
| 462 |
+
return
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// 获取响应头中的积分信息
|
| 466 |
+
periodLimit := resp.Header.Get("Zen-Pricing-Period-Limit")
|
| 467 |
+
periodCost := resp.Header.Get("Zen-Pricing-Period-Cost")
|
| 468 |
+
periodEnd := resp.Header.Get("Zen-Pricing-Period-End")
|
| 469 |
+
|
| 470 |
+
// 检查是否为积分耗尽导致的429
|
| 471 |
+
isQuotaExhausted := false
|
| 472 |
+
if periodLimit != "" && periodCost != "" {
|
| 473 |
+
limit := parseFloat(periodLimit)
|
| 474 |
+
used := parseFloat(periodCost)
|
| 475 |
+
|
| 476 |
+
// 如果使用积分 >= 最大积分,说明积分已满
|
| 477 |
+
if limit > 0 && used >= limit {
|
| 478 |
+
isQuotaExhausted = true
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
account.RateLimitHits++
|
| 483 |
+
account.IsCooling = true
|
| 484 |
+
account.IsActive = false
|
| 485 |
+
|
| 486 |
+
oldStatus := account.Status
|
| 487 |
+
account.Status = "cooling"
|
| 488 |
+
account.Category = "cooling"
|
| 489 |
+
|
| 490 |
+
if isQuotaExhausted {
|
| 491 |
+
// 积分耗尽导致的429,根据periodEnd设置冷却时间
|
| 492 |
+
if periodEnd != "" {
|
| 493 |
+
if endTime, err := time.Parse(time.RFC3339, periodEnd); err == nil {
|
| 494 |
+
account.CoolingUntil = endTime
|
| 495 |
+
account.BanReason = "Quota exhausted (429)"
|
| 496 |
+
// 同时更新积分刷新时间
|
| 497 |
+
account.CreditRefreshTime = endTime
|
| 498 |
+
|
| 499 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 积分耗尽导致429限���,冷却至积分刷新时间: %s UTC",
|
| 500 |
+
account.Email, account.ID, endTime.Format("2006-01-02 15:04:05"))
|
| 501 |
+
} else {
|
| 502 |
+
// 解析失败,使用默认冷却时间
|
| 503 |
+
account.CoolingUntil = time.Now().UTC().Add(1 * time.Hour)
|
| 504 |
+
account.BanReason = "Quota exhausted (429) - fallback cooling"
|
| 505 |
+
|
| 506 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 积分耗尽但无法解析刷新时间,使用默认冷却: %s UTC",
|
| 507 |
+
account.Email, account.ID, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 508 |
+
}
|
| 509 |
+
} else {
|
| 510 |
+
// 没有periodEnd,使用默认冷却时间
|
| 511 |
+
account.CoolingUntil = time.Now().UTC().Add(1 * time.Hour)
|
| 512 |
+
account.BanReason = "Quota exhausted (429) - no end time"
|
| 513 |
+
|
| 514 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 积分耗尽但无刷新时间信息,使用默认冷却: %s UTC",
|
| 515 |
+
account.Email, account.ID, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 516 |
+
}
|
| 517 |
+
} else {
|
| 518 |
+
// 常规429限流错误,使用默认冷却时间
|
| 519 |
+
account.CoolingUntil = time.Now().UTC().Add(1 * time.Hour)
|
| 520 |
+
account.BanReason = "Rate limited (429)"
|
| 521 |
+
|
| 522 |
+
log.Printf("[WARN] 账号 %s (ID:%d) 遇到常规429限流 (第 %d 次),冷却至: %s UTC",
|
| 523 |
+
account.Email, account.ID, account.RateLimitHits, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
database.GetDB().Save(account)
|
| 527 |
+
|
| 528 |
+
if oldStatus != "cooling" {
|
| 529 |
+
log.Printf("[INFO] 账号 %s 状态变更: %s -> cooling", account.Email, oldStatus)
|
| 530 |
+
}
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// MarkAccountRateLimitedShort 标记账号遇到 429 限流错误(短期冷却)
|
| 534 |
+
func MarkAccountRateLimitedShort(account *model.Account) {
|
| 535 |
+
account.RateLimitHits++
|
| 536 |
+
account.IsCooling = true
|
| 537 |
+
account.IsActive = false
|
| 538 |
+
|
| 539 |
+
// 设置短期冷却时间:5秒(使用UTC时间)
|
| 540 |
+
account.CoolingUntil = time.Now().UTC().Add(5 * time.Second)
|
| 541 |
+
|
| 542 |
+
// 更新状态
|
| 543 |
+
account.Status = "cooling"
|
| 544 |
+
account.Category = "cooling"
|
| 545 |
+
account.BanReason = "Rate limited (429) - short cooling"
|
| 546 |
+
|
| 547 |
+
database.GetDB().Save(account)
|
| 548 |
+
|
| 549 |
+
log.Printf("[INFO] 账号 %s (ID:%d) 短期冷却,冷却至 %s UTC",
|
| 550 |
+
account.Email, account.ID, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
// FreezeAccount 冻结账号指定时间(用于500错误限速)
|
| 554 |
+
func FreezeAccount(account *model.Account, duration time.Duration) {
|
| 555 |
+
if account == nil {
|
| 556 |
+
return
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
freezeUntil := time.Now().Add(duration)
|
| 560 |
+
|
| 561 |
+
// 立即在内存中更新冻结状态
|
| 562 |
+
statusMu.Lock()
|
| 563 |
+
if status, exists := accountStatuses[account.ID]; exists {
|
| 564 |
+
status.FrozenUntil = freezeUntil
|
| 565 |
+
status.InUse = false // 释放账号
|
| 566 |
+
status.InUseSince = time.Time{} // 重置使用开始时间
|
| 567 |
+
} else {
|
| 568 |
+
accountStatuses[account.ID] = &AccountStatus{
|
| 569 |
+
LastUsed: time.Now(),
|
| 570 |
+
InUse: false,
|
| 571 |
+
FrozenUntil: freezeUntil,
|
| 572 |
+
InUseSince: time.Time{},
|
| 573 |
+
}
|
| 574 |
+
}
|
| 575 |
+
statusMu.Unlock()
|
| 576 |
+
|
| 577 |
+
// 异步更新数据库
|
| 578 |
+
go func() {
|
| 579 |
+
// 设置冷却时间(使用UTC时间)
|
| 580 |
+
account.CoolingUntil = freezeUntil.UTC()
|
| 581 |
+
account.IsCooling = true
|
| 582 |
+
account.IsActive = false
|
| 583 |
+
|
| 584 |
+
// 更新状态
|
| 585 |
+
account.Status = "cooling"
|
| 586 |
+
account.Category = "cooling"
|
| 587 |
+
account.BanReason = "Rate limit tracking problem (500)"
|
| 588 |
+
|
| 589 |
+
database.GetDB().Save(account)
|
| 590 |
+
}()
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
func ResetAccountError(account *model.Account) {
|
| 594 |
+
account.ErrorCount = 0
|
| 595 |
+
database.GetDB().Save(account)
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
// 扣减积分并检查是否需要冷却
|
| 599 |
+
func UseCredit(account *model.Account, multiplier float64) {
|
| 600 |
+
account.DailyUsed += multiplier
|
| 601 |
+
account.TotalUsed += multiplier
|
| 602 |
+
account.LastUsed = time.Now() // 更新最后使用时间
|
| 603 |
+
|
| 604 |
+
limit := float64(model.PlanLimits[account.PlanType])
|
| 605 |
+
if account.DailyUsed >= limit {
|
| 606 |
+
account.IsCooling = true
|
| 607 |
+
account.Status = "cooling" // 更新状态
|
| 608 |
+
account.Category = "cooling"
|
| 609 |
+
account.BanReason = "Daily quota exceeded"
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
database.GetDB().Save(account)
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
// UpdateAccountCreditsFromResponse 根据响应头中的积分信息更新账号
|
| 616 |
+
// 如果响应头中有积分信息,使用实际值;否则使用模型倍率
|
| 617 |
+
func UpdateAccountCreditsFromResponse(account *model.Account, resp *http.Response, modelMultiplier float64) {
|
| 618 |
+
// 无论如何都要更新最后使用时间
|
| 619 |
+
account.LastUsed = time.Now()
|
| 620 |
+
|
| 621 |
+
if resp == nil || resp.Header == nil {
|
| 622 |
+
// 如果没有响应头,使用模型倍率
|
| 623 |
+
UseCredit(account, modelMultiplier)
|
| 624 |
+
return
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
// 获取响应头中的积分信息
|
| 628 |
+
periodLimit := resp.Header.Get("Zen-Pricing-Period-Limit")
|
| 629 |
+
periodCost := resp.Header.Get("Zen-Pricing-Period-Cost")
|
| 630 |
+
requestCost := resp.Header.Get("Zen-Request-Cost")
|
| 631 |
+
periodEnd := resp.Header.Get("Zen-Pricing-Period-End")
|
| 632 |
+
|
| 633 |
+
// 解析本次请求消耗的积分
|
| 634 |
+
var creditUsed float64
|
| 635 |
+
hasAPICredits := false
|
| 636 |
+
|
| 637 |
+
if requestCost != "" {
|
| 638 |
+
if val := parseFloat(requestCost); val > 0 {
|
| 639 |
+
creditUsed = val
|
| 640 |
+
hasAPICredits = true
|
| 641 |
+
}
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
// 如果有 periodCost,更新账号的总使用量(当日总计)
|
| 645 |
+
if periodCost != "" {
|
| 646 |
+
if val := parseFloat(periodCost); val >= 0 {
|
| 647 |
+
// 直接使用API返回的当日使用量
|
| 648 |
+
account.DailyUsed = val
|
| 649 |
+
hasAPICredits = true
|
| 650 |
+
}
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
// 如果有 periodLimit,可以用于验证账号计划类型
|
| 654 |
+
if periodLimit != "" {
|
| 655 |
+
if limit := parseFloat(periodLimit); limit > 0 {
|
| 656 |
+
// 可选:验证或更新账号的计划类型
|
| 657 |
+
// 这里只记录日志,不改变计划类型
|
| 658 |
+
expectedLimit := float64(model.PlanLimits[account.PlanType])
|
| 659 |
+
if limit != expectedLimit && IsDebugMode() {
|
| 660 |
+
log.Printf("[INFO] 账号 %s (ID:%d) API限额(%v)与本地限额(%v)不一致",
|
| 661 |
+
account.Email, account.ID, limit, expectedLimit)
|
| 662 |
+
}
|
| 663 |
+
}
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
// 解析冷却到期时间(UTC时间)和积分刷新时间
|
| 667 |
+
var coolingEndTime time.Time
|
| 668 |
+
if periodEnd != "" {
|
| 669 |
+
if t, err := time.Parse(time.RFC3339, periodEnd); err == nil {
|
| 670 |
+
coolingEndTime = t
|
| 671 |
+
// 同时更新积分刷新时间
|
| 672 |
+
account.CreditRefreshTime = t
|
| 673 |
+
} else {
|
| 674 |
+
// 如果解析失败,记录日志
|
| 675 |
+
log.Printf("[WARN] 无法解析 Zen-Pricing-Period-End: %s, error: %v", periodEnd, err)
|
| 676 |
+
}
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
if hasAPICredits {
|
| 680 |
+
// 使用API返回的积分值
|
| 681 |
+
if requestCost != "" && creditUsed > 0 {
|
| 682 |
+
account.TotalUsed += creditUsed
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
// 检查是否需要冷却
|
| 686 |
+
limit := float64(model.PlanLimits[account.PlanType])
|
| 687 |
+
if account.DailyUsed >= limit {
|
| 688 |
+
account.IsCooling = true
|
| 689 |
+
account.Status = "cooling"
|
| 690 |
+
account.Category = "cooling"
|
| 691 |
+
account.BanReason = "Daily quota exceeded"
|
| 692 |
+
|
| 693 |
+
// 如果有响应头中的冷却到期时间,使用它;否则使用默认时间
|
| 694 |
+
if !coolingEndTime.IsZero() {
|
| 695 |
+
account.CoolingUntil = coolingEndTime
|
| 696 |
+
log.Printf("[INFO] 账号 %s (ID:%d) 积分耗尽,进入冷却,到期时间: %s (UTC)",
|
| 697 |
+
account.Email, account.ID, coolingEndTime.Format("2006-01-02 15:04:05"))
|
| 698 |
+
} else {
|
| 699 |
+
// 默认冷却到第二天的 UTC 0点
|
| 700 |
+
now := time.Now().UTC()
|
| 701 |
+
tomorrow := now.Add(24 * time.Hour)
|
| 702 |
+
account.CoolingUntil = time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, time.UTC)
|
| 703 |
+
log.Printf("[INFO] 账号 %s (ID:%d) 积分耗尽,进入冷却至: %s (UTC)",
|
| 704 |
+
account.Email, account.ID, account.CoolingUntil.Format("2006-01-02 15:04:05"))
|
| 705 |
+
}
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
database.GetDB().Save(account)
|
| 709 |
+
|
| 710 |
+
// 输出调试日志(仅在调试模式下)
|
| 711 |
+
if IsDebugMode() && (requestCost != "" || periodCost != "") {
|
| 712 |
+
log.Printf("[DEBUG] 使用API积分: 账号=%s, RequestCost=%s, PeriodCost=%s, PeriodLimit=%s, PeriodEnd=%s",
|
| 713 |
+
account.Email, requestCost, periodCost, periodLimit, periodEnd)
|
| 714 |
+
}
|
| 715 |
+
} else {
|
| 716 |
+
// 没有API积分信息,使用模型倍率(UseCredit 会自动更新 LastUsed)
|
| 717 |
+
UseCredit(account, modelMultiplier)
|
| 718 |
+
}
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
// parseFloat 安全地解析字符串为浮点数
|
| 722 |
+
func parseFloat(s string) float64 {
|
| 723 |
+
if s == "" {
|
| 724 |
+
return 0
|
| 725 |
+
}
|
| 726 |
+
var val float64
|
| 727 |
+
_, _ = fmt.Sscanf(s, "%f", &val)
|
| 728 |
+
return val
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
// refreshRefreshTokenAccount 使用 refresh_token 刷新账号 token (用于 refresh-token-login 类型的账号)
|
| 732 |
+
func (p *AccountPool) refreshRefreshTokenAccount(account *model.Account) error {
|
| 733 |
+
if account.RefreshToken == "" {
|
| 734 |
+
return fmt.Errorf("账号 %s 缺少 refresh_token", account.ClientID)
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
// 调用 zencoder auth API 刷新 token
|
| 738 |
+
tokenResp, err := RefreshAccessToken(account.RefreshToken, account.Proxy)
|
| 739 |
+
if err != nil {
|
| 740 |
+
return fmt.Errorf("调用 zencoder auth API 失败: %w", err)
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
// 计算过期时间
|
| 744 |
+
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 745 |
+
|
| 746 |
+
// 更新数据库
|
| 747 |
+
updates := map[string]interface{}{
|
| 748 |
+
"access_token": tokenResp.AccessToken,
|
| 749 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 750 |
+
"token_expiry": expiry,
|
| 751 |
+
"updated_at": time.Now(),
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
if err := database.GetDB().Model(&model.Account{}).
|
| 755 |
+
Where("id = ?", account.ID).
|
| 756 |
+
Updates(updates).Error; err != nil {
|
| 757 |
+
return fmt.Errorf("更新数据库失败: %w", err)
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
// 更新内存中的值
|
| 761 |
+
account.AccessToken = tokenResp.AccessToken
|
| 762 |
+
account.RefreshToken = tokenResp.RefreshToken
|
| 763 |
+
account.TokenExpiry = expiry
|
| 764 |
+
|
| 765 |
+
return nil
|
| 766 |
+
}
|
internal/service/provider/anthropic.go
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/anthropics/anthropic-sdk-go"
|
| 10 |
+
"github.com/anthropics/anthropic-sdk-go/option"
|
| 11 |
+
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultAnthropicBaseURL = "https://api.anthropic.com"
|
| 16 |
+
|
| 17 |
+
type AnthropicProvider struct {
|
| 18 |
+
client *anthropic.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewAnthropicProvider(cfg Config) *AnthropicProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultAnthropicBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &AnthropicProvider{
|
| 42 |
+
client: anthropic.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *AnthropicProvider) Name() string {
|
| 48 |
+
return "anthropic"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *AnthropicProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Messages.New(context.Background(), anthropic.MessageNewParams{
|
| 53 |
+
Model: anthropic.F(anthropic.ModelClaude3_5SonnetLatest),
|
| 54 |
+
MaxTokens: anthropic.F(int64(1)),
|
| 55 |
+
Messages: anthropic.F([]anthropic.MessageParam{{Role: anthropic.F(anthropic.MessageParamRoleUser), Content: anthropic.F([]anthropic.ContentBlockParamUnion{anthropic.NewTextBlock("hi")})}}),
|
| 56 |
+
})
|
| 57 |
+
return err
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
func (p *AnthropicProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 61 |
+
messages := p.convertMessages(req.Messages)
|
| 62 |
+
|
| 63 |
+
maxTokens := int64(4096)
|
| 64 |
+
if req.MaxTokens > 0 {
|
| 65 |
+
maxTokens = int64(req.MaxTokens)
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
resp, err := p.client.Messages.New(context.Background(), anthropic.MessageNewParams{
|
| 69 |
+
Model: anthropic.F(req.Model),
|
| 70 |
+
MaxTokens: anthropic.F(maxTokens),
|
| 71 |
+
Messages: anthropic.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *AnthropicProvider) convertMessages(msgs []model.ChatMessage) []anthropic.MessageParam {
|
| 81 |
+
var messages []anthropic.MessageParam
|
| 82 |
+
for _, msg := range msgs {
|
| 83 |
+
if msg.Role == "system" {
|
| 84 |
+
continue
|
| 85 |
+
}
|
| 86 |
+
role := anthropic.MessageParamRoleUser
|
| 87 |
+
if msg.Role == "assistant" {
|
| 88 |
+
role = anthropic.MessageParamRoleAssistant
|
| 89 |
+
}
|
| 90 |
+
messages = append(messages, anthropic.MessageParam{
|
| 91 |
+
Role: anthropic.F(role),
|
| 92 |
+
Content: anthropic.F([]anthropic.ContentBlockParamUnion{anthropic.NewTextBlock(msg.Content)}),
|
| 93 |
+
})
|
| 94 |
+
}
|
| 95 |
+
return messages
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
func (p *AnthropicProvider) convertResponse(resp *anthropic.Message) *model.ChatCompletionResponse {
|
| 99 |
+
content := ""
|
| 100 |
+
for _, block := range resp.Content {
|
| 101 |
+
if block.Type == anthropic.ContentBlockTypeText {
|
| 102 |
+
content += block.Text
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
return &model.ChatCompletionResponse{
|
| 107 |
+
ID: resp.ID,
|
| 108 |
+
Object: "chat.completion",
|
| 109 |
+
Model: string(resp.Model),
|
| 110 |
+
Choices: []model.Choice{{
|
| 111 |
+
Index: 0,
|
| 112 |
+
Message: model.ChatMessage{Role: "assistant", Content: content},
|
| 113 |
+
FinishReason: string(resp.StopReason),
|
| 114 |
+
}},
|
| 115 |
+
Usage: model.Usage{
|
| 116 |
+
PromptTokens: int(resp.Usage.InputTokens),
|
| 117 |
+
CompletionTokens: int(resp.Usage.OutputTokens),
|
| 118 |
+
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
func (p *AnthropicProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 124 |
+
messages := p.convertMessages(req.Messages)
|
| 125 |
+
|
| 126 |
+
maxTokens := int64(4096)
|
| 127 |
+
if req.MaxTokens > 0 {
|
| 128 |
+
maxTokens = int64(req.MaxTokens)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
stream := p.client.Messages.NewStreaming(context.Background(), anthropic.MessageNewParams{
|
| 132 |
+
Model: anthropic.F(req.Model),
|
| 133 |
+
MaxTokens: anthropic.F(maxTokens),
|
| 134 |
+
Messages: anthropic.F(messages),
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
return p.handleStream(stream, writer)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
func (p *AnthropicProvider) handleStream(stream *ssestream.Stream[anthropic.MessageStreamEvent], writer http.ResponseWriter) error {
|
| 141 |
+
flusher, ok := writer.(http.Flusher)
|
| 142 |
+
if !ok {
|
| 143 |
+
return fmt.Errorf("streaming not supported")
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 147 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 148 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 149 |
+
|
| 150 |
+
for stream.Next() {
|
| 151 |
+
event := stream.Current()
|
| 152 |
+
data, _ := json.Marshal(event)
|
| 153 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 154 |
+
flusher.Flush()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if err := stream.Err(); err != nil {
|
| 158 |
+
return err
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 162 |
+
flusher.Flush()
|
| 163 |
+
return nil
|
| 164 |
+
}
|
internal/service/provider/client.go
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"net/http"
|
| 6 |
+
"net/url"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
// NewHTTPClient 创建HTTP客户端
|
| 12 |
+
// 支持HTTP和SOCKS5代理
|
| 13 |
+
func NewHTTPClient(proxy string, timeout time.Duration) *http.Client {
|
| 14 |
+
// 如果代理是SOCKS5格式,使用新的代理客户端创建函数
|
| 15 |
+
if strings.HasPrefix(proxy, "socks5://") {
|
| 16 |
+
client, err := NewHTTPClientWithProxy(proxy, timeout)
|
| 17 |
+
if err != nil {
|
| 18 |
+
log.Printf("创建SOCKS5代理客户端失败: %v, 使用默认客户端", err)
|
| 19 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 20 |
+
return client
|
| 21 |
+
}
|
| 22 |
+
return client
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// 原有的HTTP代理逻辑
|
| 26 |
+
transport := &http.Transport{}
|
| 27 |
+
|
| 28 |
+
if proxy != "" {
|
| 29 |
+
if proxyURL, err := url.Parse(proxy); err == nil {
|
| 30 |
+
transport.Proxy = http.ProxyURL(proxyURL)
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
if timeout == 0 {
|
| 35 |
+
timeout = 600 * time.Second // 10分钟超时,支持长时间流式响应
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return &http.Client{
|
| 39 |
+
Transport: transport,
|
| 40 |
+
Timeout: timeout,
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// NewHTTPClientWithPoolProxy 使用代理池创建HTTP客户端
|
| 45 |
+
func NewHTTPClientWithPoolProxy(useProxy bool, timeout time.Duration) *http.Client {
|
| 46 |
+
if !useProxy {
|
| 47 |
+
// 不使用代理
|
| 48 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 49 |
+
return client
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
pool := GetProxyPool()
|
| 53 |
+
if !pool.HasProxies() {
|
| 54 |
+
// 没有可用代理,使用默认客户端
|
| 55 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 56 |
+
return client
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// 获取下一个代理
|
| 60 |
+
proxyURL := pool.GetNextProxy()
|
| 61 |
+
client, err := NewHTTPClientWithProxy(proxyURL, timeout)
|
| 62 |
+
if err != nil {
|
| 63 |
+
log.Printf("使用代理 %s 创建客户端失败: %v, 使用默认客户端", proxyURL, err)
|
| 64 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 65 |
+
return client
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return client
|
| 69 |
+
}
|
internal/service/provider/errors.go
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import "errors"
|
| 4 |
+
|
| 5 |
+
var (
|
| 6 |
+
ErrStreamNotSupported = errors.New("streaming not supported")
|
| 7 |
+
ErrInvalidToken = errors.New("invalid token")
|
| 8 |
+
ErrRequestFailed = errors.New("request failed")
|
| 9 |
+
ErrUnknownProvider = errors.New("unknown provider")
|
| 10 |
+
)
|
internal/service/provider/factory.go
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import "fmt"
|
| 4 |
+
|
| 5 |
+
type ProviderType string
|
| 6 |
+
|
| 7 |
+
const (
|
| 8 |
+
ProviderOpenAI ProviderType = "openai"
|
| 9 |
+
ProviderAnthropic ProviderType = "anthropic"
|
| 10 |
+
ProviderGemini ProviderType = "gemini"
|
| 11 |
+
ProviderGrok ProviderType = "grok"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
func NewProvider(providerType ProviderType, cfg Config) (Provider, error) {
|
| 15 |
+
switch providerType {
|
| 16 |
+
case ProviderOpenAI:
|
| 17 |
+
return NewOpenAIProvider(cfg), nil
|
| 18 |
+
case ProviderAnthropic:
|
| 19 |
+
return NewAnthropicProvider(cfg), nil
|
| 20 |
+
case ProviderGemini:
|
| 21 |
+
return NewGeminiProvider(cfg)
|
| 22 |
+
case ProviderGrok:
|
| 23 |
+
return NewGrokProvider(cfg), nil
|
| 24 |
+
default:
|
| 25 |
+
return nil, fmt.Errorf("%w: %s", ErrUnknownProvider, providerType)
|
| 26 |
+
}
|
| 27 |
+
}
|
internal/service/provider/gemini.go
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/google/generative-ai-go/genai"
|
| 10 |
+
"google.golang.org/api/option"
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
type GeminiProvider struct {
|
| 15 |
+
client *genai.Client
|
| 16 |
+
config Config
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
func NewGeminiProvider(cfg Config) (*GeminiProvider, error) {
|
| 20 |
+
ctx := context.Background()
|
| 21 |
+
|
| 22 |
+
opts := []option.ClientOption{
|
| 23 |
+
option.WithAPIKey(cfg.APIKey),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if cfg.BaseURL != "" {
|
| 27 |
+
opts = append(opts, option.WithEndpoint(cfg.BaseURL))
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
client, err := genai.NewClient(ctx, opts...)
|
| 31 |
+
if err != nil {
|
| 32 |
+
return nil, err
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
return &GeminiProvider{
|
| 36 |
+
client: client,
|
| 37 |
+
config: cfg,
|
| 38 |
+
}, nil
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func (p *GeminiProvider) Name() string {
|
| 42 |
+
return "gemini"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (p *GeminiProvider) ValidateToken() error {
|
| 46 |
+
model := p.client.GenerativeModel("gemini-1.5-flash")
|
| 47 |
+
_, err := model.GenerateContent(context.Background(), genai.Text("hi"))
|
| 48 |
+
return err
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *GeminiProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 52 |
+
geminiModel := p.client.GenerativeModel(req.Model)
|
| 53 |
+
|
| 54 |
+
var parts []genai.Part
|
| 55 |
+
for _, msg := range req.Messages {
|
| 56 |
+
parts = append(parts, genai.Text(msg.Content))
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
resp, err := geminiModel.GenerateContent(context.Background(), parts...)
|
| 60 |
+
if err != nil {
|
| 61 |
+
return nil, err
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
return p.convertResponse(resp, req.Model), nil
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
func (p *GeminiProvider) convertResponse(resp *genai.GenerateContentResponse, modelName string) *model.ChatCompletionResponse {
|
| 68 |
+
content := ""
|
| 69 |
+
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
| 70 |
+
for _, part := range resp.Candidates[0].Content.Parts {
|
| 71 |
+
if text, ok := part.(genai.Text); ok {
|
| 72 |
+
content += string(text)
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return &model.ChatCompletionResponse{
|
| 78 |
+
ID: "gemini-" + modelName,
|
| 79 |
+
Object: "chat.completion",
|
| 80 |
+
Model: modelName,
|
| 81 |
+
Choices: []model.Choice{{
|
| 82 |
+
Index: 0,
|
| 83 |
+
Message: model.ChatMessage{Role: "assistant", Content: content},
|
| 84 |
+
FinishReason: "stop",
|
| 85 |
+
}},
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
func (p *GeminiProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 90 |
+
geminiModel := p.client.GenerativeModel(req.Model)
|
| 91 |
+
|
| 92 |
+
var parts []genai.Part
|
| 93 |
+
for _, msg := range req.Messages {
|
| 94 |
+
parts = append(parts, genai.Text(msg.Content))
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
iter := geminiModel.GenerateContentStream(context.Background(), parts...)
|
| 98 |
+
return p.handleStream(iter, writer)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
func (p *GeminiProvider) handleStream(iter *genai.GenerateContentResponseIterator, writer http.ResponseWriter) error {
|
| 102 |
+
flusher, ok := writer.(http.Flusher)
|
| 103 |
+
if !ok {
|
| 104 |
+
return fmt.Errorf("streaming not supported")
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 108 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 109 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 110 |
+
|
| 111 |
+
for {
|
| 112 |
+
resp, err := iter.Next()
|
| 113 |
+
if err != nil {
|
| 114 |
+
break
|
| 115 |
+
}
|
| 116 |
+
data, _ := json.Marshal(resp)
|
| 117 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 118 |
+
flusher.Flush()
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 122 |
+
flusher.Flush()
|
| 123 |
+
return nil
|
| 124 |
+
}
|
internal/service/provider/grok.go
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/openai/openai-go"
|
| 10 |
+
"github.com/openai/openai-go/option"
|
| 11 |
+
"github.com/openai/openai-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultGrokBaseURL = "https://api.x.ai/v1"
|
| 16 |
+
|
| 17 |
+
type GrokProvider struct {
|
| 18 |
+
client *openai.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewGrokProvider(cfg Config) *GrokProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultGrokBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &GrokProvider{
|
| 42 |
+
client: openai.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *GrokProvider) Name() string {
|
| 48 |
+
return "grok"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *GrokProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Models.List(context.Background())
|
| 53 |
+
return err
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func (p *GrokProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 57 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 58 |
+
for i, msg := range req.Messages {
|
| 59 |
+
switch msg.Role {
|
| 60 |
+
case "system":
|
| 61 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 62 |
+
case "user":
|
| 63 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 64 |
+
case "assistant":
|
| 65 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp, err := p.client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
|
| 70 |
+
Model: openai.F(req.Model),
|
| 71 |
+
Messages: openai.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *GrokProvider) convertResponse(resp *openai.ChatCompletion) *model.ChatCompletionResponse {
|
| 81 |
+
choices := make([]model.Choice, len(resp.Choices))
|
| 82 |
+
for i, c := range resp.Choices {
|
| 83 |
+
choices[i] = model.Choice{
|
| 84 |
+
Index: int(c.Index),
|
| 85 |
+
Message: model.ChatMessage{Role: string(c.Message.Role), Content: c.Message.Content},
|
| 86 |
+
FinishReason: string(c.FinishReason),
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return &model.ChatCompletionResponse{
|
| 91 |
+
ID: resp.ID,
|
| 92 |
+
Object: string(resp.Object),
|
| 93 |
+
Created: resp.Created,
|
| 94 |
+
Model: resp.Model,
|
| 95 |
+
Choices: choices,
|
| 96 |
+
Usage: model.Usage{
|
| 97 |
+
PromptTokens: int(resp.Usage.PromptTokens),
|
| 98 |
+
CompletionTokens: int(resp.Usage.CompletionTokens),
|
| 99 |
+
TotalTokens: int(resp.Usage.TotalTokens),
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
func (p *GrokProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 105 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 106 |
+
for i, msg := range req.Messages {
|
| 107 |
+
switch msg.Role {
|
| 108 |
+
case "system":
|
| 109 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 110 |
+
case "user":
|
| 111 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 112 |
+
case "assistant":
|
| 113 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
stream := p.client.Chat.Completions.NewStreaming(context.Background(), openai.ChatCompletionNewParams{
|
| 118 |
+
Model: openai.F(req.Model),
|
| 119 |
+
Messages: openai.F(messages),
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
return p.handleStream(stream, writer)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
func (p *GrokProvider) handleStream(stream *ssestream.Stream[openai.ChatCompletionChunk], writer http.ResponseWriter) error {
|
| 126 |
+
flusher, ok := writer.(http.Flusher)
|
| 127 |
+
if !ok {
|
| 128 |
+
return fmt.Errorf("streaming not supported")
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 132 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 133 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 134 |
+
|
| 135 |
+
for stream.Next() {
|
| 136 |
+
chunk := stream.Current()
|
| 137 |
+
data, _ := json.Marshal(chunk)
|
| 138 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 139 |
+
flusher.Flush()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if err := stream.Err(); err != nil {
|
| 143 |
+
return err
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 147 |
+
flusher.Flush()
|
| 148 |
+
return nil
|
| 149 |
+
}
|
internal/service/provider/manager.go
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"sync"
|
| 6 |
+
|
| 7 |
+
"zencoder-2api/internal/model"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
// Manager Provider管理器,缓存已创建的provider实例
|
| 11 |
+
type Manager struct {
|
| 12 |
+
mu sync.RWMutex
|
| 13 |
+
providers map[string]Provider
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
var defaultManager = &Manager{
|
| 17 |
+
providers: make(map[string]Provider),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// GetManager 获取默认管理器
|
| 21 |
+
func GetManager() *Manager {
|
| 22 |
+
return defaultManager
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// GetProvider 根据账号和模型获取或创建Provider
|
| 26 |
+
func (m *Manager) GetProvider(accountID uint, zenModel model.ZenModel, cfg Config) (Provider, error) {
|
| 27 |
+
key := m.buildKey(accountID, zenModel.ProviderID)
|
| 28 |
+
|
| 29 |
+
m.mu.RLock()
|
| 30 |
+
if p, ok := m.providers[key]; ok {
|
| 31 |
+
m.mu.RUnlock()
|
| 32 |
+
return p, nil
|
| 33 |
+
}
|
| 34 |
+
m.mu.RUnlock()
|
| 35 |
+
|
| 36 |
+
return m.createProvider(key, zenModel.ProviderID, cfg)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func (m *Manager) buildKey(accountID uint, providerID string) string {
|
| 40 |
+
return fmt.Sprintf("%d:%s", accountID, providerID)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
func (m *Manager) createProvider(key, providerID string, cfg Config) (Provider, error) {
|
| 44 |
+
m.mu.Lock()
|
| 45 |
+
defer m.mu.Unlock()
|
| 46 |
+
|
| 47 |
+
// 双重检查
|
| 48 |
+
if p, ok := m.providers[key]; ok {
|
| 49 |
+
return p, nil
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
var providerType ProviderType
|
| 53 |
+
switch providerID {
|
| 54 |
+
case "openai":
|
| 55 |
+
providerType = ProviderOpenAI
|
| 56 |
+
case "anthropic":
|
| 57 |
+
providerType = ProviderAnthropic
|
| 58 |
+
case "gemini":
|
| 59 |
+
providerType = ProviderGemini
|
| 60 |
+
case "xai":
|
| 61 |
+
providerType = ProviderGrok
|
| 62 |
+
default:
|
| 63 |
+
providerType = ProviderAnthropic
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
p, err := NewProvider(providerType, cfg)
|
| 67 |
+
if err != nil {
|
| 68 |
+
return nil, err
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
m.providers[key] = p
|
| 72 |
+
return p, nil
|
| 73 |
+
}
|
internal/service/provider/openai.go
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/openai/openai-go"
|
| 10 |
+
"github.com/openai/openai-go/option"
|
| 11 |
+
"github.com/openai/openai-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
| 16 |
+
|
| 17 |
+
type OpenAIProvider struct {
|
| 18 |
+
client *openai.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewOpenAIProvider(cfg Config) *OpenAIProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultOpenAIBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &OpenAIProvider{
|
| 42 |
+
client: openai.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *OpenAIProvider) Name() string {
|
| 48 |
+
return "openai"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *OpenAIProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Models.List(context.Background())
|
| 53 |
+
return err
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func (p *OpenAIProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 57 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 58 |
+
for i, msg := range req.Messages {
|
| 59 |
+
switch msg.Role {
|
| 60 |
+
case "system":
|
| 61 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 62 |
+
case "user":
|
| 63 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 64 |
+
case "assistant":
|
| 65 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp, err := p.client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
|
| 70 |
+
Model: openai.F(req.Model),
|
| 71 |
+
Messages: openai.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *OpenAIProvider) convertResponse(resp *openai.ChatCompletion) *model.ChatCompletionResponse {
|
| 81 |
+
choices := make([]model.Choice, len(resp.Choices))
|
| 82 |
+
for i, c := range resp.Choices {
|
| 83 |
+
choices[i] = model.Choice{
|
| 84 |
+
Index: int(c.Index),
|
| 85 |
+
Message: model.ChatMessage{Role: string(c.Message.Role), Content: c.Message.Content},
|
| 86 |
+
FinishReason: string(c.FinishReason),
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return &model.ChatCompletionResponse{
|
| 91 |
+
ID: resp.ID,
|
| 92 |
+
Object: string(resp.Object),
|
| 93 |
+
Created: resp.Created,
|
| 94 |
+
Model: resp.Model,
|
| 95 |
+
Choices: choices,
|
| 96 |
+
Usage: model.Usage{
|
| 97 |
+
PromptTokens: int(resp.Usage.PromptTokens),
|
| 98 |
+
CompletionTokens: int(resp.Usage.CompletionTokens),
|
| 99 |
+
TotalTokens: int(resp.Usage.TotalTokens),
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
func (p *OpenAIProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 105 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 106 |
+
for i, msg := range req.Messages {
|
| 107 |
+
switch msg.Role {
|
| 108 |
+
case "system":
|
| 109 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 110 |
+
case "user":
|
| 111 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 112 |
+
case "assistant":
|
| 113 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
stream := p.client.Chat.Completions.NewStreaming(context.Background(), openai.ChatCompletionNewParams{
|
| 118 |
+
Model: openai.F(req.Model),
|
| 119 |
+
Messages: openai.F(messages),
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
return p.handleStream(stream, writer)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
func (p *OpenAIProvider) handleStream(stream *ssestream.Stream[openai.ChatCompletionChunk], writer http.ResponseWriter) error {
|
| 126 |
+
flusher, ok := writer.(http.Flusher)
|
| 127 |
+
if !ok {
|
| 128 |
+
return fmt.Errorf("streaming not supported")
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 132 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 133 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 134 |
+
|
| 135 |
+
for stream.Next() {
|
| 136 |
+
chunk := stream.Current()
|
| 137 |
+
data, _ := json.Marshal(chunk)
|
| 138 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 139 |
+
flusher.Flush()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if err := stream.Err(); err != nil {
|
| 143 |
+
return err
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 147 |
+
flusher.Flush()
|
| 148 |
+
return nil
|
| 149 |
+
}
|
internal/service/provider/provider.go
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"io"
|
| 5 |
+
"net/http"
|
| 6 |
+
|
| 7 |
+
"zencoder-2api/internal/model"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
// Config Provider配置
|
| 11 |
+
type Config struct {
|
| 12 |
+
BaseURL string // 自定义请求地址
|
| 13 |
+
APIKey string // API密钥
|
| 14 |
+
ExtraHeaders map[string]string // 额外请求头
|
| 15 |
+
Proxy string // 代理地址
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
// Provider AI平台提供者接口
|
| 19 |
+
type Provider interface {
|
| 20 |
+
// Name 返回提供者名称
|
| 21 |
+
Name() string
|
| 22 |
+
|
| 23 |
+
// Chat 非流式聊天
|
| 24 |
+
Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error)
|
| 25 |
+
|
| 26 |
+
// ChatStream 流式聊天
|
| 27 |
+
ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error
|
| 28 |
+
|
| 29 |
+
// ValidateToken 验证token是否有效
|
| 30 |
+
ValidateToken() error
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// BaseProvider 基础Provider实现
|
| 34 |
+
type BaseProvider struct {
|
| 35 |
+
Config Config
|
| 36 |
+
Client *http.Client
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// SetHeaders 设置通用请求头
|
| 40 |
+
func (b *BaseProvider) SetHeaders(req *http.Request) {
|
| 41 |
+
req.Header.Set("Content-Type", "application/json")
|
| 42 |
+
req.Header.Set("Accept", "application/json")
|
| 43 |
+
|
| 44 |
+
// 设置额外请求头
|
| 45 |
+
for k, v := range b.Config.ExtraHeaders {
|
| 46 |
+
req.Header.Set(k, v)
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// StreamResponse 通用流式响应处理
|
| 51 |
+
func (b *BaseProvider) StreamResponse(body io.Reader, writer http.ResponseWriter) error {
|
| 52 |
+
flusher, ok := writer.(http.Flusher)
|
| 53 |
+
if !ok {
|
| 54 |
+
return ErrStreamNotSupported
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 58 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 59 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 60 |
+
|
| 61 |
+
buf := make([]byte, 4096)
|
| 62 |
+
for {
|
| 63 |
+
n, err := body.Read(buf)
|
| 64 |
+
if n > 0 {
|
| 65 |
+
writer.Write(buf[:n])
|
| 66 |
+
flusher.Flush()
|
| 67 |
+
}
|
| 68 |
+
if err == io.EOF {
|
| 69 |
+
break
|
| 70 |
+
}
|
| 71 |
+
if err != nil {
|
| 72 |
+
return err
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return nil
|
| 77 |
+
}
|
internal/service/provider/proxy.go
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"math/rand"
|
| 7 |
+
"net"
|
| 8 |
+
"net/http"
|
| 9 |
+
"net/url"
|
| 10 |
+
"os"
|
| 11 |
+
"strings"
|
| 12 |
+
"sync"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"golang.org/x/net/proxy"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// ProxyPool 代理池管理器
|
| 19 |
+
type ProxyPool struct {
|
| 20 |
+
proxies []string
|
| 21 |
+
mu sync.RWMutex
|
| 22 |
+
index int
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
var (
|
| 26 |
+
globalProxyPool *ProxyPool
|
| 27 |
+
once sync.Once
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
// GetProxyPool 获取全局代理池实例
|
| 31 |
+
func GetProxyPool() *ProxyPool {
|
| 32 |
+
once.Do(func() {
|
| 33 |
+
globalProxyPool = NewProxyPool()
|
| 34 |
+
})
|
| 35 |
+
return globalProxyPool
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// NewProxyPool 创建新的代理池
|
| 39 |
+
func NewProxyPool() *ProxyPool {
|
| 40 |
+
pool := &ProxyPool{
|
| 41 |
+
proxies: make([]string, 0),
|
| 42 |
+
}
|
| 43 |
+
pool.loadProxiesFromEnv()
|
| 44 |
+
return pool
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// loadProxiesFromEnv 从环境变量加载代理列表
|
| 48 |
+
func (p *ProxyPool) loadProxiesFromEnv() {
|
| 49 |
+
p.mu.Lock()
|
| 50 |
+
defer p.mu.Unlock()
|
| 51 |
+
|
| 52 |
+
proxyEnv := os.Getenv("SOCKS_PROXY_POOL")
|
| 53 |
+
if proxyEnv == "" {
|
| 54 |
+
return
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// 解析逗号分隔的代理列表
|
| 58 |
+
proxiesStr := strings.Split(proxyEnv, ",")
|
| 59 |
+
for _, proxyStr := range proxiesStr {
|
| 60 |
+
proxyStr = strings.TrimSpace(proxyStr)
|
| 61 |
+
if proxyStr != "" {
|
| 62 |
+
p.proxies = append(p.proxies, proxyStr)
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// GetNextProxy 获取下一个代理(轮询方式)
|
| 68 |
+
func (p *ProxyPool) GetNextProxy() string {
|
| 69 |
+
p.mu.Lock()
|
| 70 |
+
defer p.mu.Unlock()
|
| 71 |
+
|
| 72 |
+
if len(p.proxies) == 0 {
|
| 73 |
+
return ""
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
proxy := p.proxies[p.index]
|
| 77 |
+
p.index = (p.index + 1) % len(p.proxies)
|
| 78 |
+
return proxy
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// GetRandomProxy 获取随机代理
|
| 82 |
+
func (p *ProxyPool) GetRandomProxy() string {
|
| 83 |
+
p.mu.RLock()
|
| 84 |
+
defer p.mu.RUnlock()
|
| 85 |
+
|
| 86 |
+
if len(p.proxies) == 0 {
|
| 87 |
+
return ""
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
index := rand.Intn(len(p.proxies))
|
| 91 |
+
return p.proxies[index]
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// HasProxies 检查是否有可用代理
|
| 95 |
+
func (p *ProxyPool) HasProxies() bool {
|
| 96 |
+
p.mu.RLock()
|
| 97 |
+
defer p.mu.RUnlock()
|
| 98 |
+
return len(p.proxies) > 0
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Count 返回代理数量
|
| 102 |
+
func (p *ProxyPool) Count() int {
|
| 103 |
+
p.mu.RLock()
|
| 104 |
+
defer p.mu.RUnlock()
|
| 105 |
+
return len(p.proxies)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// GetAllProxies 获取所有代理列表(用于测试)
|
| 109 |
+
func (p *ProxyPool) GetAllProxies() []string {
|
| 110 |
+
p.mu.RLock()
|
| 111 |
+
defer p.mu.RUnlock()
|
| 112 |
+
result := make([]string, len(p.proxies))
|
| 113 |
+
copy(result, p.proxies)
|
| 114 |
+
return result
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// createSOCKS5Transport 创建SOCKS5代理传输层
|
| 118 |
+
func createSOCKS5Transport(proxyURL string, timeout time.Duration) (*http.Transport, error) {
|
| 119 |
+
// 处理自定义格式:socks5://host:port:username:password
|
| 120 |
+
// 转换为标准格式:socks5://username:password@host:port
|
| 121 |
+
if strings.Contains(proxyURL, "socks5://") && strings.Count(proxyURL, ":") == 4 {
|
| 122 |
+
// 解析自定义格式
|
| 123 |
+
parts := strings.Split(proxyURL, ":")
|
| 124 |
+
if len(parts) == 5 {
|
| 125 |
+
// parts[0] = "socks5", parts[1] = "//host", parts[2] = "port", parts[3] = "username", parts[4] = "password"
|
| 126 |
+
host := strings.TrimPrefix(parts[1], "//")
|
| 127 |
+
port := parts[2]
|
| 128 |
+
username := parts[3]
|
| 129 |
+
password := parts[4]
|
| 130 |
+
|
| 131 |
+
// 重构为标准URL格式
|
| 132 |
+
proxyURL = fmt.Sprintf("socks5://%s:%s@%s:%s", username, password, host, port)
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
u, err := url.Parse(proxyURL)
|
| 137 |
+
if err != nil {
|
| 138 |
+
return nil, fmt.Errorf("解析代理URL失败: %v", err)
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
if u.Scheme != "socks5" {
|
| 142 |
+
return nil, fmt.Errorf("仅支持SOCKS5代理")
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// 解析用户名和密码
|
| 146 |
+
var auth *proxy.Auth
|
| 147 |
+
if u.User != nil {
|
| 148 |
+
password, _ := u.User.Password()
|
| 149 |
+
auth = &proxy.Auth{
|
| 150 |
+
User: u.User.Username(),
|
| 151 |
+
Password: password,
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// 创建SOCKS5拨号器
|
| 156 |
+
dialer, err := proxy.SOCKS5("tcp", u.Host, auth, proxy.Direct)
|
| 157 |
+
if err != nil {
|
| 158 |
+
return nil, fmt.Errorf("创建SOCKS5拨号器失败: %v", err)
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
transport := &http.Transport{
|
| 162 |
+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
| 163 |
+
return dialer.Dial(network, addr)
|
| 164 |
+
},
|
| 165 |
+
MaxIdleConns: 100,
|
| 166 |
+
IdleConnTimeout: 90 * time.Second,
|
| 167 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
return transport, nil
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// parseCustomProxyURL 解析自定义代理URL格式
|
| 174 |
+
func parseCustomProxyURL(proxyURL string) string {
|
| 175 |
+
// 处理自定义格式:socks5://host:port:username:password
|
| 176 |
+
// 转换为标准格式:socks5://username:password@host:port
|
| 177 |
+
if strings.Contains(proxyURL, "socks5://") && strings.Count(proxyURL, ":") == 4 {
|
| 178 |
+
// 解析自定义格式
|
| 179 |
+
parts := strings.Split(proxyURL, ":")
|
| 180 |
+
if len(parts) == 5 {
|
| 181 |
+
// parts[0] = "socks5", parts[1] = "//host", parts[2] = "port", parts[3] = "username", parts[4] = "password"
|
| 182 |
+
host := strings.TrimPrefix(parts[1], "//")
|
| 183 |
+
port := parts[2]
|
| 184 |
+
username := parts[3]
|
| 185 |
+
password := parts[4]
|
| 186 |
+
|
| 187 |
+
// 重构为标准URL格式
|
| 188 |
+
return fmt.Sprintf("socks5://%s:%s@%s:%s", username, password, host, port)
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
return proxyURL
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// NewHTTPClientWithProxy 创建带指定代理的HTTP客户端
|
| 195 |
+
func NewHTTPClientWithProxy(proxyURL string, timeout time.Duration) (*http.Client, error) {
|
| 196 |
+
if timeout == 0 {
|
| 197 |
+
timeout = 600 * time.Second // 10分钟超时,支持长时间流式响应
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if proxyURL == "" {
|
| 201 |
+
// 没有代理,使用默认客户端
|
| 202 |
+
return &http.Client{
|
| 203 |
+
Transport: &http.Transport{
|
| 204 |
+
DialContext: (&net.Dialer{
|
| 205 |
+
Timeout: 30 * time.Second,
|
| 206 |
+
KeepAlive: 30 * time.Second,
|
| 207 |
+
}).DialContext,
|
| 208 |
+
MaxIdleConns: 100,
|
| 209 |
+
IdleConnTimeout: 90 * time.Second,
|
| 210 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 211 |
+
},
|
| 212 |
+
Timeout: timeout,
|
| 213 |
+
}, nil
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// 转换自定义格式到标准格式
|
| 217 |
+
standardURL := parseCustomProxyURL(proxyURL)
|
| 218 |
+
|
| 219 |
+
// 解析代理URL
|
| 220 |
+
u, err := url.Parse(standardURL)
|
| 221 |
+
if err != nil {
|
| 222 |
+
return nil, fmt.Errorf("解析代理URL失败: %v", err)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
var transport *http.Transport
|
| 226 |
+
|
| 227 |
+
if u.Scheme == "socks5" {
|
| 228 |
+
// SOCKS5代理 - 使用转换后的标准URL
|
| 229 |
+
transport, err = createSOCKS5Transport(standardURL, timeout)
|
| 230 |
+
if err != nil {
|
| 231 |
+
return nil, err
|
| 232 |
+
}
|
| 233 |
+
} else {
|
| 234 |
+
// HTTP代理
|
| 235 |
+
transport = &http.Transport{
|
| 236 |
+
Proxy: http.ProxyURL(u),
|
| 237 |
+
DialContext: (&net.Dialer{
|
| 238 |
+
Timeout: 30 * time.Second,
|
| 239 |
+
KeepAlive: 30 * time.Second,
|
| 240 |
+
}).DialContext,
|
| 241 |
+
MaxIdleConns: 100,
|
| 242 |
+
IdleConnTimeout: 90 * time.Second,
|
| 243 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
return &http.Client{
|
| 248 |
+
Transport: transport,
|
| 249 |
+
Timeout: timeout,
|
| 250 |
+
}, nil
|
| 251 |
+
}
|
internal/service/proxy_client.go
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"net/http"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
|
| 10 |
+
"zencoder-2api/internal/service/provider"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
// ProxyRequestOptions 代理请求选项
|
| 14 |
+
type ProxyRequestOptions struct {
|
| 15 |
+
UseProxy bool // 是否使用代理
|
| 16 |
+
MaxRetries int // 最大重试次数
|
| 17 |
+
RetryDelay time.Duration // 重试延迟
|
| 18 |
+
OnError func(error) bool // 错误判断函数,返回true表示需要重试
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// DefaultProxyRequestOptions 默认代理请求选项
|
| 22 |
+
func DefaultProxyRequestOptions() ProxyRequestOptions {
|
| 23 |
+
return ProxyRequestOptions{
|
| 24 |
+
UseProxy: true,
|
| 25 |
+
MaxRetries: 3,
|
| 26 |
+
RetryDelay: time.Second,
|
| 27 |
+
OnError: isNetworkError,
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// isNetworkError 判断是否为网络错误(可重试的错误)
|
| 32 |
+
func isNetworkError(err error) bool {
|
| 33 |
+
if err == nil {
|
| 34 |
+
return false
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
errStr := err.Error()
|
| 38 |
+
|
| 39 |
+
// 网络连接错误
|
| 40 |
+
if strings.Contains(errStr, "connection refused") ||
|
| 41 |
+
strings.Contains(errStr, "connection reset") ||
|
| 42 |
+
strings.Contains(errStr, "connection timed out") ||
|
| 43 |
+
strings.Contains(errStr, "timeout") ||
|
| 44 |
+
strings.Contains(errStr, "network is unreachable") ||
|
| 45 |
+
strings.Contains(errStr, "no such host") ||
|
| 46 |
+
strings.Contains(errStr, "dial tcp") ||
|
| 47 |
+
strings.Contains(errStr, "i/o timeout") ||
|
| 48 |
+
strings.Contains(errStr, "EOF") {
|
| 49 |
+
return true
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// SOCKS代理相关错误
|
| 53 |
+
if strings.Contains(errStr, "socks connect") ||
|
| 54 |
+
strings.Contains(errStr, "proxy") {
|
| 55 |
+
return true
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return false
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// DoRequestWithProxyRetry 执行带代理重试的HTTP请求
|
| 62 |
+
func DoRequestWithProxyRetry(ctx context.Context, req *http.Request, originalProxy string, options ProxyRequestOptions) (*http.Response, error) {
|
| 63 |
+
// 首先尝试使用原始代理(如果有的话)
|
| 64 |
+
client := provider.NewHTTPClient(originalProxy, 0)
|
| 65 |
+
|
| 66 |
+
resp, err := client.Do(req)
|
| 67 |
+
if err == nil {
|
| 68 |
+
// 请求成功,返回结果
|
| 69 |
+
return resp, nil
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// 检查错误是否可重试
|
| 73 |
+
if !options.OnError(err) {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// 如果不使用代理池,直接返回错误
|
| 78 |
+
if !options.UseProxy {
|
| 79 |
+
return nil, err
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
proxyPool := provider.GetProxyPool()
|
| 83 |
+
if !proxyPool.HasProxies() {
|
| 84 |
+
return nil, fmt.Errorf("原始请求失败且无可用代理: %v", err)
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
var lastErr error = err
|
| 88 |
+
|
| 89 |
+
// 使用代理池进行重试
|
| 90 |
+
for i := 0; i < options.MaxRetries; i++ {
|
| 91 |
+
// 获取下一个代理
|
| 92 |
+
proxyURL := proxyPool.GetNextProxy()
|
| 93 |
+
if proxyURL == "" {
|
| 94 |
+
break
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// 创建使用代理的HTTP客户端
|
| 98 |
+
proxyClient, clientErr := provider.NewHTTPClientWithProxy(proxyURL, 0)
|
| 99 |
+
if clientErr != nil {
|
| 100 |
+
lastErr = clientErr
|
| 101 |
+
time.Sleep(options.RetryDelay)
|
| 102 |
+
continue
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// 克隆请求(因为request body可能已经被消费)
|
| 106 |
+
newReq := req.Clone(ctx)
|
| 107 |
+
|
| 108 |
+
resp, err := proxyClient.Do(newReq)
|
| 109 |
+
if err == nil {
|
| 110 |
+
// 请求成功
|
| 111 |
+
return resp, nil
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// 检查错误是否继续重试
|
| 115 |
+
if !options.OnError(err) {
|
| 116 |
+
return nil, err
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
lastErr = err
|
| 120 |
+
time.Sleep(options.RetryDelay)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
return nil, fmt.Errorf("所有代理重试均失败,最后错误: %v", lastErr)
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// CreateHTTPClientWithFallback 创建支持代理fallback的HTTP客户端
|
| 127 |
+
func CreateHTTPClientWithFallback(originalProxy string, useProxyPool bool) *http.Client {
|
| 128 |
+
// 如果不使用代理池,使用原始逻辑
|
| 129 |
+
if !useProxyPool {
|
| 130 |
+
return provider.NewHTTPClient(originalProxy, 0)
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// 如果有原始代理,先尝试原始代理
|
| 134 |
+
if originalProxy != "" {
|
| 135 |
+
client, err := provider.NewHTTPClientWithProxy(originalProxy, 0)
|
| 136 |
+
if err == nil {
|
| 137 |
+
return client
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// 使用代理池
|
| 142 |
+
return provider.NewHTTPClientWithPoolProxy(true, 0)
|
| 143 |
+
}
|
internal/service/refresh.go
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"log"
|
| 9 |
+
"net/http"
|
| 10 |
+
"strings"
|
| 11 |
+
"time"
|
| 12 |
+
|
| 13 |
+
"zencoder-2api/internal/model"
|
| 14 |
+
"zencoder-2api/internal/database"
|
| 15 |
+
"zencoder-2api/internal/service/provider"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// RefreshTokenRequest 请求刷新token的结构
|
| 19 |
+
type RefreshTokenRequest struct {
|
| 20 |
+
GrantType string `json:"grant_type"`
|
| 21 |
+
RefreshToken string `json:"refresh_token"`
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// RefreshTokenResponse 刷新token的响应结构
|
| 25 |
+
type RefreshTokenResponse struct {
|
| 26 |
+
TokenType string `json:"token_type"`
|
| 27 |
+
AccessToken string `json:"access_token"`
|
| 28 |
+
IDToken string `json:"id_token"`
|
| 29 |
+
RefreshToken string `json:"refresh_token"`
|
| 30 |
+
ExpiresIn int `json:"expires_in"`
|
| 31 |
+
Federated map[string]interface{} `json:"federated"`
|
| 32 |
+
|
| 33 |
+
// 这些字段可能不在响应中,但我们可以从JWT解析
|
| 34 |
+
UserID string `json:"-"`
|
| 35 |
+
Email string `json:"-"`
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// AccountLockoutError 表示账号被锁定的错误
|
| 39 |
+
type AccountLockoutError struct {
|
| 40 |
+
StatusCode int
|
| 41 |
+
Body string
|
| 42 |
+
AccountID string
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (e *AccountLockoutError) Error() string {
|
| 46 |
+
return fmt.Sprintf("account %s is locked out: status %d, body: %s", e.AccountID, e.StatusCode, e.Body)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// isAccountLockoutError 检查是否是账号锁定错误
|
| 50 |
+
func isAccountLockoutError(statusCode int, body string) bool {
|
| 51 |
+
if statusCode == 400 {
|
| 52 |
+
// 检查响应体中是否包含锁定信息
|
| 53 |
+
return strings.Contains(body, "User is locked out") ||
|
| 54 |
+
strings.Contains(body, "user is locked out") ||
|
| 55 |
+
strings.Contains(body, "locked out")
|
| 56 |
+
}
|
| 57 |
+
return false
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
// markAccountAsBanned 将账号标记为被封禁状态
|
| 61 |
+
func markAccountAsBanned(account *model.Account, reason string) error {
|
| 62 |
+
updates := map[string]interface{}{
|
| 63 |
+
"status": "banned",
|
| 64 |
+
"is_active": false,
|
| 65 |
+
"is_cooling": false,
|
| 66 |
+
"ban_reason": reason,
|
| 67 |
+
"updated_at": time.Now(),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if err := database.GetDB().Model(&model.Account{}).
|
| 71 |
+
Where("id = ?", account.ID).
|
| 72 |
+
Updates(updates).Error; err != nil {
|
| 73 |
+
return fmt.Errorf("failed to update account status: %w", err)
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
log.Printf("[账号管理] 账号 %s (ID:%d) 已标记为封禁状态: %s", account.ClientID, account.ID, reason)
|
| 77 |
+
return nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// isRefreshTokenInvalidError 检查是否是refresh token无效错误
|
| 81 |
+
func isRefreshTokenInvalidError(statusCode int, body string) bool {
|
| 82 |
+
if statusCode == 401 {
|
| 83 |
+
return strings.Contains(body, "Refresh token is not valid") ||
|
| 84 |
+
strings.Contains(body, "refresh token is not valid") ||
|
| 85 |
+
strings.Contains(body, "invalid refresh token") ||
|
| 86 |
+
strings.Contains(body, "refresh_token is invalid")
|
| 87 |
+
}
|
| 88 |
+
return false
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// markTokenRecordAsBanned 将token记录标记为封禁状态
|
| 92 |
+
func markTokenRecordAsBanned(record *model.TokenRecord, reason string) error {
|
| 93 |
+
updates := map[string]interface{}{
|
| 94 |
+
"status": "banned",
|
| 95 |
+
"is_active": false,
|
| 96 |
+
"ban_reason": reason,
|
| 97 |
+
"updated_at": time.Now(),
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if err := database.GetDB().Model(&model.TokenRecord{}).
|
| 101 |
+
Where("id = ?", record.ID).
|
| 102 |
+
Updates(updates).Error; err != nil {
|
| 103 |
+
return fmt.Errorf("failed to update token record status: %w", err)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
log.Printf("[Token管理] Token记录 #%d 已标记为封禁状态: %s", record.ID, reason)
|
| 107 |
+
return nil
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// markTokenRecordAsExpired 将token记录标记为过期状态
|
| 111 |
+
func markTokenRecordAsExpired(record *model.TokenRecord, reason string) error {
|
| 112 |
+
updates := map[string]interface{}{
|
| 113 |
+
"status": "expired",
|
| 114 |
+
"is_active": false,
|
| 115 |
+
"ban_reason": reason,
|
| 116 |
+
"updated_at": time.Now(),
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if err := database.GetDB().Model(&model.TokenRecord{}).
|
| 120 |
+
Where("id = ?", record.ID).
|
| 121 |
+
Updates(updates).Error; err != nil {
|
| 122 |
+
return fmt.Errorf("failed to update token record status: %w", err)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
log.Printf("[Token管理] Token记录 #%d 已标记为过期状态: %s", record.ID, reason)
|
| 126 |
+
return nil
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// disableTokenRecordsByEmail 根据邮箱禁用相关的token记录
|
| 130 |
+
func disableTokenRecordsByEmail(email string, reason string) error {
|
| 131 |
+
updates := map[string]interface{}{
|
| 132 |
+
"status": "banned",
|
| 133 |
+
"is_active": false,
|
| 134 |
+
"ban_reason": reason,
|
| 135 |
+
"updated_at": time.Now(),
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
result := database.GetDB().Model(&model.TokenRecord{}).
|
| 139 |
+
Where("email = ? AND status = ?", email, "active").
|
| 140 |
+
Updates(updates)
|
| 141 |
+
|
| 142 |
+
if result.Error != nil {
|
| 143 |
+
return fmt.Errorf("failed to disable token records: %w", result.Error)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
if result.RowsAffected > 0 {
|
| 147 |
+
log.Printf("[Token管理] 已禁用邮箱 %s 相关的 %d 条token记录: %s", email, result.RowsAffected, reason)
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return nil
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// RefreshAccessToken 使用 refresh_token 获取新的 access_token
|
| 154 |
+
func RefreshAccessToken(refreshToken string, proxy string) (*RefreshTokenResponse, error) {
|
| 155 |
+
url := "https://auth.zencoder.ai/api/frontegg/oauth/token"
|
| 156 |
+
|
| 157 |
+
// 打印调试日志
|
| 158 |
+
if IsDebugMode() {
|
| 159 |
+
log.Printf("[DEBUG] [RefreshToken] >>> 开始刷新Token")
|
| 160 |
+
log.Printf("[DEBUG] [RefreshToken] 请求URL: %s", url)
|
| 161 |
+
if len(refreshToken) > 20 {
|
| 162 |
+
log.Printf("[DEBUG] [RefreshToken] RefreshToken: %s...", refreshToken[:20])
|
| 163 |
+
} else {
|
| 164 |
+
log.Printf("[DEBUG] [RefreshToken] RefreshToken: %s", refreshToken)
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
reqBody := RefreshTokenRequest{
|
| 169 |
+
GrantType: "refresh_token",
|
| 170 |
+
RefreshToken: refreshToken,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
jsonData, err := json.Marshal(reqBody)
|
| 174 |
+
if err != nil {
|
| 175 |
+
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if IsDebugMode() {
|
| 179 |
+
log.Printf("[DEBUG] [RefreshToken] 请求Body: %s", string(jsonData))
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
| 183 |
+
if err != nil {
|
| 184 |
+
return nil, fmt.Errorf("failed to create request: %w", err)
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// 设置请求头
|
| 188 |
+
req.Header.Set("Accept", "*/*")
|
| 189 |
+
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
| 190 |
+
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8,zh-TW;q=0.7,ja;q=0.6")
|
| 191 |
+
req.Header.Set("Cache-Control", "no-cache")
|
| 192 |
+
req.Header.Set("Content-Type", "application/json")
|
| 193 |
+
req.Header.Set("Origin", "https://auth.zencoder.ai")
|
| 194 |
+
req.Header.Set("Pragma", "no-cache")
|
| 195 |
+
req.Header.Set("Priority", "u=1, i")
|
| 196 |
+
req.Header.Set("Sec-Ch-Ua", `"Google Chrome";v="143", "Chromium";v="143", "Not A(Brand";v="24"`)
|
| 197 |
+
req.Header.Set("Sec-Ch-Ua-Mobile", "?0")
|
| 198 |
+
req.Header.Set("Sec-Ch-Ua-Platform", `"Windows"`)
|
| 199 |
+
req.Header.Set("Sec-Fetch-Dest", "empty")
|
| 200 |
+
req.Header.Set("Sec-Fetch-Mode", "cors")
|
| 201 |
+
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
| 202 |
+
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/143.0.0.0 Safari/537.36")
|
| 203 |
+
req.Header.Set("X-Frontegg-Framework", "react@18.2.0")
|
| 204 |
+
req.Header.Set("X-Frontegg-Sdk", "@frontegg/react@7.12.14")
|
| 205 |
+
|
| 206 |
+
// 使用客户端执行请求
|
| 207 |
+
client := provider.NewHTTPClient(proxy, 30*time.Second)
|
| 208 |
+
|
| 209 |
+
if IsDebugMode() {
|
| 210 |
+
log.Printf("[DEBUG] [RefreshToken] → 发送请求...")
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
resp, err := client.Do(req)
|
| 214 |
+
if err != nil {
|
| 215 |
+
if IsDebugMode() {
|
| 216 |
+
log.Printf("[DEBUG] [RefreshToken] ✗ 请求失败: %v", err)
|
| 217 |
+
}
|
| 218 |
+
return nil, fmt.Errorf("failed to send request: %w", err)
|
| 219 |
+
}
|
| 220 |
+
defer resp.Body.Close()
|
| 221 |
+
|
| 222 |
+
if IsDebugMode() {
|
| 223 |
+
log.Printf("[DEBUG] [RefreshToken] ← 收到响应: status=%d", resp.StatusCode)
|
| 224 |
+
// 输出响应头
|
| 225 |
+
log.Printf("[DEBUG] [RefreshToken] 响应头:")
|
| 226 |
+
for k, v := range resp.Header {
|
| 227 |
+
log.Printf("[DEBUG] [RefreshToken] %s: %v", k, v)
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
body, err := io.ReadAll(resp.Body)
|
| 232 |
+
if err != nil {
|
| 233 |
+
return nil, fmt.Errorf("failed to read response: %w", err)
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
if IsDebugMode() {
|
| 237 |
+
log.Printf("[DEBUG] [RefreshToken] 响应Body: %s", string(body))
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
if resp.StatusCode != http.StatusOK {
|
| 241 |
+
if IsDebugMode() {
|
| 242 |
+
log.Printf("[DEBUG] [RefreshToken] ✗ API错误: %d - %s", resp.StatusCode, string(body))
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// 检查是否是账号锁定错误
|
| 246 |
+
if isAccountLockoutError(resp.StatusCode, string(body)) {
|
| 247 |
+
return nil, &AccountLockoutError{
|
| 248 |
+
StatusCode: resp.StatusCode,
|
| 249 |
+
Body: string(body),
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// 检查是否是refresh token无效错误
|
| 254 |
+
if isRefreshTokenInvalidError(resp.StatusCode, string(body)) {
|
| 255 |
+
return nil, fmt.Errorf("refresh token expired or invalid: status %d, body: %s", resp.StatusCode, string(body))
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
return nil, fmt.Errorf("failed to refresh token: status %d, body: %s", resp.StatusCode, string(body))
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
var tokenResp RefreshTokenResponse
|
| 262 |
+
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
| 263 |
+
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// 如果响应中没有UserID,尝试从access_token中解析
|
| 267 |
+
if tokenResp.UserID == "" && tokenResp.AccessToken != "" {
|
| 268 |
+
if payload, err := ParseJWT(tokenResp.AccessToken); err == nil {
|
| 269 |
+
// 优先使用 Email,没有则使用 Subject
|
| 270 |
+
if payload.Email != "" {
|
| 271 |
+
tokenResp.UserID = payload.Email
|
| 272 |
+
tokenResp.Email = payload.Email
|
| 273 |
+
} else if payload.Subject != "" {
|
| 274 |
+
tokenResp.UserID = payload.Subject
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
if IsDebugMode() {
|
| 278 |
+
log.Printf("[DEBUG] [RefreshToken] 从JWT解析UserID: %s", tokenResp.UserID)
|
| 279 |
+
log.Printf("[DEBUG] [RefreshToken] JWT Payload - Email: %s, Subject: %s",
|
| 280 |
+
payload.Email, payload.Subject)
|
| 281 |
+
}
|
| 282 |
+
} else {
|
| 283 |
+
if IsDebugMode() {
|
| 284 |
+
log.Printf("[DEBUG] [RefreshToken] 解析JWT失败: %v", err)
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if IsDebugMode() {
|
| 290 |
+
accessTokenPreview := tokenResp.AccessToken
|
| 291 |
+
if len(accessTokenPreview) > 20 {
|
| 292 |
+
accessTokenPreview = accessTokenPreview[:20]
|
| 293 |
+
}
|
| 294 |
+
log.Printf("[DEBUG] [RefreshToken] <<< 刷新成功: UserID=%s, AccessToken=%s..., ExpiresIn=%d",
|
| 295 |
+
tokenResp.UserID,
|
| 296 |
+
accessTokenPreview,
|
| 297 |
+
tokenResp.ExpiresIn)
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return &tokenResp, nil
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// min 辅助函数
|
| 304 |
+
func min(a, b int) int {
|
| 305 |
+
if a < b {
|
| 306 |
+
return a
|
| 307 |
+
}
|
| 308 |
+
return b
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
// UpdateAccountToken 更新账号的 token
|
| 312 |
+
func UpdateAccountToken(account *model.Account) error {
|
| 313 |
+
if account.RefreshToken == "" {
|
| 314 |
+
return fmt.Errorf("account %s has no refresh token", account.ClientID)
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
// 调用刷新接口
|
| 318 |
+
tokenResp, err := RefreshAccessToken(account.RefreshToken, account.Proxy)
|
| 319 |
+
if err != nil {
|
| 320 |
+
// 检查是否是账号锁定错误
|
| 321 |
+
if lockoutErr, ok := err.(*AccountLockoutError); ok {
|
| 322 |
+
// 将账号标记为封禁状态
|
| 323 |
+
if markErr := markAccountAsBanned(account, "用户被锁定: "+lockoutErr.Body); markErr != nil {
|
| 324 |
+
log.Printf("[账号管理] 标记账号封禁状态失败: %v", markErr)
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
return fmt.Errorf("failed to refresh token for account %s: %w", account.ClientID, err)
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// 计算过期时间
|
| 331 |
+
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 332 |
+
|
| 333 |
+
// 更新数据库
|
| 334 |
+
updates := map[string]interface{}{
|
| 335 |
+
"access_token": tokenResp.AccessToken,
|
| 336 |
+
"refresh_token": tokenResp.RefreshToken, // 更新新的 refresh_token
|
| 337 |
+
"token_expiry": expiry,
|
| 338 |
+
"updated_at": time.Now(),
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
if err := database.DB.Model(&model.Account{}).
|
| 342 |
+
Where("id = ?", account.ID).
|
| 343 |
+
Updates(updates).Error; err != nil {
|
| 344 |
+
return fmt.Errorf("failed to update account token: %w", err)
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// 更新内存中的值
|
| 348 |
+
account.AccessToken = tokenResp.AccessToken
|
| 349 |
+
account.RefreshToken = tokenResp.RefreshToken
|
| 350 |
+
account.TokenExpiry = expiry
|
| 351 |
+
|
| 352 |
+
debugLogf("✅ Refreshed token for account %s, expires at %s", account.ClientID, expiry.Format(time.RFC3339))
|
| 353 |
+
|
| 354 |
+
return nil
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// UpdateTokenRecordToken 更新 TokenRecord 的 token
|
| 358 |
+
func UpdateTokenRecordToken(record *model.TokenRecord) error {
|
| 359 |
+
if record.RefreshToken == "" {
|
| 360 |
+
return fmt.Errorf("token record %d has no refresh token", record.ID)
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
// 调用刷新接口
|
| 364 |
+
tokenResp, err := RefreshAccessToken(record.RefreshToken, "")
|
| 365 |
+
if err != nil {
|
| 366 |
+
// 检查是否是账号锁定错误
|
| 367 |
+
if lockoutErr, ok := err.(*AccountLockoutError); ok {
|
| 368 |
+
// 将token记录标记为封禁状态
|
| 369 |
+
if markErr := markTokenRecordAsBanned(record, "账号被锁定: "+lockoutErr.Body); markErr != nil {
|
| 370 |
+
log.Printf("[Token管理] 标记token记录封禁状态失败: %v", markErr)
|
| 371 |
+
}
|
| 372 |
+
// 根据邮箱禁用相关的token记录
|
| 373 |
+
if record.Email != "" {
|
| 374 |
+
if disableErr := disableTokenRecordsByEmail(record.Email, "关联账号被锁定"); disableErr != nil {
|
| 375 |
+
log.Printf("[Token管理] 禁用相关token记录失败: %v", disableErr)
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
return fmt.Errorf("token record %d account locked out: %w", record.ID, err)
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// 检查是否是refresh token过期错误
|
| 382 |
+
if strings.Contains(err.Error(), "refresh token expired or invalid") {
|
| 383 |
+
// 将token记录标记为过期状态
|
| 384 |
+
if markErr := markTokenRecordAsExpired(record, "Refresh token过期或无效"); markErr != nil {
|
| 385 |
+
log.Printf("[Token管理] 标记token记录过期状态失败: %v", markErr)
|
| 386 |
+
}
|
| 387 |
+
return fmt.Errorf("token record %d refresh token expired: %w", record.ID, err)
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
return fmt.Errorf("failed to refresh token for record %d: %w", record.ID, err)
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
// 计算过期时间
|
| 394 |
+
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 395 |
+
|
| 396 |
+
// 更新数据库
|
| 397 |
+
updates := map[string]interface{}{
|
| 398 |
+
"token": tokenResp.AccessToken,
|
| 399 |
+
"refresh_token": tokenResp.RefreshToken, // 更新新的 refresh_token
|
| 400 |
+
"token_expiry": expiry,
|
| 401 |
+
"status": "active", // 刷新成功时重新激活
|
| 402 |
+
"updated_at": time.Now(),
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
if err := database.DB.Model(&model.TokenRecord{}).
|
| 406 |
+
Where("id = ?", record.ID).
|
| 407 |
+
Updates(updates).Error; err != nil {
|
| 408 |
+
return fmt.Errorf("failed to update token record: %w", err)
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
// 更新内存中的值
|
| 412 |
+
record.Token = tokenResp.AccessToken
|
| 413 |
+
record.RefreshToken = tokenResp.RefreshToken
|
| 414 |
+
record.TokenExpiry = expiry
|
| 415 |
+
record.Status = "active"
|
| 416 |
+
|
| 417 |
+
debugLogf("✅ Refreshed token for record %d, expires at %s", record.ID, expiry.Format(time.RFC3339))
|
| 418 |
+
|
| 419 |
+
return nil
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
// CheckAndRefreshToken 检查并刷新即将过期的 token
|
| 423 |
+
func CheckAndRefreshToken(account *model.Account) error {
|
| 424 |
+
// 如果没有 RefreshToken,跳过
|
| 425 |
+
if account.RefreshToken == "" {
|
| 426 |
+
return nil
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// 如果 token 在一小时内过期,则刷新
|
| 430 |
+
if time.Until(account.TokenExpiry) < time.Hour {
|
| 431 |
+
debugLogf("⚠️ Token for account %s expires in %v, refreshing...",
|
| 432 |
+
account.ClientID, time.Until(account.TokenExpiry))
|
| 433 |
+
return UpdateAccountToken(account)
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
return nil
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// CheckAndRefreshTokenRecord 检查并刷新即将过期的 TokenRecord
|
| 440 |
+
func CheckAndRefreshTokenRecord(record *model.TokenRecord) error {
|
| 441 |
+
// 如果没有 RefreshToken,跳过
|
| 442 |
+
if record.RefreshToken == "" {
|
| 443 |
+
return nil
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// 如果 token 在一小时内过期,则刷新
|
| 447 |
+
if time.Until(record.TokenExpiry) < time.Hour {
|
| 448 |
+
debugLogf("⚠️ Token for record %d expires in %v, refreshing...",
|
| 449 |
+
record.ID, time.Until(record.TokenExpiry))
|
| 450 |
+
return UpdateTokenRecordToken(record)
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
return nil
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
// StartTokenRefreshScheduler 启动定时刷新 token 的调度器
|
| 457 |
+
func StartTokenRefreshScheduler() {
|
| 458 |
+
go func() {
|
| 459 |
+
// 立即执行一次
|
| 460 |
+
refreshExpiredTokens()
|
| 461 |
+
|
| 462 |
+
// 然后每分钟检查一次
|
| 463 |
+
ticker := time.NewTicker(1 * time.Minute)
|
| 464 |
+
defer ticker.Stop()
|
| 465 |
+
|
| 466 |
+
for range ticker.C {
|
| 467 |
+
refreshExpiredTokens()
|
| 468 |
+
}
|
| 469 |
+
}()
|
| 470 |
+
|
| 471 |
+
log.Printf("🔄 Token refresh scheduler started - checking every minute")
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
// refreshExpiredTokens 刷新即将过期的 tokens
|
| 475 |
+
func refreshExpiredTokens() {
|
| 476 |
+
now := time.Now()
|
| 477 |
+
threshold := now.Add(time.Hour) // 1小时内即将过期的token
|
| 478 |
+
|
| 479 |
+
// 查询所有即将过期的账号(排除banned状态)
|
| 480 |
+
var accounts []model.Account
|
| 481 |
+
if err := database.DB.Where("token_expiry < ?", threshold).
|
| 482 |
+
Where("status != ?", "banned").
|
| 483 |
+
Find(&accounts).Error; err == nil {
|
| 484 |
+
|
| 485 |
+
for _, account := range accounts {
|
| 486 |
+
// 根据账号类型选择不同的刷新方式
|
| 487 |
+
if account.ClientSecret == "refresh-token-login" {
|
| 488 |
+
// refresh-token-login 账号使用 refresh_token 刷新
|
| 489 |
+
if account.RefreshToken != "" {
|
| 490 |
+
if err := UpdateAccountToken(&account); err != nil {
|
| 491 |
+
log.Printf("[Token刷新] ❌ refresh-token账号 %s 刷新失败: %v", account.ClientID, err)
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
} else {
|
| 495 |
+
// 普通账号使用 OAuth client credentials 刷新
|
| 496 |
+
if account.ClientID != "" && account.ClientSecret != "" {
|
| 497 |
+
if err := refreshAccountToken(&account); err != nil {
|
| 498 |
+
log.Printf("[Token刷新] ❌ 账号 %s OAuth刷新失败: %v", account.ClientID, err)
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
}
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
// 刷新 TokenRecord 的 tokens - 只排除banned状态的记录
|
| 506 |
+
var records []model.TokenRecord
|
| 507 |
+
if err := database.DB.Where("refresh_token != '' AND token_expiry < ?", threshold).
|
| 508 |
+
Where("status != ?", "banned").
|
| 509 |
+
Find(&records).Error; err == nil {
|
| 510 |
+
|
| 511 |
+
for _, record := range records {
|
| 512 |
+
if err := UpdateTokenRecordToken(&record); err != nil {
|
| 513 |
+
log.Printf("[Token刷新] ❌ 生成token #%d 刷新失败: %v", record.ID, err)
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
// debugLogf 简单的调试日志函数
|
| 520 |
+
func debugLogf(format string, args ...interface{}) {
|
| 521 |
+
if IsDebugMode() {
|
| 522 |
+
log.Printf("[DEBUG] "+format, args...)
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// RefreshTokenAndAccounts 刷新token记录并异步刷新相同邮箱的账号
|
| 527 |
+
func RefreshTokenAndAccounts(tokenRecordID uint) error {
|
| 528 |
+
// 获取token记录
|
| 529 |
+
var record model.TokenRecord
|
| 530 |
+
if err := database.GetDB().First(&record, tokenRecordID).Error; err != nil {
|
| 531 |
+
return fmt.Errorf("获取token记录失败: %w", err)
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
if record.RefreshToken == "" {
|
| 535 |
+
return fmt.Errorf("token记录没有refresh_token")
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
// 1. 刷新token记录的token
|
| 539 |
+
log.Printf("[Token刷新] 开始刷新token记录 #%d", tokenRecordID)
|
| 540 |
+
|
| 541 |
+
// 调用刷新接口
|
| 542 |
+
tokenResp, err := RefreshAccessToken(record.RefreshToken, "")
|
| 543 |
+
if err != nil {
|
| 544 |
+
// 检查是否是账号锁定错误
|
| 545 |
+
if lockoutErr, ok := err.(*AccountLockoutError); ok {
|
| 546 |
+
// 将token记录标记为封禁状态
|
| 547 |
+
if markErr := markTokenRecordAsBanned(&record, "账号被锁定: "+lockoutErr.Body); markErr != nil {
|
| 548 |
+
log.Printf("[Token管理] 标记token记录封禁状态失败: %v", markErr)
|
| 549 |
+
}
|
| 550 |
+
// 根据邮箱禁用相关的token记录
|
| 551 |
+
if record.Email != "" {
|
| 552 |
+
if disableErr := disableTokenRecordsByEmail(record.Email, "关联账号被锁定"); disableErr != nil {
|
| 553 |
+
log.Printf("[Token管理] 禁用相关token记录失败: %v", disableErr)
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
// 检查是否是refresh token过期错误
|
| 559 |
+
if strings.Contains(err.Error(), "refresh token expired or invalid") {
|
| 560 |
+
// 将token记录标记为过期状态
|
| 561 |
+
if markErr := markTokenRecordAsExpired(&record, "Refresh token过期或无效"); markErr != nil {
|
| 562 |
+
log.Printf("[Token管理] 标记token记录过期状态失败: %v", markErr)
|
| 563 |
+
}
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
return fmt.Errorf("刷新token失败: %w", err)
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
// 计算过期时间
|
| 570 |
+
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 571 |
+
|
| 572 |
+
// 更新数据库
|
| 573 |
+
updates := map[string]interface{}{
|
| 574 |
+
"token": tokenResp.AccessToken,
|
| 575 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 576 |
+
"token_expiry": expiry,
|
| 577 |
+
"status": "active", // 刷新成功时重置为活跃状态
|
| 578 |
+
"ban_reason": "", // 清除封禁原因
|
| 579 |
+
"updated_at": time.Now(),
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
if err := database.GetDB().Model(&model.TokenRecord{}).
|
| 583 |
+
Where("id = ?", tokenRecordID).
|
| 584 |
+
Updates(updates).Error; err != nil {
|
| 585 |
+
return fmt.Errorf("更新token记录失败: %w", err)
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
// 2. 解析新token获取邮箱
|
| 589 |
+
email := ""
|
| 590 |
+
if payload, err := ParseJWT(tokenResp.AccessToken); err == nil {
|
| 591 |
+
email = payload.Email
|
| 592 |
+
log.Printf("[Token刷新] 解析到邮箱: %s", email)
|
| 593 |
+
} else {
|
| 594 |
+
log.Printf("[Token刷新] 无法解析JWT获取邮箱: %v", err)
|
| 595 |
+
return nil // 不影响token记录的刷新
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
if email == "" {
|
| 599 |
+
log.Printf("[Token刷新] 邮箱为空,跳过账号刷新")
|
| 600 |
+
return nil
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
// 3. 异步刷新相同邮箱的账号
|
| 604 |
+
go refreshAccountsByEmail(email)
|
| 605 |
+
|
| 606 |
+
return nil
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
// refreshAccountsByEmail 刷新指定邮箱的所有账号
|
| 610 |
+
func refreshAccountsByEmail(email string) {
|
| 611 |
+
log.Printf("[账号刷新] 开始刷新邮箱 %s 的所有账号", email)
|
| 612 |
+
|
| 613 |
+
// 查询所有相同邮箱的账号
|
| 614 |
+
var accounts []model.Account
|
| 615 |
+
if err := database.GetDB().Where("email = ?", email).Find(&accounts).Error; err != nil {
|
| 616 |
+
log.Printf("[账号刷新] 查询邮箱 %s 的账号失败: %v", email, err)
|
| 617 |
+
return
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
if len(accounts) == 0 {
|
| 621 |
+
log.Printf("[账号刷新] 没有找到邮箱 %s 的账号", email)
|
| 622 |
+
return
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
log.Printf("[账号刷新] 找到 %d 个账号需要刷新", len(accounts))
|
| 626 |
+
|
| 627 |
+
// 逐个刷新账号
|
| 628 |
+
successCount := 0
|
| 629 |
+
failCount := 0
|
| 630 |
+
|
| 631 |
+
for _, account := range accounts {
|
| 632 |
+
// 如果账号没有client_id和client_secret,跳过
|
| 633 |
+
if account.ClientID == "" || account.ClientSecret == "" {
|
| 634 |
+
log.Printf("[账号刷新] 账号 ID:%d 缺少client_id或client_secret,跳过", account.ID)
|
| 635 |
+
continue
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
log.Printf("[账号刷新] 正在刷新账号 ID:%d (ClientID: %s)", account.ID, account.ClientID)
|
| 639 |
+
|
| 640 |
+
// 使用OAuth方式刷新token
|
| 641 |
+
if err := refreshAccountToken(&account); err != nil {
|
| 642 |
+
log.Printf("[账号刷新] 账号 ID:%d 刷新失败: %v", account.ID, err)
|
| 643 |
+
failCount++
|
| 644 |
+
} else {
|
| 645 |
+
log.Printf("[账号刷新] 账号 ID:%d 刷新成功", account.ID)
|
| 646 |
+
successCount++
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
// 添加短暂延迟,避免请求过快
|
| 650 |
+
time.Sleep(100 * time.Millisecond)
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
log.Printf("[账号刷新] 邮箱 %s 的账号刷新完成 - 成功: %d, 失败: %d",
|
| 654 |
+
email, successCount, failCount)
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
// RefreshAccountToken 使用client credentials刷新账号token(导出函数)
|
| 658 |
+
func RefreshAccountToken(account *model.Account) error {
|
| 659 |
+
return refreshAccountToken(account)
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
// refreshAccountToken 使用client credentials刷新账号token
|
| 663 |
+
func refreshAccountToken(account *model.Account) error {
|
| 664 |
+
// 构建OAuth token请求
|
| 665 |
+
url := "https://fe.zencoder.ai/oauth/token"
|
| 666 |
+
|
| 667 |
+
reqBody := map[string]string{
|
| 668 |
+
"grant_type": "client_credentials",
|
| 669 |
+
"client_id": account.ClientID,
|
| 670 |
+
"client_secret": account.ClientSecret,
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
jsonData, err := json.Marshal(reqBody)
|
| 674 |
+
if err != nil {
|
| 675 |
+
return fmt.Errorf("序列化请求失败: %w", err)
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
| 679 |
+
if err != nil {
|
| 680 |
+
return fmt.Errorf("创建请求失败: %w", err)
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
// 设置请求头
|
| 684 |
+
req.Header.Set("Content-Type", "application/json")
|
| 685 |
+
req.Header.Set("Accept", "application/json")
|
| 686 |
+
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
| 687 |
+
|
| 688 |
+
// 使用代理(如果有)
|
| 689 |
+
client := provider.NewHTTPClient(account.Proxy, 30*time.Second)
|
| 690 |
+
|
| 691 |
+
resp, err := client.Do(req)
|
| 692 |
+
if err != nil {
|
| 693 |
+
return fmt.Errorf("发送请求失败: %w", err)
|
| 694 |
+
}
|
| 695 |
+
defer resp.Body.Close()
|
| 696 |
+
|
| 697 |
+
body, err := io.ReadAll(resp.Body)
|
| 698 |
+
if err != nil {
|
| 699 |
+
return fmt.Errorf("读取响应失败: %w", err)
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
if resp.StatusCode != http.StatusOK {
|
| 703 |
+
// 检查是否是账号锁定错误
|
| 704 |
+
if isAccountLockoutError(resp.StatusCode, string(body)) {
|
| 705 |
+
// 将账号标记为封禁状态
|
| 706 |
+
if markErr := markAccountAsBanned(account, "OAuth认证失败-用户被锁定: "+string(body)); markErr != nil {
|
| 707 |
+
log.Printf("[账号管理] 标记账号封禁状态失败: %v", markErr)
|
| 708 |
+
}
|
| 709 |
+
return &AccountLockoutError{
|
| 710 |
+
StatusCode: resp.StatusCode,
|
| 711 |
+
Body: string(body),
|
| 712 |
+
AccountID: account.ClientID,
|
| 713 |
+
}
|
| 714 |
+
}
|
| 715 |
+
return fmt.Errorf("API返回错误: %d - %s", resp.StatusCode, string(body))
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
// 解析响应
|
| 719 |
+
var tokenResp struct {
|
| 720 |
+
AccessToken string `json:"access_token"`
|
| 721 |
+
TokenType string `json:"token_type"`
|
| 722 |
+
ExpiresIn int `json:"expires_in"`
|
| 723 |
+
RefreshToken string `json:"refresh_token"`
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
| 727 |
+
return fmt.Errorf("解析响应失败: %w", err)
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
// 计算过期时间
|
| 731 |
+
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 732 |
+
|
| 733 |
+
// 解析token获取更多信息
|
| 734 |
+
planType := account.PlanType // 保留原有计划类型
|
| 735 |
+
dailyUsed := account.DailyUsed // 保留原有使用量
|
| 736 |
+
totalUsed := account.TotalUsed // 保留原有总使用量
|
| 737 |
+
|
| 738 |
+
if payload, err := ParseJWT(tokenResp.AccessToken); err == nil {
|
| 739 |
+
// 更新计划类型(如果有)
|
| 740 |
+
if payload.CustomClaims.Plan != "" {
|
| 741 |
+
planType = model.PlanType(payload.CustomClaims.Plan)
|
| 742 |
+
}
|
| 743 |
+
// 验证邮箱
|
| 744 |
+
if account.Email != "" && payload.Email != account.Email {
|
| 745 |
+
log.Printf("[账号刷新] 警告: 账号 ID:%d 邮箱不匹配 (期望: %s, 实际: %s)",
|
| 746 |
+
account.ID, account.Email, payload.Email)
|
| 747 |
+
}
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
// 更新数据库
|
| 751 |
+
updates := map[string]interface{}{
|
| 752 |
+
"access_token": tokenResp.AccessToken,
|
| 753 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 754 |
+
"token_expiry": expiry,
|
| 755 |
+
"plan_type": planType,
|
| 756 |
+
"daily_used": dailyUsed, // 保持原有使用量
|
| 757 |
+
"total_used": totalUsed, // 保持原有总使用量
|
| 758 |
+
"updated_at": time.Now(),
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
return database.GetDB().Model(&model.Account{}).
|
| 762 |
+
Where("id = ?", account.ID).
|
| 763 |
+
Updates(updates).Error
|
| 764 |
+
}
|
internal/service/request.go
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
// 注意:ReplaceModelInBody 函数已被删除,不再进行模型重定向/替换
|
internal/service/scheduler.go
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"time"
|
| 6 |
+
|
| 7 |
+
"zencoder-2api/internal/database"
|
| 8 |
+
"zencoder-2api/internal/model"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
func StartCreditResetScheduler() {
|
| 12 |
+
go func() {
|
| 13 |
+
for {
|
| 14 |
+
now := time.Now()
|
| 15 |
+
next := time.Date(now.Year(), now.Month(), now.Day(), 9, 9, 0, 0, now.Location())
|
| 16 |
+
if now.After(next) {
|
| 17 |
+
next = next.Add(24 * time.Hour)
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
time.Sleep(time.Until(next))
|
| 21 |
+
ResetAllCredits()
|
| 22 |
+
}
|
| 23 |
+
}()
|
| 24 |
+
log.Println("Credit reset scheduler started (daily at 09:09)")
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
func ResetAllCredits() {
|
| 28 |
+
today := time.Now().Format("2006-01-02")
|
| 29 |
+
|
| 30 |
+
database.GetDB().Model(&model.Account{}).
|
| 31 |
+
Where("last_reset_date != ? OR last_reset_date IS NULL", today).
|
| 32 |
+
Updates(map[string]interface{}{
|
| 33 |
+
"daily_used": 0,
|
| 34 |
+
"is_cooling": false,
|
| 35 |
+
"last_reset_date": today,
|
| 36 |
+
})
|
| 37 |
+
|
| 38 |
+
log.Printf("Credits reset completed at %s", time.Now().Format("2006-01-02 15:04:05"))
|
| 39 |
+
}
|
internal/service/stream.go
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"io"
|
| 6 |
+
"net/http"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
// StreamResponse 流式传输响应到客户端
|
| 10 |
+
func StreamResponse(w http.ResponseWriter, resp *http.Response) error {
|
| 11 |
+
// 复制响应头
|
| 12 |
+
for k, v := range resp.Header {
|
| 13 |
+
for _, vv := range v {
|
| 14 |
+
w.Header().Add(k, vv)
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
w.WriteHeader(resp.StatusCode)
|
| 18 |
+
|
| 19 |
+
// 获取Flusher接口
|
| 20 |
+
flusher, ok := w.(http.Flusher)
|
| 21 |
+
if !ok {
|
| 22 |
+
// 如果不支持Flusher,直接复制
|
| 23 |
+
_, err := io.Copy(w, resp.Body)
|
| 24 |
+
return err
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// 使用bufio读取并逐行刷新
|
| 28 |
+
reader := bufio.NewReader(resp.Body)
|
| 29 |
+
for {
|
| 30 |
+
line, err := reader.ReadBytes('\n')
|
| 31 |
+
if len(line) > 0 {
|
| 32 |
+
_, writeErr := w.Write(line)
|
| 33 |
+
if writeErr != nil {
|
| 34 |
+
return writeErr
|
| 35 |
+
}
|
| 36 |
+
flusher.Flush()
|
| 37 |
+
}
|
| 38 |
+
if err != nil {
|
| 39 |
+
if err == io.EOF {
|
| 40 |
+
return nil
|
| 41 |
+
}
|
| 42 |
+
return err
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// CopyResponse 普通响应复制
|
| 48 |
+
func CopyResponse(w http.ResponseWriter, resp *http.Response) error {
|
| 49 |
+
for k, v := range resp.Header {
|
| 50 |
+
for _, vv := range v {
|
| 51 |
+
w.Header().Add(k, vv)
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
w.WriteHeader(resp.StatusCode)
|
| 55 |
+
_, err := io.Copy(w, resp.Body)
|
| 56 |
+
return err
|
| 57 |
+
}
|
internal/service/token.go
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"io"
|
| 7 |
+
"log"
|
| 8 |
+
"net/http"
|
| 9 |
+
"net/url"
|
| 10 |
+
"strings"
|
| 11 |
+
"time"
|
| 12 |
+
|
| 13 |
+
"zencoder-2api/internal/database"
|
| 14 |
+
"zencoder-2api/internal/model"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
type TokenResponse struct {
|
| 18 |
+
AccessToken string `json:"access_token"`
|
| 19 |
+
TokenType string `json:"token_type"`
|
| 20 |
+
ExpiresIn int `json:"expires_in"`
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
const (
|
| 24 |
+
ZencoderTokenURL = "https://fe.zencoder.ai/oauth/token"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
func GetToken(account *model.Account) (string, error) {
|
| 28 |
+
if account.AccessToken != "" && time.Now().Before(account.TokenExpiry) {
|
| 29 |
+
return account.AccessToken, nil
|
| 30 |
+
}
|
| 31 |
+
return RefreshToken(account)
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
func RefreshToken(account *model.Account) (string, error) {
|
| 35 |
+
// 每次创建新的 HTTP 客户端,禁用连接复用
|
| 36 |
+
transport := &http.Transport{
|
| 37 |
+
DisableKeepAlives: true, // 禁用 Keep-Alive
|
| 38 |
+
DisableCompression: false,
|
| 39 |
+
MaxIdleConns: 0, // 不保持空闲连接
|
| 40 |
+
MaxIdleConnsPerHost: 0,
|
| 41 |
+
IdleConnTimeout: 0,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
if account.Proxy != "" {
|
| 45 |
+
proxyURL, err := url.Parse(account.Proxy)
|
| 46 |
+
if err == nil {
|
| 47 |
+
transport.Proxy = http.ProxyURL(proxyURL)
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
client := &http.Client{
|
| 52 |
+
Transport: transport,
|
| 53 |
+
Timeout: 30 * time.Second,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
data := url.Values{}
|
| 57 |
+
data.Set("grant_type", "client_credentials")
|
| 58 |
+
data.Set("client_id", account.ClientID)
|
| 59 |
+
data.Set("client_secret", account.ClientSecret)
|
| 60 |
+
|
| 61 |
+
req, err := http.NewRequest("POST", ZencoderTokenURL, strings.NewReader(data.Encode()))
|
| 62 |
+
if err != nil {
|
| 63 |
+
return "", err
|
| 64 |
+
}
|
| 65 |
+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
| 66 |
+
req.Header.Set("Connection", "close") // 明确要求关闭连接
|
| 67 |
+
|
| 68 |
+
resp, err := client.Do(req)
|
| 69 |
+
if err != nil {
|
| 70 |
+
return "", err
|
| 71 |
+
}
|
| 72 |
+
defer resp.Body.Close()
|
| 73 |
+
|
| 74 |
+
body, err := io.ReadAll(resp.Body)
|
| 75 |
+
if err != nil {
|
| 76 |
+
return "", err
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
if resp.StatusCode != http.StatusOK {
|
| 80 |
+
// 检查是否是账号锁定错误
|
| 81 |
+
if isAccountLockoutError(resp.StatusCode, string(body)) {
|
| 82 |
+
// 将账号标记为封禁状态
|
| 83 |
+
if markErr := markAccountAsBanned(account, "OAuth认证失败-用户被锁定: "+string(body)); markErr != nil {
|
| 84 |
+
log.Printf("[账号管理] 标记账号封禁状态失败: %v", markErr)
|
| 85 |
+
}
|
| 86 |
+
return "", &AccountLockoutError{
|
| 87 |
+
StatusCode: resp.StatusCode,
|
| 88 |
+
Body: string(body),
|
| 89 |
+
AccountID: account.ClientID,
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
return "", fmt.Errorf("token request failed: %s", string(body))
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
var tokenResp TokenResponse
|
| 96 |
+
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
| 97 |
+
return "", err
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
account.AccessToken = tokenResp.AccessToken
|
| 101 |
+
account.TokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-60) * time.Second)
|
| 102 |
+
|
| 103 |
+
// 只有已存在的账号才保存到数据库
|
| 104 |
+
if account.ID > 0 {
|
| 105 |
+
database.GetDB().Save(account)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// 显式关闭传输层,确保连接被清理
|
| 109 |
+
transport.CloseIdleConnections()
|
| 110 |
+
|
| 111 |
+
return account.AccessToken, nil
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
func createHTTPClient(proxy string) *http.Client {
|
| 115 |
+
transport := &http.Transport{}
|
| 116 |
+
|
| 117 |
+
if proxy != "" {
|
| 118 |
+
proxyURL, err := url.Parse(proxy)
|
| 119 |
+
if err == nil {
|
| 120 |
+
transport.Proxy = http.ProxyURL(proxyURL)
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
return &http.Client{
|
| 125 |
+
Transport: transport,
|
| 126 |
+
Timeout: 30 * time.Second,
|
| 127 |
+
}
|
| 128 |
+
}
|
internal/service/zencoder.go
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package service
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"bytes"
|
| 6 |
+
"encoding/json"
|
| 7 |
+
"fmt"
|
| 8 |
+
"io"
|
| 9 |
+
"net/http"
|
| 10 |
+
"time"
|
| 11 |
+
|
| 12 |
+
"github.com/google/uuid"
|
| 13 |
+
"zencoder-2api/internal/model"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
const (
|
| 17 |
+
ZencoderChatURL = "https://api.zencoder.ai/v1/chat/completions"
|
| 18 |
+
MaxRetries = 3
|
| 19 |
+
ZencoderVersion = "3.24.0"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
type ZencoderService struct{}
|
| 23 |
+
|
| 24 |
+
func NewZencoderService() *ZencoderService {
|
| 25 |
+
return &ZencoderService{}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
func setZencoderHeaders(req *http.Request, token, modelID string) {
|
| 29 |
+
req.Header.Set("Accept", "application/json")
|
| 30 |
+
req.Header.Set("Content-Type", "application/json")
|
| 31 |
+
req.Header.Set("Authorization", "Bearer "+token)
|
| 32 |
+
req.Header.Set("User-Agent", "zen-cli/0.9.0-windows-x64")
|
| 33 |
+
req.Header.Set("zen-model-id", modelID)
|
| 34 |
+
req.Header.Set("zencoder-arch", "x64")
|
| 35 |
+
req.Header.Set("zencoder-os", "windows")
|
| 36 |
+
req.Header.Set("zencoder-version", ZencoderVersion)
|
| 37 |
+
req.Header.Set("zencoder-client-type", "vscode")
|
| 38 |
+
req.Header.Set("zencoder-operation-id", uuid.New().String())
|
| 39 |
+
req.Header.Set("zencoder-operation-type", "agent_call")
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
func (s *ZencoderService) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 43 |
+
// 检查模型是否存在于模型字典中
|
| 44 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 45 |
+
if !exists {
|
| 46 |
+
return nil, ErrNoAvailableAccount
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
var lastErr error
|
| 50 |
+
for i := 0; i < MaxRetries; i++ {
|
| 51 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 52 |
+
if err != nil {
|
| 53 |
+
return nil, err
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
resp, err := s.doRequest(account, req)
|
| 57 |
+
if err != nil {
|
| 58 |
+
MarkAccountError(account)
|
| 59 |
+
lastErr = err
|
| 60 |
+
continue
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
ResetAccountError(account)
|
| 64 |
+
|
| 65 |
+
// ZenCoder服务没有HTTP响应,只能使用模型倍率
|
| 66 |
+
UseCredit(account, zenModel.Multiplier)
|
| 67 |
+
|
| 68 |
+
return resp, nil
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return nil, fmt.Errorf("all retries failed: %w", lastErr)
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
func (s *ZencoderService) doRequest(account *model.Account, req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 75 |
+
token, err := GetToken(account)
|
| 76 |
+
if err != nil {
|
| 77 |
+
return nil, err
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// 获取模型映射
|
| 81 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 82 |
+
if !exists {
|
| 83 |
+
return nil, ErrNoAvailableAccount
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
body, err := json.Marshal(req)
|
| 87 |
+
if err != nil {
|
| 88 |
+
return nil, err
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
client := createHTTPClient(account.Proxy)
|
| 92 |
+
httpReq, err := http.NewRequest("POST", ZencoderChatURL, bytes.NewReader(body))
|
| 93 |
+
if err != nil {
|
| 94 |
+
return nil, err
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
setZencoderHeaders(httpReq, token, zenModel.ID)
|
| 98 |
+
|
| 99 |
+
resp, err := client.Do(httpReq)
|
| 100 |
+
if err != nil {
|
| 101 |
+
return nil, err
|
| 102 |
+
}
|
| 103 |
+
defer resp.Body.Close()
|
| 104 |
+
|
| 105 |
+
respBody, err := io.ReadAll(resp.Body)
|
| 106 |
+
if err != nil {
|
| 107 |
+
return nil, err
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
if resp.StatusCode != http.StatusOK {
|
| 111 |
+
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
var chatResp model.ChatCompletionResponse
|
| 115 |
+
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
| 116 |
+
return nil, err
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
return &chatResp, nil
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
func (s *ZencoderService) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 123 |
+
// 检查模型是否存在于模型字典中
|
| 124 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 125 |
+
if !exists {
|
| 126 |
+
return ErrNoAvailableAccount
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
var lastErr error
|
| 130 |
+
for i := 0; i < MaxRetries; i++ {
|
| 131 |
+
account, err := GetNextAccountForModel(req.Model)
|
| 132 |
+
if err != nil {
|
| 133 |
+
return err
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
err = s.doStreamRequest(account, req, writer)
|
| 137 |
+
if err != nil {
|
| 138 |
+
MarkAccountError(account)
|
| 139 |
+
lastErr = err
|
| 140 |
+
continue
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
ResetAccountError(account)
|
| 144 |
+
|
| 145 |
+
// 流式响应,使用模型倍率
|
| 146 |
+
UseCredit(account, zenModel.Multiplier)
|
| 147 |
+
|
| 148 |
+
return nil
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
return fmt.Errorf("all retries failed: %w", lastErr)
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
func (s *ZencoderService) doStreamRequest(account *model.Account, req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 155 |
+
token, err := GetToken(account)
|
| 156 |
+
if err != nil {
|
| 157 |
+
return err
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
// 获取模型映射
|
| 161 |
+
zenModel, exists := model.GetZenModel(req.Model)
|
| 162 |
+
if !exists {
|
| 163 |
+
return ErrNoAvailableAccount
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
req.Stream = true
|
| 167 |
+
body, err := json.Marshal(req)
|
| 168 |
+
if err != nil {
|
| 169 |
+
return err
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
client := createHTTPClient(account.Proxy)
|
| 173 |
+
client.Timeout = 5 * time.Minute
|
| 174 |
+
|
| 175 |
+
httpReq, err := http.NewRequest("POST", ZencoderChatURL, bytes.NewReader(body))
|
| 176 |
+
if err != nil {
|
| 177 |
+
return err
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
setZencoderHeaders(httpReq, token, zenModel.ID)
|
| 181 |
+
|
| 182 |
+
resp, err := client.Do(httpReq)
|
| 183 |
+
if err != nil {
|
| 184 |
+
return err
|
| 185 |
+
}
|
| 186 |
+
defer resp.Body.Close()
|
| 187 |
+
|
| 188 |
+
if resp.StatusCode != http.StatusOK {
|
| 189 |
+
respBody, _ := io.ReadAll(resp.Body)
|
| 190 |
+
return fmt.Errorf("status %d: %s", resp.StatusCode, string(respBody))
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
return s.streamResponse(resp.Body, writer)
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
func (s *ZencoderService) streamResponse(body io.Reader, writer http.ResponseWriter) error {
|
| 197 |
+
flusher, ok := writer.(http.Flusher)
|
| 198 |
+
if !ok {
|
| 199 |
+
return fmt.Errorf("streaming not supported")
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 203 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 204 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 205 |
+
|
| 206 |
+
scanner := bufio.NewScanner(body)
|
| 207 |
+
for scanner.Scan() {
|
| 208 |
+
line := scanner.Text()
|
| 209 |
+
if line == "" {
|
| 210 |
+
continue
|
| 211 |
+
}
|
| 212 |
+
fmt.Fprintf(writer, "%s\n\n", line)
|
| 213 |
+
flusher.Flush()
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return scanner.Err()
|
| 217 |
+
}
|
main.go
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"os"
|
| 6 |
+
|
| 7 |
+
"github.com/gin-gonic/gin"
|
| 8 |
+
"github.com/joho/godotenv"
|
| 9 |
+
"zencoder-2api/internal/database"
|
| 10 |
+
"zencoder-2api/internal/handler"
|
| 11 |
+
"zencoder-2api/internal/middleware"
|
| 12 |
+
"zencoder-2api/internal/service"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
func main() {
|
| 16 |
+
// 加载 .env 文件
|
| 17 |
+
if err := godotenv.Load(); err != nil {
|
| 18 |
+
log.Println("No .env file found or error loading it, using system environment variables or defaults")
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
port := os.Getenv("PORT")
|
| 22 |
+
if port == "" {
|
| 23 |
+
port = "8080"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
dbPath := os.Getenv("DB_PATH")
|
| 27 |
+
if dbPath == "" {
|
| 28 |
+
dbPath = "data.db"
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
if err := database.Init(dbPath); err != nil {
|
| 32 |
+
log.Fatal("Failed to init database:", err)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// 启动积分重置定时任务
|
| 36 |
+
service.StartCreditResetScheduler()
|
| 37 |
+
|
| 38 |
+
// 启动Token刷新定时任务
|
| 39 |
+
service.StartTokenRefreshScheduler()
|
| 40 |
+
|
| 41 |
+
// 初始化账号池
|
| 42 |
+
service.InitAccountPool()
|
| 43 |
+
|
| 44 |
+
// 初始化自动生成服务
|
| 45 |
+
service.InitAutoGenerationService()
|
| 46 |
+
|
| 47 |
+
r := gin.Default()
|
| 48 |
+
setupRoutes(r)
|
| 49 |
+
|
| 50 |
+
log.Printf("Server starting on :%s", port)
|
| 51 |
+
if err := r.Run(":" + port); err != nil {
|
| 52 |
+
log.Fatal(err)
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func setupRoutes(r *gin.Engine) {
|
| 57 |
+
r.Static("/static", "./web/static")
|
| 58 |
+
r.LoadHTMLGlob("web/templates/*")
|
| 59 |
+
|
| 60 |
+
r.GET("/", func(c *gin.Context) {
|
| 61 |
+
c.HTML(200, "index.html", nil)
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
// Anthropic API - /v1/messages
|
| 65 |
+
anthropicHandler := handler.NewAnthropicHandler()
|
| 66 |
+
r.POST("/v1/messages", middleware.LoggerMiddleware(), middleware.AuthMiddleware(), anthropicHandler.Messages)
|
| 67 |
+
|
| 68 |
+
// OpenAI API - /v1/chat/completions, /v1/responses
|
| 69 |
+
openaiHandler := handler.NewOpenAIHandler()
|
| 70 |
+
r.POST("/v1/chat/completions", middleware.LoggerMiddleware(), middleware.AuthMiddleware(), openaiHandler.ChatCompletions)
|
| 71 |
+
r.POST("/v1/responses", middleware.LoggerMiddleware(), middleware.AuthMiddleware(), openaiHandler.Responses)
|
| 72 |
+
|
| 73 |
+
// Gemini API - /v1beta/models/*path
|
| 74 |
+
geminiHandler := handler.NewGeminiHandler()
|
| 75 |
+
r.POST("/v1beta/models/*path", middleware.LoggerMiddleware(), middleware.AuthMiddleware(), geminiHandler.HandleRequest)
|
| 76 |
+
|
| 77 |
+
// OAuth处理器 - 不需要管理密码验证(公开访问)
|
| 78 |
+
oauthHandler := handler.NewOAuthHandler()
|
| 79 |
+
r.GET("/api/oauth/start-rt", oauthHandler.StartOAuthForRT)
|
| 80 |
+
r.GET("/api/oauth/callback-rt", oauthHandler.CallbackOAuthForRT)
|
| 81 |
+
|
| 82 |
+
// External API - 用于注册机提交OAuth token(公开访问)
|
| 83 |
+
externalHandler := handler.NewExternalHandler()
|
| 84 |
+
r.POST("/api/external/submit-tokens", externalHandler.SubmitTokens)
|
| 85 |
+
|
| 86 |
+
// Account management API - 需要后台管理密码验证
|
| 87 |
+
accountHandler := handler.NewAccountHandler()
|
| 88 |
+
tokenHandler := handler.NewTokenHandler()
|
| 89 |
+
api := r.Group("/api")
|
| 90 |
+
api.Use(middleware.AdminAuthMiddleware()) // 应用后台管理密码验证中间件
|
| 91 |
+
{
|
| 92 |
+
// 账号管理
|
| 93 |
+
api.GET("/accounts", accountHandler.List)
|
| 94 |
+
api.POST("/accounts", accountHandler.Create)
|
| 95 |
+
api.PUT("/accounts/:id", accountHandler.Update)
|
| 96 |
+
api.DELETE("/accounts/:id", accountHandler.Delete)
|
| 97 |
+
api.POST("/accounts/:id/toggle", accountHandler.Toggle)
|
| 98 |
+
api.POST("/accounts/batch/category", accountHandler.BatchUpdateCategory)
|
| 99 |
+
api.POST("/accounts/batch/move-all", accountHandler.BatchMoveAll)
|
| 100 |
+
api.POST("/accounts/batch/refresh-token", accountHandler.BatchRefreshToken)
|
| 101 |
+
api.POST("/accounts/batch/delete", accountHandler.BatchDelete)
|
| 102 |
+
|
| 103 |
+
// Token记录管理
|
| 104 |
+
api.GET("/tokens", tokenHandler.ListTokenRecords)
|
| 105 |
+
api.PUT("/tokens/:id", tokenHandler.UpdateTokenRecord)
|
| 106 |
+
api.DELETE("/tokens/:id", tokenHandler.DeleteTokenRecord)
|
| 107 |
+
api.POST("/tokens/:id/trigger", tokenHandler.TriggerGeneration)
|
| 108 |
+
api.POST("/tokens/:id/refresh", tokenHandler.RefreshTokenRecord)
|
| 109 |
+
api.GET("/tokens/tasks", tokenHandler.GetGenerationTasks)
|
| 110 |
+
api.GET("/tokens/pool-status", tokenHandler.GetPoolStatus)
|
| 111 |
+
}
|
| 112 |
+
}
|