| package service |
|
|
| import ( |
| "context" |
| "fmt" |
| "net" |
| "net/http" |
| "net/url" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/setting/system_setting" |
|
|
| "golang.org/x/net/proxy" |
| ) |
|
|
| var ( |
| httpClient *http.Client |
| proxyClientLock sync.Mutex |
| proxyClients = make(map[string]*http.Client) |
| ) |
|
|
| func checkRedirect(req *http.Request, via []*http.Request) error { |
| fetchSetting := system_setting.GetFetchSetting() |
| urlStr := req.URL.String() |
| if err := common.ValidateURLWithFetchSetting(urlStr, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { |
| return fmt.Errorf("redirect to %s blocked: %v", urlStr, err) |
| } |
| if len(via) >= 10 { |
| return fmt.Errorf("stopped after 10 redirects") |
| } |
| return nil |
| } |
|
|
| func InitHttpClient() { |
| if common.RelayTimeout == 0 { |
| httpClient = &http.Client{ |
| CheckRedirect: checkRedirect, |
| } |
| } else { |
| httpClient = &http.Client{ |
| Timeout: time.Duration(common.RelayTimeout) * time.Second, |
| CheckRedirect: checkRedirect, |
| } |
| } |
| } |
|
|
| func GetHttpClient() *http.Client { |
| return httpClient |
| } |
|
|
| |
| func ResetProxyClientCache() { |
| proxyClientLock.Lock() |
| defer proxyClientLock.Unlock() |
| for _, client := range proxyClients { |
| if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { |
| transport.CloseIdleConnections() |
| } |
| } |
| proxyClients = make(map[string]*http.Client) |
| } |
|
|
| |
| func NewProxyHttpClient(proxyURL string) (*http.Client, error) { |
| if proxyURL == "" { |
| return http.DefaultClient, nil |
| } |
|
|
| proxyClientLock.Lock() |
| if client, ok := proxyClients[proxyURL]; ok { |
| proxyClientLock.Unlock() |
| return client, nil |
| } |
| proxyClientLock.Unlock() |
|
|
| parsedURL, err := url.Parse(proxyURL) |
| if err != nil { |
| return nil, err |
| } |
|
|
| switch parsedURL.Scheme { |
| case "http", "https": |
| client := &http.Client{ |
| Transport: &http.Transport{ |
| Proxy: http.ProxyURL(parsedURL), |
| }, |
| CheckRedirect: checkRedirect, |
| } |
| client.Timeout = time.Duration(common.RelayTimeout) * time.Second |
| proxyClientLock.Lock() |
| proxyClients[proxyURL] = client |
| proxyClientLock.Unlock() |
| return client, nil |
|
|
| case "socks5", "socks5h": |
| |
| var auth *proxy.Auth |
| if parsedURL.User != nil { |
| auth = &proxy.Auth{ |
| User: parsedURL.User.Username(), |
| Password: "", |
| } |
| if password, ok := parsedURL.User.Password(); ok { |
| auth.Password = password |
| } |
| } |
|
|
| |
| |
| dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) |
| if err != nil { |
| return nil, err |
| } |
|
|
| client := &http.Client{ |
| Transport: &http.Transport{ |
| DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
| return dialer.Dial(network, addr) |
| }, |
| }, |
| CheckRedirect: checkRedirect, |
| } |
| client.Timeout = time.Duration(common.RelayTimeout) * time.Second |
| proxyClientLock.Lock() |
| proxyClients[proxyURL] = client |
| proxyClientLock.Unlock() |
| return client, nil |
|
|
| default: |
| return nil, fmt.Errorf("unsupported proxy scheme: %s, must be http, https, socks5 or socks5h", parsedURL.Scheme) |
| } |
| } |
|
|