| package channel |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| common2 "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/relay/helper" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/setting/operation_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
| "github.com/gin-gonic/gin" |
| "github.com/gorilla/websocket" |
| ) |
|
|
| func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { |
| if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { |
| |
| } else if info.RelayMode == constant.RelayModeImagesEdits { |
| |
| } else if info.RelayMode == constant.RelayModeRealtime { |
| |
| } else { |
| req.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| req.Set("Accept", c.Request.Header.Get("Accept")) |
| if info.IsStream && c.Request.Header.Get("Accept") == "" { |
| req.Set("Accept", "text/event-stream") |
| } |
| } |
| } |
|
|
| |
| |
| func processHeaderOverride(info *common.RelayInfo) (map[string]string, error) { |
| headerOverride := make(map[string]string) |
| for k, v := range info.HeadersOverride { |
| str, ok := v.(string) |
| if !ok { |
| return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
|
|
| |
| if strings.Contains(str, "{api_key}") { |
| str = strings.ReplaceAll(str, "{api_key}", info.ApiKey) |
| } |
|
|
| headerOverride[k] = str |
| } |
| return headerOverride, nil |
| } |
|
|
| func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| if common2.DebugEnabled { |
| println("fullRequestURL:", fullRequestURL) |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| headers := req.Header |
| headerOverride, err := processHeaderOverride(info) |
| if err != nil { |
| return nil, err |
| } |
| for key, value := range headerOverride { |
| headers.Set(key, value) |
| } |
| err = a.SetupRequestHeader(c, &headers, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|
| func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| if common2.DebugEnabled { |
| println("fullRequestURL:", fullRequestURL) |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| |
| req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| headers := req.Header |
| headerOverride, err := processHeaderOverride(info) |
| if err != nil { |
| return nil, err |
| } |
| for key, value := range headerOverride { |
| headers.Set(key, value) |
| } |
| err = a.SetupRequestHeader(c, &headers, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|
| func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| targetHeader := http.Header{} |
| headerOverride, err := processHeaderOverride(info) |
| if err != nil { |
| return nil, err |
| } |
| for key, value := range headerOverride { |
| targetHeader.Set(key, value) |
| } |
| err = a.SetupRequestHeader(c, &targetHeader, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader) |
| if err != nil { |
| return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err) |
| } |
| |
| |
| |
| return targetConn, nil |
| } |
|
|
| func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { |
| pingerCtx, stopPinger := context.WithCancel(context.Background()) |
|
|
| gopool.Go(func() { |
| defer func() { |
| |
| if r := recover(); r != nil { |
| if common2.DebugEnabled { |
| println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r)) |
| } |
| } |
| if common2.DebugEnabled { |
| println("SSE ping goroutine stopped.") |
| } |
| }() |
|
|
| if pingInterval <= 0 { |
| pingInterval = helper.DefaultPingInterval |
| } |
|
|
| ticker := time.NewTicker(pingInterval) |
| |
| defer func() { |
| ticker.Stop() |
| if common2.DebugEnabled { |
| println("SSE ping ticker stopped") |
| } |
| }() |
|
|
| var pingMutex sync.Mutex |
| if common2.DebugEnabled { |
| println("SSE ping goroutine started") |
| } |
|
|
| |
| maxPingDuration := 120 * time.Minute |
| pingTimeout := time.NewTimer(maxPingDuration) |
| defer pingTimeout.Stop() |
|
|
| for { |
| select { |
| |
| case <-ticker.C: |
| if err := sendPingData(c, &pingMutex); err != nil { |
| if common2.DebugEnabled { |
| println("SSE ping error, stopping goroutine:", err.Error()) |
| } |
| return |
| } |
| |
| case <-pingerCtx.Done(): |
| return |
| |
| case <-c.Request.Context().Done(): |
| return |
| |
| case <-pingTimeout.C: |
| if common2.DebugEnabled { |
| println("SSE ping goroutine timeout, stopping") |
| } |
| return |
| } |
| } |
| }) |
|
|
| return stopPinger |
| } |
|
|
| func sendPingData(c *gin.Context, mutex *sync.Mutex) error { |
| |
| done := make(chan error, 1) |
| go func() { |
| mutex.Lock() |
| defer mutex.Unlock() |
|
|
| err := helper.PingData(c) |
| if err != nil { |
| logger.LogError(c, "SSE ping error: "+err.Error()) |
| done <- err |
| return |
| } |
|
|
| if common2.DebugEnabled { |
| println("SSE ping data sent.") |
| } |
| done <- nil |
| }() |
|
|
| |
| select { |
| case err := <-done: |
| return err |
| case <-time.After(10 * time.Second): |
| return errors.New("SSE ping data send timeout") |
| case <-c.Request.Context().Done(): |
| return errors.New("request context cancelled during ping") |
| } |
| } |
|
|
| func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
| return doRequest(c, req, info) |
| } |
| func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
| var client *http.Client |
| var err error |
| if info.ChannelSetting.Proxy != "" { |
| client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) |
| if err != nil { |
| return nil, fmt.Errorf("new proxy http client failed: %w", err) |
| } |
| } else { |
| client = service.GetHttpClient() |
| } |
|
|
| var stopPinger context.CancelFunc |
| if info.IsStream { |
| helper.SetEventStreamHeaders(c) |
| |
| generalSettings := operation_setting.GetGeneralSetting() |
| if generalSettings.PingIntervalEnabled && !info.DisablePing { |
| pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second |
| stopPinger = startPingKeepAlive(c, pingInterval) |
| |
| defer func() { |
| if stopPinger != nil { |
| stopPinger() |
| if common2.DebugEnabled { |
| println("SSE ping goroutine stopped by defer") |
| } |
| } |
| }() |
| } |
| } |
|
|
| resp, err := client.Do(req) |
| if err != nil { |
| logger.LogError(c, "do request failed: "+err.Error()) |
| return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) |
| } |
| if resp == nil { |
| return nil, errors.New("resp is nil") |
| } |
|
|
| _ = req.Body.Close() |
| _ = c.Request.Body.Close() |
| return resp, nil |
| } |
|
|
| func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.BuildRequestURL(info) |
| if err != nil { |
| return nil, err |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| req.GetBody = func() (io.ReadCloser, error) { |
| return io.NopCloser(requestBody), nil |
| } |
|
|
| err = a.BuildRequestHeader(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|