|
|
package amp |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"compress/gzip" |
|
|
"context" |
|
|
"errors" |
|
|
"fmt" |
|
|
"io" |
|
|
"net" |
|
|
"net/http" |
|
|
"net/http/httputil" |
|
|
"net/url" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
log "github.com/sirupsen/logrus" |
|
|
) |
|
|
|
|
|
func removeQueryValuesMatching(req *http.Request, key string, match string) { |
|
|
if req == nil || req.URL == nil || match == "" { |
|
|
return |
|
|
} |
|
|
|
|
|
q := req.URL.Query() |
|
|
values, ok := q[key] |
|
|
if !ok || len(values) == 0 { |
|
|
return |
|
|
} |
|
|
|
|
|
kept := make([]string, 0, len(values)) |
|
|
for _, v := range values { |
|
|
if v == match { |
|
|
continue |
|
|
} |
|
|
kept = append(kept, v) |
|
|
} |
|
|
|
|
|
if len(kept) == 0 { |
|
|
q.Del(key) |
|
|
} else { |
|
|
q[key] = kept |
|
|
} |
|
|
req.URL.RawQuery = q.Encode() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type readCloser struct { |
|
|
r io.Reader |
|
|
c io.Closer |
|
|
} |
|
|
|
|
|
func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) } |
|
|
func (rc *readCloser) Close() error { return rc.c.Close() } |
|
|
|
|
|
|
|
|
|
|
|
func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) { |
|
|
parsed, err := url.Parse(upstreamURL) |
|
|
if err != nil { |
|
|
return nil, fmt.Errorf("invalid amp upstream url: %w", err) |
|
|
} |
|
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(parsed) |
|
|
originalDirector := proxy.Director |
|
|
|
|
|
|
|
|
proxy.Director = func(req *http.Request) { |
|
|
originalDirector(req) |
|
|
req.Host = parsed.Host |
|
|
|
|
|
|
|
|
|
|
|
req.Header.Del("Authorization") |
|
|
req.Header.Del("X-Api-Key") |
|
|
req.Header.Del("X-Goog-Api-Key") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clientKey := getClientAPIKeyFromContext(req.Context()) |
|
|
removeQueryValuesMatching(req, "key", clientKey) |
|
|
removeQueryValuesMatching(req, "auth_token", clientKey) |
|
|
|
|
|
|
|
|
if req.Header.Get("X-Request-ID") == "" { |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if key, err := secretSource.Get(req.Context()); err == nil && key != "" { |
|
|
req.Header.Set("X-Api-Key", key) |
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key)) |
|
|
} else if err != nil { |
|
|
log.Warnf("amp secret source error (continuing without auth): %v", err) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
proxy.ModifyResponse = func(resp *http.Response) error { |
|
|
|
|
|
|
|
|
if resp.StatusCode >= 500 { |
|
|
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) |
|
|
} else if resp.StatusCode >= 400 { |
|
|
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) |
|
|
} |
|
|
|
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 { |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
if resp.Header.Get("Content-Encoding") != "" { |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
if isStreamingResponse(resp) { |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
originalBody := resp.Body |
|
|
|
|
|
|
|
|
header := make([]byte, 2) |
|
|
n, _ := io.ReadFull(originalBody, header) |
|
|
|
|
|
|
|
|
|
|
|
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b { |
|
|
|
|
|
rest, err := io.ReadAll(originalBody) |
|
|
if err != nil { |
|
|
|
|
|
resp.Body = &readCloser{ |
|
|
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), |
|
|
c: originalBody, |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
gzippedData := append(header[:n], rest...) |
|
|
|
|
|
|
|
|
gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData)) |
|
|
if err != nil { |
|
|
log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err) |
|
|
|
|
|
_ = originalBody.Close() |
|
|
resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) |
|
|
return nil |
|
|
} |
|
|
|
|
|
decompressed, err := io.ReadAll(gzipReader) |
|
|
_ = gzipReader.Close() |
|
|
if err != nil { |
|
|
log.Warnf("amp proxy: gzip decompress error: %v", err) |
|
|
|
|
|
_ = originalBody.Close() |
|
|
resp.Body = io.NopCloser(bytes.NewReader(gzippedData)) |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
_ = originalBody.Close() |
|
|
|
|
|
|
|
|
resp.Body = io.NopCloser(bytes.NewReader(decompressed)) |
|
|
resp.ContentLength = int64(len(decompressed)) |
|
|
|
|
|
|
|
|
resp.Header.Del("Content-Encoding") |
|
|
resp.Header.Del("Content-Length") |
|
|
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) |
|
|
|
|
|
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed)) |
|
|
} else { |
|
|
|
|
|
|
|
|
resp.Body = &readCloser{ |
|
|
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody), |
|
|
c: originalBody, |
|
|
} |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { |
|
|
|
|
|
var errType string |
|
|
if errors.Is(err, context.DeadlineExceeded) { |
|
|
errType = "timeout" |
|
|
} else if errors.Is(err, context.Canceled) { |
|
|
errType = "canceled" |
|
|
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { |
|
|
errType = "dial_timeout" |
|
|
} else if _, ok := err.(net.Error); ok { |
|
|
errType = "network_error" |
|
|
} else { |
|
|
errType = "connection_error" |
|
|
} |
|
|
|
|
|
|
|
|
if errors.Is(err, context.Canceled) { |
|
|
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path) |
|
|
} else { |
|
|
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err) |
|
|
} |
|
|
|
|
|
rw.Header().Set("Content-Type", "application/json") |
|
|
rw.WriteHeader(http.StatusBadGateway) |
|
|
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) |
|
|
} |
|
|
|
|
|
return proxy, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func isStreamingResponse(resp *http.Response) bool { |
|
|
contentType := resp.Header.Get("Content-Type") |
|
|
|
|
|
|
|
|
if strings.Contains(contentType, "text/event-stream") { |
|
|
return true |
|
|
} |
|
|
|
|
|
return false |
|
|
} |
|
|
|
|
|
|
|
|
func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
proxy.ServeHTTP(c.Writer, c.Request) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func filterBetaFeatures(header, featureToRemove string) string { |
|
|
features := strings.Split(header, ",") |
|
|
filtered := make([]string, 0, len(features)) |
|
|
|
|
|
for _, feature := range features { |
|
|
trimmed := strings.TrimSpace(feature) |
|
|
if trimmed != "" && trimmed != featureToRemove { |
|
|
filtered = append(filtered, trimmed) |
|
|
} |
|
|
} |
|
|
|
|
|
return strings.Join(filtered, ",") |
|
|
} |
|
|
|