|
|
package channel |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"errors" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
common2 "github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/logger" |
|
|
"github.com/QuantumNous/new-api/relay/common" |
|
|
"github.com/QuantumNous/new-api/relay/constant" |
|
|
"github.com/QuantumNous/new-api/relay/helper" |
|
|
"github.com/QuantumNous/new-api/service" |
|
|
"github.com/QuantumNous/new-api/setting/operation_setting" |
|
|
"github.com/QuantumNous/new-api/types" |
|
|
|
|
|
"github.com/bytedance/gopkg/util/gopool" |
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/gorilla/websocket" |
|
|
) |
|
|
|
|
|
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { |
|
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { |
|
|
|
|
|
} else if info.RelayMode == constant.RelayModeImagesEdits { |
|
|
|
|
|
} else if info.RelayMode == constant.RelayModeRealtime { |
|
|
|
|
|
} else { |
|
|
req.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
|
|
req.Set("Accept", c.Request.Header.Get("Accept")) |
|
|
if info.IsStream && c.Request.Header.Get("Accept") == "" { |
|
|
req.Set("Accept", "text/event-stream") |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) { |
|
|
headerOverride := make(map[string]string) |
|
|
for k, v := range info.HeadersOverride { |
|
|
str, ok := v.(string) |
|
|
if !ok { |
|
|
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid) |
|
|
} |
|
|
|
|
|
|
|
|
if strings.Contains(str, "{api_key}") { |
|
|
str = strings.ReplaceAll(str, "{api_key}", info.ApiKey) |
|
|
} |
|
|
|
|
|
headerOverride[k] = str |
|
|
} |
|
|
return headerOverride, nil |
|
|
} |
|
|
|
|
|
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
|
|
fullRequestURL, err := a.GetRequestURL(info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("get request url failed: %w", err) |
|
|
} |
|
|
if common2.DebugEnabled { |
|
|
println("fullRequestURL:", fullRequestURL) |
|
|
} |
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("new request failed: %w", err) |
|
|
} |
|
|
headers := req.Header |
|
|
headerOverride, err := processHeaderOverride(info) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
for key, value := range headerOverride { |
|
|
headers.Set(key, value) |
|
|
} |
|
|
err = a.SetupRequestHeader(c, &headers, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("setup request header failed: %w", err) |
|
|
} |
|
|
resp, err := doRequest(c, req, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("do request failed: %w", err) |
|
|
} |
|
|
return resp, nil |
|
|
} |
|
|
|
|
|
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
|
|
fullRequestURL, err := a.GetRequestURL(info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("get request url failed: %w", err) |
|
|
} |
|
|
if common2.DebugEnabled { |
|
|
println("fullRequestURL:", fullRequestURL) |
|
|
} |
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("new request failed: %w", err) |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
|
|
headers := req.Header |
|
|
headerOverride, err := processHeaderOverride(info) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
for key, value := range headerOverride { |
|
|
headers.Set(key, value) |
|
|
} |
|
|
err = a.SetupRequestHeader(c, &headers, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("setup request header failed: %w", err) |
|
|
} |
|
|
resp, err := doRequest(c, req, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("do request failed: %w", err) |
|
|
} |
|
|
return resp, nil |
|
|
} |
|
|
|
|
|
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) { |
|
|
fullRequestURL, err := a.GetRequestURL(info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("get request url failed: %w", err) |
|
|
} |
|
|
targetHeader := http.Header{} |
|
|
headerOverride, err := processHeaderOverride(info) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
for key, value := range headerOverride { |
|
|
targetHeader.Set(key, value) |
|
|
} |
|
|
err = a.SetupRequestHeader(c, &targetHeader, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("setup request header failed: %w", err) |
|
|
} |
|
|
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
|
|
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return targetConn, nil |
|
|
} |
|
|
|
|
|
func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { |
|
|
pingerCtx, stopPinger := context.WithCancel(context.Background()) |
|
|
|
|
|
gopool.Go(func() { |
|
|
defer func() { |
|
|
|
|
|
if r := recover(); r != nil { |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r)) |
|
|
} |
|
|
} |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping goroutine stopped.") |
|
|
} |
|
|
}() |
|
|
|
|
|
if pingInterval <= 0 { |
|
|
pingInterval = helper.DefaultPingInterval |
|
|
} |
|
|
|
|
|
ticker := time.NewTicker(pingInterval) |
|
|
|
|
|
defer func() { |
|
|
ticker.Stop() |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping ticker stopped") |
|
|
} |
|
|
}() |
|
|
|
|
|
var pingMutex sync.Mutex |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping goroutine started") |
|
|
} |
|
|
|
|
|
|
|
|
maxPingDuration := 120 * time.Minute |
|
|
pingTimeout := time.NewTimer(maxPingDuration) |
|
|
defer pingTimeout.Stop() |
|
|
|
|
|
for { |
|
|
select { |
|
|
|
|
|
case <-ticker.C: |
|
|
if err := sendPingData(c, &pingMutex); err != nil { |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping error, stopping goroutine:", err.Error()) |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
case <-pingerCtx.Done(): |
|
|
return |
|
|
|
|
|
case <-c.Request.Context().Done(): |
|
|
return |
|
|
|
|
|
case <-pingTimeout.C: |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping goroutine timeout, stopping") |
|
|
} |
|
|
return |
|
|
} |
|
|
} |
|
|
}) |
|
|
|
|
|
return stopPinger |
|
|
} |
|
|
|
|
|
func sendPingData(c *gin.Context, mutex *sync.Mutex) error { |
|
|
|
|
|
done := make(chan error, 1) |
|
|
go func() { |
|
|
mutex.Lock() |
|
|
defer mutex.Unlock() |
|
|
|
|
|
err := helper.PingData(c) |
|
|
if err != nil { |
|
|
logger.LogError(c, "SSE ping error: "+err.Error()) |
|
|
done <- err |
|
|
return |
|
|
} |
|
|
|
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping data sent.") |
|
|
} |
|
|
done <- nil |
|
|
}() |
|
|
|
|
|
|
|
|
select { |
|
|
case err := <-done: |
|
|
return err |
|
|
case <-time.After(10 * time.Second): |
|
|
return errors.New("SSE ping data send timeout") |
|
|
case <-c.Request.Context().Done(): |
|
|
return errors.New("request context cancelled during ping") |
|
|
} |
|
|
} |
|
|
|
|
|
func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
|
|
return doRequest(c, req, info) |
|
|
} |
|
|
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
|
|
var client *http.Client |
|
|
var err error |
|
|
if info.ChannelSetting.Proxy != "" { |
|
|
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("new proxy http client failed: %w", err) |
|
|
} |
|
|
} else { |
|
|
client = service.GetHttpClient() |
|
|
} |
|
|
|
|
|
var stopPinger context.CancelFunc |
|
|
if info.IsStream { |
|
|
helper.SetEventStreamHeaders(c) |
|
|
|
|
|
generalSettings := operation_setting.GetGeneralSetting() |
|
|
if generalSettings.PingIntervalEnabled && !info.DisablePing { |
|
|
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second |
|
|
stopPinger = startPingKeepAlive(c, pingInterval) |
|
|
|
|
|
defer func() { |
|
|
if stopPinger != nil { |
|
|
stopPinger() |
|
|
if common2.DebugEnabled { |
|
|
println("SSE ping goroutine stopped by defer") |
|
|
} |
|
|
} |
|
|
}() |
|
|
} |
|
|
} |
|
|
|
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
logger.LogError(c, "do request failed: "+err.Error()) |
|
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) |
|
|
} |
|
|
if resp == nil { |
|
|
return nil, errors.New("resp is nil") |
|
|
} |
|
|
|
|
|
_ = req.Body.Close() |
|
|
_ = c.Request.Body.Close() |
|
|
return resp, nil |
|
|
} |
|
|
|
|
|
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
|
|
fullRequestURL, err := a.BuildRequestURL(info) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("new request failed: %w", err) |
|
|
} |
|
|
req.GetBody = func() (io.ReadCloser, error) { |
|
|
return io.NopCloser(requestBody), nil |
|
|
} |
|
|
|
|
|
err = a.BuildRequestHeader(c, req, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("setup request header failed: %w", err) |
|
|
} |
|
|
resp, err := doRequest(c, req, info) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("do request failed: %w", err) |
|
|
} |
|
|
return resp, nil |
|
|
} |
|
|
|