Upload 10 files
Browse files- internal/service/provider/anthropic.go +164 -0
- internal/service/provider/client.go +69 -0
- internal/service/provider/errors.go +10 -0
- internal/service/provider/factory.go +27 -0
- internal/service/provider/gemini.go +124 -0
- internal/service/provider/grok.go +149 -0
- internal/service/provider/manager.go +73 -0
- internal/service/provider/openai.go +149 -0
- internal/service/provider/proxy.go +251 -0
internal/service/provider/anthropic.go
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/anthropics/anthropic-sdk-go"
|
| 10 |
+
"github.com/anthropics/anthropic-sdk-go/option"
|
| 11 |
+
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultAnthropicBaseURL = "https://api.anthropic.com"
|
| 16 |
+
|
| 17 |
+
type AnthropicProvider struct {
|
| 18 |
+
client *anthropic.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewAnthropicProvider(cfg Config) *AnthropicProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultAnthropicBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &AnthropicProvider{
|
| 42 |
+
client: anthropic.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *AnthropicProvider) Name() string {
|
| 48 |
+
return "anthropic"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *AnthropicProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Messages.New(context.Background(), anthropic.MessageNewParams{
|
| 53 |
+
Model: anthropic.F(anthropic.ModelClaude3_5SonnetLatest),
|
| 54 |
+
MaxTokens: anthropic.F(int64(1)),
|
| 55 |
+
Messages: anthropic.F([]anthropic.MessageParam{{Role: anthropic.F(anthropic.MessageParamRoleUser), Content: anthropic.F([]anthropic.ContentBlockParamUnion{anthropic.NewTextBlock("hi")})}}),
|
| 56 |
+
})
|
| 57 |
+
return err
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
func (p *AnthropicProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 61 |
+
messages := p.convertMessages(req.Messages)
|
| 62 |
+
|
| 63 |
+
maxTokens := int64(4096)
|
| 64 |
+
if req.MaxTokens > 0 {
|
| 65 |
+
maxTokens = int64(req.MaxTokens)
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
resp, err := p.client.Messages.New(context.Background(), anthropic.MessageNewParams{
|
| 69 |
+
Model: anthropic.F(req.Model),
|
| 70 |
+
MaxTokens: anthropic.F(maxTokens),
|
| 71 |
+
Messages: anthropic.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *AnthropicProvider) convertMessages(msgs []model.ChatMessage) []anthropic.MessageParam {
|
| 81 |
+
var messages []anthropic.MessageParam
|
| 82 |
+
for _, msg := range msgs {
|
| 83 |
+
if msg.Role == "system" {
|
| 84 |
+
continue
|
| 85 |
+
}
|
| 86 |
+
role := anthropic.MessageParamRoleUser
|
| 87 |
+
if msg.Role == "assistant" {
|
| 88 |
+
role = anthropic.MessageParamRoleAssistant
|
| 89 |
+
}
|
| 90 |
+
messages = append(messages, anthropic.MessageParam{
|
| 91 |
+
Role: anthropic.F(role),
|
| 92 |
+
Content: anthropic.F([]anthropic.ContentBlockParamUnion{anthropic.NewTextBlock(msg.Content)}),
|
| 93 |
+
})
|
| 94 |
+
}
|
| 95 |
+
return messages
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
func (p *AnthropicProvider) convertResponse(resp *anthropic.Message) *model.ChatCompletionResponse {
|
| 99 |
+
content := ""
|
| 100 |
+
for _, block := range resp.Content {
|
| 101 |
+
if block.Type == anthropic.ContentBlockTypeText {
|
| 102 |
+
content += block.Text
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
return &model.ChatCompletionResponse{
|
| 107 |
+
ID: resp.ID,
|
| 108 |
+
Object: "chat.completion",
|
| 109 |
+
Model: string(resp.Model),
|
| 110 |
+
Choices: []model.Choice{{
|
| 111 |
+
Index: 0,
|
| 112 |
+
Message: model.ChatMessage{Role: "assistant", Content: content},
|
| 113 |
+
FinishReason: string(resp.StopReason),
|
| 114 |
+
}},
|
| 115 |
+
Usage: model.Usage{
|
| 116 |
+
PromptTokens: int(resp.Usage.InputTokens),
|
| 117 |
+
CompletionTokens: int(resp.Usage.OutputTokens),
|
| 118 |
+
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
func (p *AnthropicProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 124 |
+
messages := p.convertMessages(req.Messages)
|
| 125 |
+
|
| 126 |
+
maxTokens := int64(4096)
|
| 127 |
+
if req.MaxTokens > 0 {
|
| 128 |
+
maxTokens = int64(req.MaxTokens)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
stream := p.client.Messages.NewStreaming(context.Background(), anthropic.MessageNewParams{
|
| 132 |
+
Model: anthropic.F(req.Model),
|
| 133 |
+
MaxTokens: anthropic.F(maxTokens),
|
| 134 |
+
Messages: anthropic.F(messages),
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
return p.handleStream(stream, writer)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
func (p *AnthropicProvider) handleStream(stream *ssestream.Stream[anthropic.MessageStreamEvent], writer http.ResponseWriter) error {
|
| 141 |
+
flusher, ok := writer.(http.Flusher)
|
| 142 |
+
if !ok {
|
| 143 |
+
return fmt.Errorf("streaming not supported")
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 147 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 148 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 149 |
+
|
| 150 |
+
for stream.Next() {
|
| 151 |
+
event := stream.Current()
|
| 152 |
+
data, _ := json.Marshal(event)
|
| 153 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 154 |
+
flusher.Flush()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
if err := stream.Err(); err != nil {
|
| 158 |
+
return err
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 162 |
+
flusher.Flush()
|
| 163 |
+
return nil
|
| 164 |
+
}
|
internal/service/provider/client.go
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"net/http"
|
| 6 |
+
"net/url"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
// NewHTTPClient 创建HTTP客户端
|
| 12 |
+
// 支持HTTP和SOCKS5代理
|
| 13 |
+
func NewHTTPClient(proxy string, timeout time.Duration) *http.Client {
|
| 14 |
+
// 如果代理是SOCKS5格式,使用新的代理客户端创建函数
|
| 15 |
+
if strings.HasPrefix(proxy, "socks5://") {
|
| 16 |
+
client, err := NewHTTPClientWithProxy(proxy, timeout)
|
| 17 |
+
if err != nil {
|
| 18 |
+
log.Printf("创建SOCKS5代理客户端失败: %v, 使用默认客户端", err)
|
| 19 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 20 |
+
return client
|
| 21 |
+
}
|
| 22 |
+
return client
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// 原有的HTTP代理逻辑
|
| 26 |
+
transport := &http.Transport{}
|
| 27 |
+
|
| 28 |
+
if proxy != "" {
|
| 29 |
+
if proxyURL, err := url.Parse(proxy); err == nil {
|
| 30 |
+
transport.Proxy = http.ProxyURL(proxyURL)
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
if timeout == 0 {
|
| 35 |
+
timeout = 600 * time.Second // 10分钟超时,支持长时间流式响应
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return &http.Client{
|
| 39 |
+
Transport: transport,
|
| 40 |
+
Timeout: timeout,
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// NewHTTPClientWithPoolProxy 使用代理池创建HTTP客户端
|
| 45 |
+
func NewHTTPClientWithPoolProxy(useProxy bool, timeout time.Duration) *http.Client {
|
| 46 |
+
if !useProxy {
|
| 47 |
+
// 不使用代理
|
| 48 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 49 |
+
return client
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
pool := GetProxyPool()
|
| 53 |
+
if !pool.HasProxies() {
|
| 54 |
+
// 没有可用代理,使用默认客户端
|
| 55 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 56 |
+
return client
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// 获取下一个代理
|
| 60 |
+
proxyURL := pool.GetNextProxy()
|
| 61 |
+
client, err := NewHTTPClientWithProxy(proxyURL, timeout)
|
| 62 |
+
if err != nil {
|
| 63 |
+
log.Printf("使用代理 %s 创建客户端失败: %v, 使用默认客户端", proxyURL, err)
|
| 64 |
+
client, _ := NewHTTPClientWithProxy("", timeout)
|
| 65 |
+
return client
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return client
|
| 69 |
+
}
|
internal/service/provider/errors.go
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import "errors"
|
| 4 |
+
|
| 5 |
+
var (
|
| 6 |
+
ErrStreamNotSupported = errors.New("streaming not supported")
|
| 7 |
+
ErrInvalidToken = errors.New("invalid token")
|
| 8 |
+
ErrRequestFailed = errors.New("request failed")
|
| 9 |
+
ErrUnknownProvider = errors.New("unknown provider")
|
| 10 |
+
)
|
internal/service/provider/factory.go
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import "fmt"
|
| 4 |
+
|
| 5 |
+
type ProviderType string
|
| 6 |
+
|
| 7 |
+
const (
|
| 8 |
+
ProviderOpenAI ProviderType = "openai"
|
| 9 |
+
ProviderAnthropic ProviderType = "anthropic"
|
| 10 |
+
ProviderGemini ProviderType = "gemini"
|
| 11 |
+
ProviderGrok ProviderType = "grok"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
func NewProvider(providerType ProviderType, cfg Config) (Provider, error) {
|
| 15 |
+
switch providerType {
|
| 16 |
+
case ProviderOpenAI:
|
| 17 |
+
return NewOpenAIProvider(cfg), nil
|
| 18 |
+
case ProviderAnthropic:
|
| 19 |
+
return NewAnthropicProvider(cfg), nil
|
| 20 |
+
case ProviderGemini:
|
| 21 |
+
return NewGeminiProvider(cfg)
|
| 22 |
+
case ProviderGrok:
|
| 23 |
+
return NewGrokProvider(cfg), nil
|
| 24 |
+
default:
|
| 25 |
+
return nil, fmt.Errorf("%w: %s", ErrUnknownProvider, providerType)
|
| 26 |
+
}
|
| 27 |
+
}
|
internal/service/provider/gemini.go
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/google/generative-ai-go/genai"
|
| 10 |
+
"google.golang.org/api/option"
|
| 11 |
+
"zencoder-2api/internal/model"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
type GeminiProvider struct {
|
| 15 |
+
client *genai.Client
|
| 16 |
+
config Config
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
func NewGeminiProvider(cfg Config) (*GeminiProvider, error) {
|
| 20 |
+
ctx := context.Background()
|
| 21 |
+
|
| 22 |
+
opts := []option.ClientOption{
|
| 23 |
+
option.WithAPIKey(cfg.APIKey),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if cfg.BaseURL != "" {
|
| 27 |
+
opts = append(opts, option.WithEndpoint(cfg.BaseURL))
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
client, err := genai.NewClient(ctx, opts...)
|
| 31 |
+
if err != nil {
|
| 32 |
+
return nil, err
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
return &GeminiProvider{
|
| 36 |
+
client: client,
|
| 37 |
+
config: cfg,
|
| 38 |
+
}, nil
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
func (p *GeminiProvider) Name() string {
|
| 42 |
+
return "gemini"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func (p *GeminiProvider) ValidateToken() error {
|
| 46 |
+
model := p.client.GenerativeModel("gemini-1.5-flash")
|
| 47 |
+
_, err := model.GenerateContent(context.Background(), genai.Text("hi"))
|
| 48 |
+
return err
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *GeminiProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 52 |
+
geminiModel := p.client.GenerativeModel(req.Model)
|
| 53 |
+
|
| 54 |
+
var parts []genai.Part
|
| 55 |
+
for _, msg := range req.Messages {
|
| 56 |
+
parts = append(parts, genai.Text(msg.Content))
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
resp, err := geminiModel.GenerateContent(context.Background(), parts...)
|
| 60 |
+
if err != nil {
|
| 61 |
+
return nil, err
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
return p.convertResponse(resp, req.Model), nil
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
func (p *GeminiProvider) convertResponse(resp *genai.GenerateContentResponse, modelName string) *model.ChatCompletionResponse {
|
| 68 |
+
content := ""
|
| 69 |
+
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
|
| 70 |
+
for _, part := range resp.Candidates[0].Content.Parts {
|
| 71 |
+
if text, ok := part.(genai.Text); ok {
|
| 72 |
+
content += string(text)
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return &model.ChatCompletionResponse{
|
| 78 |
+
ID: "gemini-" + modelName,
|
| 79 |
+
Object: "chat.completion",
|
| 80 |
+
Model: modelName,
|
| 81 |
+
Choices: []model.Choice{{
|
| 82 |
+
Index: 0,
|
| 83 |
+
Message: model.ChatMessage{Role: "assistant", Content: content},
|
| 84 |
+
FinishReason: "stop",
|
| 85 |
+
}},
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
func (p *GeminiProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 90 |
+
geminiModel := p.client.GenerativeModel(req.Model)
|
| 91 |
+
|
| 92 |
+
var parts []genai.Part
|
| 93 |
+
for _, msg := range req.Messages {
|
| 94 |
+
parts = append(parts, genai.Text(msg.Content))
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
iter := geminiModel.GenerateContentStream(context.Background(), parts...)
|
| 98 |
+
return p.handleStream(iter, writer)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
func (p *GeminiProvider) handleStream(iter *genai.GenerateContentResponseIterator, writer http.ResponseWriter) error {
|
| 102 |
+
flusher, ok := writer.(http.Flusher)
|
| 103 |
+
if !ok {
|
| 104 |
+
return fmt.Errorf("streaming not supported")
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 108 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 109 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 110 |
+
|
| 111 |
+
for {
|
| 112 |
+
resp, err := iter.Next()
|
| 113 |
+
if err != nil {
|
| 114 |
+
break
|
| 115 |
+
}
|
| 116 |
+
data, _ := json.Marshal(resp)
|
| 117 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 118 |
+
flusher.Flush()
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 122 |
+
flusher.Flush()
|
| 123 |
+
return nil
|
| 124 |
+
}
|
internal/service/provider/grok.go
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/openai/openai-go"
|
| 10 |
+
"github.com/openai/openai-go/option"
|
| 11 |
+
"github.com/openai/openai-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultGrokBaseURL = "https://api.x.ai/v1"
|
| 16 |
+
|
| 17 |
+
type GrokProvider struct {
|
| 18 |
+
client *openai.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewGrokProvider(cfg Config) *GrokProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultGrokBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &GrokProvider{
|
| 42 |
+
client: openai.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *GrokProvider) Name() string {
|
| 48 |
+
return "grok"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *GrokProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Models.List(context.Background())
|
| 53 |
+
return err
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func (p *GrokProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 57 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 58 |
+
for i, msg := range req.Messages {
|
| 59 |
+
switch msg.Role {
|
| 60 |
+
case "system":
|
| 61 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 62 |
+
case "user":
|
| 63 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 64 |
+
case "assistant":
|
| 65 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp, err := p.client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
|
| 70 |
+
Model: openai.F(req.Model),
|
| 71 |
+
Messages: openai.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *GrokProvider) convertResponse(resp *openai.ChatCompletion) *model.ChatCompletionResponse {
|
| 81 |
+
choices := make([]model.Choice, len(resp.Choices))
|
| 82 |
+
for i, c := range resp.Choices {
|
| 83 |
+
choices[i] = model.Choice{
|
| 84 |
+
Index: int(c.Index),
|
| 85 |
+
Message: model.ChatMessage{Role: string(c.Message.Role), Content: c.Message.Content},
|
| 86 |
+
FinishReason: string(c.FinishReason),
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return &model.ChatCompletionResponse{
|
| 91 |
+
ID: resp.ID,
|
| 92 |
+
Object: string(resp.Object),
|
| 93 |
+
Created: resp.Created,
|
| 94 |
+
Model: resp.Model,
|
| 95 |
+
Choices: choices,
|
| 96 |
+
Usage: model.Usage{
|
| 97 |
+
PromptTokens: int(resp.Usage.PromptTokens),
|
| 98 |
+
CompletionTokens: int(resp.Usage.CompletionTokens),
|
| 99 |
+
TotalTokens: int(resp.Usage.TotalTokens),
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
func (p *GrokProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 105 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 106 |
+
for i, msg := range req.Messages {
|
| 107 |
+
switch msg.Role {
|
| 108 |
+
case "system":
|
| 109 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 110 |
+
case "user":
|
| 111 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 112 |
+
case "assistant":
|
| 113 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
stream := p.client.Chat.Completions.NewStreaming(context.Background(), openai.ChatCompletionNewParams{
|
| 118 |
+
Model: openai.F(req.Model),
|
| 119 |
+
Messages: openai.F(messages),
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
return p.handleStream(stream, writer)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
func (p *GrokProvider) handleStream(stream *ssestream.Stream[openai.ChatCompletionChunk], writer http.ResponseWriter) error {
|
| 126 |
+
flusher, ok := writer.(http.Flusher)
|
| 127 |
+
if !ok {
|
| 128 |
+
return fmt.Errorf("streaming not supported")
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 132 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 133 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 134 |
+
|
| 135 |
+
for stream.Next() {
|
| 136 |
+
chunk := stream.Current()
|
| 137 |
+
data, _ := json.Marshal(chunk)
|
| 138 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 139 |
+
flusher.Flush()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if err := stream.Err(); err != nil {
|
| 143 |
+
return err
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 147 |
+
flusher.Flush()
|
| 148 |
+
return nil
|
| 149 |
+
}
|
internal/service/provider/manager.go
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"sync"
|
| 6 |
+
|
| 7 |
+
"zencoder-2api/internal/model"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
// Manager Provider管理器,缓存已创建的provider实例
|
| 11 |
+
type Manager struct {
|
| 12 |
+
mu sync.RWMutex
|
| 13 |
+
providers map[string]Provider
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
var defaultManager = &Manager{
|
| 17 |
+
providers: make(map[string]Provider),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// GetManager 获取默认管理器
|
| 21 |
+
func GetManager() *Manager {
|
| 22 |
+
return defaultManager
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// GetProvider 根据账号和模型获取或创建Provider
|
| 26 |
+
func (m *Manager) GetProvider(accountID uint, zenModel model.ZenModel, cfg Config) (Provider, error) {
|
| 27 |
+
key := m.buildKey(accountID, zenModel.ProviderID)
|
| 28 |
+
|
| 29 |
+
m.mu.RLock()
|
| 30 |
+
if p, ok := m.providers[key]; ok {
|
| 31 |
+
m.mu.RUnlock()
|
| 32 |
+
return p, nil
|
| 33 |
+
}
|
| 34 |
+
m.mu.RUnlock()
|
| 35 |
+
|
| 36 |
+
return m.createProvider(key, zenModel.ProviderID, cfg)
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func (m *Manager) buildKey(accountID uint, providerID string) string {
|
| 40 |
+
return fmt.Sprintf("%d:%s", accountID, providerID)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
func (m *Manager) createProvider(key, providerID string, cfg Config) (Provider, error) {
|
| 44 |
+
m.mu.Lock()
|
| 45 |
+
defer m.mu.Unlock()
|
| 46 |
+
|
| 47 |
+
// 双重检查
|
| 48 |
+
if p, ok := m.providers[key]; ok {
|
| 49 |
+
return p, nil
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
var providerType ProviderType
|
| 53 |
+
switch providerID {
|
| 54 |
+
case "openai":
|
| 55 |
+
providerType = ProviderOpenAI
|
| 56 |
+
case "anthropic":
|
| 57 |
+
providerType = ProviderAnthropic
|
| 58 |
+
case "gemini":
|
| 59 |
+
providerType = ProviderGemini
|
| 60 |
+
case "xai":
|
| 61 |
+
providerType = ProviderGrok
|
| 62 |
+
default:
|
| 63 |
+
providerType = ProviderAnthropic
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
p, err := NewProvider(providerType, cfg)
|
| 67 |
+
if err != nil {
|
| 68 |
+
return nil, err
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
m.providers[key] = p
|
| 72 |
+
return p, nil
|
| 73 |
+
}
|
internal/service/provider/openai.go
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http"
|
| 8 |
+
|
| 9 |
+
"github.com/openai/openai-go"
|
| 10 |
+
"github.com/openai/openai-go/option"
|
| 11 |
+
"github.com/openai/openai-go/packages/ssestream"
|
| 12 |
+
"zencoder-2api/internal/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
const DefaultOpenAIBaseURL = "https://api.openai.com/v1"
|
| 16 |
+
|
| 17 |
+
type OpenAIProvider struct {
|
| 18 |
+
client *openai.Client
|
| 19 |
+
config Config
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func NewOpenAIProvider(cfg Config) *OpenAIProvider {
|
| 23 |
+
if cfg.BaseURL == "" {
|
| 24 |
+
cfg.BaseURL = DefaultOpenAIBaseURL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
opts := []option.RequestOption{
|
| 28 |
+
option.WithAPIKey(cfg.APIKey),
|
| 29 |
+
option.WithBaseURL(cfg.BaseURL),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
for k, v := range cfg.ExtraHeaders {
|
| 33 |
+
opts = append(opts, option.WithHeader(k, v))
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if cfg.Proxy != "" {
|
| 37 |
+
httpClient := NewHTTPClient(cfg.Proxy, 0)
|
| 38 |
+
opts = append(opts, option.WithHTTPClient(httpClient))
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
return &OpenAIProvider{
|
| 42 |
+
client: openai.NewClient(opts...),
|
| 43 |
+
config: cfg,
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (p *OpenAIProvider) Name() string {
|
| 48 |
+
return "openai"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (p *OpenAIProvider) ValidateToken() error {
|
| 52 |
+
_, err := p.client.Models.List(context.Background())
|
| 53 |
+
return err
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func (p *OpenAIProvider) Chat(req *model.ChatCompletionRequest) (*model.ChatCompletionResponse, error) {
|
| 57 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 58 |
+
for i, msg := range req.Messages {
|
| 59 |
+
switch msg.Role {
|
| 60 |
+
case "system":
|
| 61 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 62 |
+
case "user":
|
| 63 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 64 |
+
case "assistant":
|
| 65 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp, err := p.client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
|
| 70 |
+
Model: openai.F(req.Model),
|
| 71 |
+
Messages: openai.F(messages),
|
| 72 |
+
})
|
| 73 |
+
if err != nil {
|
| 74 |
+
return nil, err
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
return p.convertResponse(resp), nil
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
func (p *OpenAIProvider) convertResponse(resp *openai.ChatCompletion) *model.ChatCompletionResponse {
|
| 81 |
+
choices := make([]model.Choice, len(resp.Choices))
|
| 82 |
+
for i, c := range resp.Choices {
|
| 83 |
+
choices[i] = model.Choice{
|
| 84 |
+
Index: int(c.Index),
|
| 85 |
+
Message: model.ChatMessage{Role: string(c.Message.Role), Content: c.Message.Content},
|
| 86 |
+
FinishReason: string(c.FinishReason),
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return &model.ChatCompletionResponse{
|
| 91 |
+
ID: resp.ID,
|
| 92 |
+
Object: string(resp.Object),
|
| 93 |
+
Created: resp.Created,
|
| 94 |
+
Model: resp.Model,
|
| 95 |
+
Choices: choices,
|
| 96 |
+
Usage: model.Usage{
|
| 97 |
+
PromptTokens: int(resp.Usage.PromptTokens),
|
| 98 |
+
CompletionTokens: int(resp.Usage.CompletionTokens),
|
| 99 |
+
TotalTokens: int(resp.Usage.TotalTokens),
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
func (p *OpenAIProvider) ChatStream(req *model.ChatCompletionRequest, writer http.ResponseWriter) error {
|
| 105 |
+
messages := make([]openai.ChatCompletionMessageParamUnion, len(req.Messages))
|
| 106 |
+
for i, msg := range req.Messages {
|
| 107 |
+
switch msg.Role {
|
| 108 |
+
case "system":
|
| 109 |
+
messages[i] = openai.SystemMessage(msg.Content)
|
| 110 |
+
case "user":
|
| 111 |
+
messages[i] = openai.UserMessage(msg.Content)
|
| 112 |
+
case "assistant":
|
| 113 |
+
messages[i] = openai.AssistantMessage(msg.Content)
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
stream := p.client.Chat.Completions.NewStreaming(context.Background(), openai.ChatCompletionNewParams{
|
| 118 |
+
Model: openai.F(req.Model),
|
| 119 |
+
Messages: openai.F(messages),
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
return p.handleStream(stream, writer)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
func (p *OpenAIProvider) handleStream(stream *ssestream.Stream[openai.ChatCompletionChunk], writer http.ResponseWriter) error {
|
| 126 |
+
flusher, ok := writer.(http.Flusher)
|
| 127 |
+
if !ok {
|
| 128 |
+
return fmt.Errorf("streaming not supported")
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
writer.Header().Set("Content-Type", "text/event-stream")
|
| 132 |
+
writer.Header().Set("Cache-Control", "no-cache")
|
| 133 |
+
writer.Header().Set("Connection", "keep-alive")
|
| 134 |
+
|
| 135 |
+
for stream.Next() {
|
| 136 |
+
chunk := stream.Current()
|
| 137 |
+
data, _ := json.Marshal(chunk)
|
| 138 |
+
fmt.Fprintf(writer, "data: %s\n\n", data)
|
| 139 |
+
flusher.Flush()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if err := stream.Err(); err != nil {
|
| 143 |
+
return err
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
fmt.Fprintf(writer, "data: [DONE]\n\n")
|
| 147 |
+
flusher.Flush()
|
| 148 |
+
return nil
|
| 149 |
+
}
|
internal/service/provider/proxy.go
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package provider
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"math/rand"
|
| 7 |
+
"net"
|
| 8 |
+
"net/http"
|
| 9 |
+
"net/url"
|
| 10 |
+
"os"
|
| 11 |
+
"strings"
|
| 12 |
+
"sync"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"golang.org/x/net/proxy"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// ProxyPool 代理池管理器
|
| 19 |
+
type ProxyPool struct {
|
| 20 |
+
proxies []string
|
| 21 |
+
mu sync.RWMutex
|
| 22 |
+
index int
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
var (
|
| 26 |
+
globalProxyPool *ProxyPool
|
| 27 |
+
once sync.Once
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
// GetProxyPool 获取全局代理池实例
|
| 31 |
+
func GetProxyPool() *ProxyPool {
|
| 32 |
+
once.Do(func() {
|
| 33 |
+
globalProxyPool = NewProxyPool()
|
| 34 |
+
})
|
| 35 |
+
return globalProxyPool
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// NewProxyPool 创建新的代理池
|
| 39 |
+
func NewProxyPool() *ProxyPool {
|
| 40 |
+
pool := &ProxyPool{
|
| 41 |
+
proxies: make([]string, 0),
|
| 42 |
+
}
|
| 43 |
+
pool.loadProxiesFromEnv()
|
| 44 |
+
return pool
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// loadProxiesFromEnv 从环境变量加载代理列表
|
| 48 |
+
func (p *ProxyPool) loadProxiesFromEnv() {
|
| 49 |
+
p.mu.Lock()
|
| 50 |
+
defer p.mu.Unlock()
|
| 51 |
+
|
| 52 |
+
proxyEnv := os.Getenv("SOCKS_PROXY_POOL")
|
| 53 |
+
if proxyEnv == "" {
|
| 54 |
+
return
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// 解析逗号分隔的代理列表
|
| 58 |
+
proxiesStr := strings.Split(proxyEnv, ",")
|
| 59 |
+
for _, proxyStr := range proxiesStr {
|
| 60 |
+
proxyStr = strings.TrimSpace(proxyStr)
|
| 61 |
+
if proxyStr != "" {
|
| 62 |
+
p.proxies = append(p.proxies, proxyStr)
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// GetNextProxy 获取下一个代理(轮询方式)
|
| 68 |
+
func (p *ProxyPool) GetNextProxy() string {
|
| 69 |
+
p.mu.Lock()
|
| 70 |
+
defer p.mu.Unlock()
|
| 71 |
+
|
| 72 |
+
if len(p.proxies) == 0 {
|
| 73 |
+
return ""
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
proxy := p.proxies[p.index]
|
| 77 |
+
p.index = (p.index + 1) % len(p.proxies)
|
| 78 |
+
return proxy
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// GetRandomProxy 获取随机代理
|
| 82 |
+
func (p *ProxyPool) GetRandomProxy() string {
|
| 83 |
+
p.mu.RLock()
|
| 84 |
+
defer p.mu.RUnlock()
|
| 85 |
+
|
| 86 |
+
if len(p.proxies) == 0 {
|
| 87 |
+
return ""
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
index := rand.Intn(len(p.proxies))
|
| 91 |
+
return p.proxies[index]
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// HasProxies 检查是否有可用代理
|
| 95 |
+
func (p *ProxyPool) HasProxies() bool {
|
| 96 |
+
p.mu.RLock()
|
| 97 |
+
defer p.mu.RUnlock()
|
| 98 |
+
return len(p.proxies) > 0
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Count 返回代理数量
|
| 102 |
+
func (p *ProxyPool) Count() int {
|
| 103 |
+
p.mu.RLock()
|
| 104 |
+
defer p.mu.RUnlock()
|
| 105 |
+
return len(p.proxies)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// GetAllProxies 获取所有代理列表(用于测试)
|
| 109 |
+
func (p *ProxyPool) GetAllProxies() []string {
|
| 110 |
+
p.mu.RLock()
|
| 111 |
+
defer p.mu.RUnlock()
|
| 112 |
+
result := make([]string, len(p.proxies))
|
| 113 |
+
copy(result, p.proxies)
|
| 114 |
+
return result
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// createSOCKS5Transport 创建SOCKS5代理传输层
|
| 118 |
+
func createSOCKS5Transport(proxyURL string, timeout time.Duration) (*http.Transport, error) {
|
| 119 |
+
// 处理自定义格式:socks5://host:port:username:password
|
| 120 |
+
// 转换为标准格式:socks5://username:password@host:port
|
| 121 |
+
if strings.Contains(proxyURL, "socks5://") && strings.Count(proxyURL, ":") == 4 {
|
| 122 |
+
// 解析自定义格式
|
| 123 |
+
parts := strings.Split(proxyURL, ":")
|
| 124 |
+
if len(parts) == 5 {
|
| 125 |
+
// parts[0] = "socks5", parts[1] = "//host", parts[2] = "port", parts[3] = "username", parts[4] = "password"
|
| 126 |
+
host := strings.TrimPrefix(parts[1], "//")
|
| 127 |
+
port := parts[2]
|
| 128 |
+
username := parts[3]
|
| 129 |
+
password := parts[4]
|
| 130 |
+
|
| 131 |
+
// 重构为标准URL格式
|
| 132 |
+
proxyURL = fmt.Sprintf("socks5://%s:%s@%s:%s", username, password, host, port)
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
u, err := url.Parse(proxyURL)
|
| 137 |
+
if err != nil {
|
| 138 |
+
return nil, fmt.Errorf("解析代理URL失败: %v", err)
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
if u.Scheme != "socks5" {
|
| 142 |
+
return nil, fmt.Errorf("仅支持SOCKS5代理")
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// 解析用户名和密码
|
| 146 |
+
var auth *proxy.Auth
|
| 147 |
+
if u.User != nil {
|
| 148 |
+
password, _ := u.User.Password()
|
| 149 |
+
auth = &proxy.Auth{
|
| 150 |
+
User: u.User.Username(),
|
| 151 |
+
Password: password,
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// 创建SOCKS5拨号器
|
| 156 |
+
dialer, err := proxy.SOCKS5("tcp", u.Host, auth, proxy.Direct)
|
| 157 |
+
if err != nil {
|
| 158 |
+
return nil, fmt.Errorf("创建SOCKS5拨号器失败: %v", err)
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
transport := &http.Transport{
|
| 162 |
+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
| 163 |
+
return dialer.Dial(network, addr)
|
| 164 |
+
},
|
| 165 |
+
MaxIdleConns: 100,
|
| 166 |
+
IdleConnTimeout: 90 * time.Second,
|
| 167 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
return transport, nil
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// parseCustomProxyURL 解析自定义代理URL格式
|
| 174 |
+
func parseCustomProxyURL(proxyURL string) string {
|
| 175 |
+
// 处理自定义格式:socks5://host:port:username:password
|
| 176 |
+
// 转换为标准格式:socks5://username:password@host:port
|
| 177 |
+
if strings.Contains(proxyURL, "socks5://") && strings.Count(proxyURL, ":") == 4 {
|
| 178 |
+
// 解析自定义格式
|
| 179 |
+
parts := strings.Split(proxyURL, ":")
|
| 180 |
+
if len(parts) == 5 {
|
| 181 |
+
// parts[0] = "socks5", parts[1] = "//host", parts[2] = "port", parts[3] = "username", parts[4] = "password"
|
| 182 |
+
host := strings.TrimPrefix(parts[1], "//")
|
| 183 |
+
port := parts[2]
|
| 184 |
+
username := parts[3]
|
| 185 |
+
password := parts[4]
|
| 186 |
+
|
| 187 |
+
// 重构为标准URL格式
|
| 188 |
+
return fmt.Sprintf("socks5://%s:%s@%s:%s", username, password, host, port)
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
return proxyURL
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// NewHTTPClientWithProxy 创建带指定代理的HTTP客户端
|
| 195 |
+
func NewHTTPClientWithProxy(proxyURL string, timeout time.Duration) (*http.Client, error) {
|
| 196 |
+
if timeout == 0 {
|
| 197 |
+
timeout = 600 * time.Second // 10分钟超时,支持长时间流式响应
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if proxyURL == "" {
|
| 201 |
+
// 没有代理,使用默认客户端
|
| 202 |
+
return &http.Client{
|
| 203 |
+
Transport: &http.Transport{
|
| 204 |
+
DialContext: (&net.Dialer{
|
| 205 |
+
Timeout: 30 * time.Second,
|
| 206 |
+
KeepAlive: 30 * time.Second,
|
| 207 |
+
}).DialContext,
|
| 208 |
+
MaxIdleConns: 100,
|
| 209 |
+
IdleConnTimeout: 90 * time.Second,
|
| 210 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 211 |
+
},
|
| 212 |
+
Timeout: timeout,
|
| 213 |
+
}, nil
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// 转换自定义格式到标准格式
|
| 217 |
+
standardURL := parseCustomProxyURL(proxyURL)
|
| 218 |
+
|
| 219 |
+
// 解析代理URL
|
| 220 |
+
u, err := url.Parse(standardURL)
|
| 221 |
+
if err != nil {
|
| 222 |
+
return nil, fmt.Errorf("解析代理URL失败: %v", err)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
var transport *http.Transport
|
| 226 |
+
|
| 227 |
+
if u.Scheme == "socks5" {
|
| 228 |
+
// SOCKS5代理 - 使用转换后的标准URL
|
| 229 |
+
transport, err = createSOCKS5Transport(standardURL, timeout)
|
| 230 |
+
if err != nil {
|
| 231 |
+
return nil, err
|
| 232 |
+
}
|
| 233 |
+
} else {
|
| 234 |
+
// HTTP代理
|
| 235 |
+
transport = &http.Transport{
|
| 236 |
+
Proxy: http.ProxyURL(u),
|
| 237 |
+
DialContext: (&net.Dialer{
|
| 238 |
+
Timeout: 30 * time.Second,
|
| 239 |
+
KeepAlive: 30 * time.Second,
|
| 240 |
+
}).DialContext,
|
| 241 |
+
MaxIdleConns: 100,
|
| 242 |
+
IdleConnTimeout: 90 * time.Second,
|
| 243 |
+
TLSHandshakeTimeout: 10 * time.Second,
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
return &http.Client{
|
| 248 |
+
Transport: transport,
|
| 249 |
+
Timeout: timeout,
|
| 250 |
+
}, nil
|
| 251 |
+
}
|