| | |
| | |
| | |
| |
|
| | |
| |
|
| | package httputil |
| |
|
| | import ( |
| | "bufio" |
| | "bytes" |
| | "context" |
| | "errors" |
| | "fmt" |
| | "io" |
| | "log" |
| | "net" |
| | "net/http" |
| | "net/http/httptest" |
| | "net/http/httptrace" |
| | "net/http/internal/ascii" |
| | "net/textproto" |
| | "net/url" |
| | "os" |
| | "reflect" |
| | "runtime" |
| | "slices" |
| | "strconv" |
| | "strings" |
| | "sync" |
| | "testing" |
| | "time" |
| | ) |
| |
|
| | const fakeHopHeader = "X-Fake-Hop-Header-For-Test" |
| |
|
| | func init() { |
| | inOurTests = true |
| | hopHeaders = append(hopHeaders, fakeHopHeader) |
| | } |
| |
|
| | func TestReverseProxy(t *testing.T) { |
| | const backendResponse = "I am the backend" |
| | const backendStatus = 404 |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if r.Method == "GET" && r.FormValue("mode") == "hangup" { |
| | c, _, _ := w.(http.Hijacker).Hijack() |
| | c.Close() |
| | return |
| | } |
| | if len(r.TransferEncoding) > 0 { |
| | t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) |
| | } |
| | if r.Header.Get("X-Forwarded-For") == "" { |
| | t.Errorf("didn't get X-Forwarded-For header") |
| | } |
| | if c := r.Header.Get("Connection"); c != "" { |
| | t.Errorf("handler got Connection header value %q", c) |
| | } |
| | if c := r.Header.Get("Te"); c != "trailers" { |
| | t.Errorf("handler got Te header value %q; want 'trailers'", c) |
| | } |
| | if c := r.Header.Get("Upgrade"); c != "" { |
| | t.Errorf("handler got Upgrade header value %q", c) |
| | } |
| | if c := r.Header.Get("Proxy-Connection"); c != "" { |
| | t.Errorf("handler got Proxy-Connection header value %q", c) |
| | } |
| | if g, e := r.Host, "some-name"; g != e { |
| | t.Errorf("backend got Host header %q, want %q", g, e) |
| | } |
| | w.Header().Set("Trailers", "not a special header field name") |
| | w.Header().Set("Trailer", "X-Trailer") |
| | w.Header().Set("X-Foo", "bar") |
| | w.Header().Set("Upgrade", "foo") |
| | w.Header().Set(fakeHopHeader, "foo") |
| | w.Header().Add("X-Multi-Value", "foo") |
| | w.Header().Add("X-Multi-Value", "bar") |
| | http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) |
| | w.WriteHeader(backendStatus) |
| | w.Write([]byte(backendResponse)) |
| | w.Header().Set("X-Trailer", "trailer_value") |
| | w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Host = "some-name" |
| | getReq.Header.Set("Connection", "close, TE") |
| | getReq.Header.Add("Te", "foo") |
| | getReq.Header.Add("Te", "bar, trailers") |
| | getReq.Header.Set("Proxy-Connection", "should be deleted") |
| | getReq.Header.Set("Upgrade", "foo") |
| | getReq.Close = true |
| | res, err := frontendClient.Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | if g, e := res.StatusCode, backendStatus; g != e { |
| | t.Errorf("got res.StatusCode %d; expected %d", g, e) |
| | } |
| | if g, e := res.Header.Get("X-Foo"), "bar"; g != e { |
| | t.Errorf("got X-Foo %q; expected %q", g, e) |
| | } |
| | if c := res.Header.Get(fakeHopHeader); c != "" { |
| | t.Errorf("got %s header value %q", fakeHopHeader, c) |
| | } |
| | if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { |
| | t.Errorf("header Trailers = %q; want %q", g, e) |
| | } |
| | if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { |
| | t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) |
| | } |
| | if g, e := len(res.Header["Set-Cookie"]), 1; g != e { |
| | t.Fatalf("got %d SetCookies, want %d", g, e) |
| | } |
| | if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { |
| | t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) |
| | } |
| | if cookie := res.Cookies()[0]; cookie.Name != "flavor" { |
| | t.Errorf("unexpected cookie %q", cookie.Name) |
| | } |
| | bodyBytes, _ := io.ReadAll(res.Body) |
| | if g, e := string(bodyBytes), backendResponse; g != e { |
| | t.Errorf("got body %q; expected %q", g, e) |
| | } |
| | if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { |
| | t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) |
| | } |
| | if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { |
| | t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) |
| | } |
| | res.Body.Close() |
| |
|
| | |
| | |
| | getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) |
| | getReq.Close = true |
| | res, err = frontendClient.Do(getReq) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | res.Body.Close() |
| | if res.StatusCode != http.StatusBadGateway { |
| | t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) |
| | } |
| |
|
| | } |
| |
|
| | |
| | |
| | func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { |
| | const fakeConnectionToken = "X-Fake-Connection-Token" |
| | const backendResponse = "I am the backend" |
| |
|
| | |
| | |
| | const someConnHeader = "X-Some-Conn-Header" |
| |
|
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if c := r.Header.Get("Connection"); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", "Connection", c) |
| | } |
| | if c := r.Header.Get(fakeConnectionToken); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) |
| | } |
| | if c := r.Header.Get(someConnHeader); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) |
| | } |
| | w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) |
| | w.Header().Add("Connection", someConnHeader) |
| | w.Header().Set(someConnHeader, "should be deleted") |
| | w.Header().Set(fakeConnectionToken, "should be deleted") |
| | io.WriteString(w, backendResponse) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | proxyHandler.ServeHTTP(w, r) |
| | if c := r.Header.Get(someConnHeader); c != "should be deleted" { |
| | t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") |
| | } |
| | if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { |
| | t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") |
| | } |
| | c := r.Header["Connection"] |
| | var cf []string |
| | for _, f := range c { |
| | for sf := range strings.SplitSeq(f, ",") { |
| | if sf = strings.TrimSpace(sf); sf != "" { |
| | cf = append(cf, sf) |
| | } |
| | } |
| | } |
| | slices.Sort(cf) |
| | expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} |
| | slices.Sort(expectedValues) |
| | if !slices.Equal(cf, expectedValues) { |
| | t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) |
| | } |
| | })) |
| | defer frontend.Close() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) |
| | getReq.Header.Add("Connection", someConnHeader) |
| | getReq.Header.Set(someConnHeader, "should be deleted") |
| | getReq.Header.Set(fakeConnectionToken, "should be deleted") |
| | res, err := frontend.Client().Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | bodyBytes, err := io.ReadAll(res.Body) |
| | if err != nil { |
| | t.Fatalf("reading body: %v", err) |
| | } |
| | if got, want := string(bodyBytes), backendResponse; got != want { |
| | t.Errorf("got body %q; want %q", got, want) |
| | } |
| | if c := res.Header.Get("Connection"); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", "Connection", c) |
| | } |
| | if c := res.Header.Get(someConnHeader); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) |
| | } |
| | if c := res.Header.Get(fakeConnectionToken); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) |
| | } |
| | } |
| |
|
| | func TestReverseProxyStripEmptyConnection(t *testing.T) { |
| | |
| | const backendResponse = "I am the backend" |
| |
|
| | |
| | |
| | const someConnHeader = "X-Some-Conn-Header" |
| |
|
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if c := r.Header.Values("Connection"); len(c) != 0 { |
| | t.Errorf("handler got header %q = %v; want empty", "Connection", c) |
| | } |
| | if c := r.Header.Get(someConnHeader); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) |
| | } |
| | w.Header().Add("Connection", "") |
| | w.Header().Add("Connection", someConnHeader) |
| | w.Header().Set(someConnHeader, "should be deleted") |
| | io.WriteString(w, backendResponse) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | proxyHandler.ServeHTTP(w, r) |
| | if c := r.Header.Get(someConnHeader); c != "should be deleted" { |
| | t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") |
| | } |
| | })) |
| | defer frontend.Close() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Header.Add("Connection", "") |
| | getReq.Header.Add("Connection", someConnHeader) |
| | getReq.Header.Set(someConnHeader, "should be deleted") |
| | res, err := frontend.Client().Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | bodyBytes, err := io.ReadAll(res.Body) |
| | if err != nil { |
| | t.Fatalf("reading body: %v", err) |
| | } |
| | if got, want := string(bodyBytes), backendResponse; got != want { |
| | t.Errorf("got body %q; want %q", got, want) |
| | } |
| | if c := res.Header.Get("Connection"); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", "Connection", c) |
| | } |
| | if c := res.Header.Get(someConnHeader); c != "" { |
| | t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) |
| | } |
| | } |
| |
|
| | func TestXForwardedFor(t *testing.T) { |
| | const prevForwardedFor = "client ip" |
| | const backendResponse = "I am the backend" |
| | const backendStatus = 404 |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if r.Header.Get("X-Forwarded-For") == "" { |
| | t.Errorf("didn't get X-Forwarded-For header") |
| | } |
| | if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { |
| | t.Errorf("X-Forwarded-For didn't contain prior data") |
| | } |
| | w.WriteHeader(backendStatus) |
| | w.Write([]byte(backendResponse)) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Header.Set("Connection", "close") |
| | getReq.Header.Set("X-Forwarded-For", prevForwardedFor) |
| | getReq.Close = true |
| | res, err := frontend.Client().Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | if g, e := res.StatusCode, backendStatus; g != e { |
| | t.Errorf("got res.StatusCode %d; expected %d", g, e) |
| | } |
| | bodyBytes, _ := io.ReadAll(res.Body) |
| | if g, e := string(bodyBytes), backendResponse; g != e { |
| | t.Errorf("got body %q; expected %q", g, e) |
| | } |
| | } |
| |
|
| | |
| | func TestXForwardedFor_Omit(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if v := r.Header.Get("X-Forwarded-For"); v != "" { |
| | t.Errorf("got X-Forwarded-For header: %q", v) |
| | } |
| | w.Write([]byte("hi")) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | oldDirector := proxyHandler.Director |
| | proxyHandler.Director = func(r *http.Request) { |
| | r.Header["X-Forwarded-For"] = nil |
| | oldDirector(r) |
| | } |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Host = "some-name" |
| | getReq.Close = true |
| | res, err := frontend.Client().Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | res.Body.Close() |
| | } |
| |
|
| | func TestReverseProxyRewriteStripsForwarded(t *testing.T) { |
| | headers := []string{ |
| | "Forwarded", |
| | "X-Forwarded-For", |
| | "X-Forwarded-Host", |
| | "X-Forwarded-Proto", |
| | } |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | for _, h := range headers { |
| | if v := r.Header.Get(h); v != "" { |
| | t.Errorf("got %v header: %q", h, v) |
| | } |
| | } |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.SetURL(backendURL) |
| | }, |
| | } |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Host = "some-name" |
| | getReq.Close = true |
| | for _, h := range headers { |
| | getReq.Header.Set(h, "x") |
| | } |
| | res, err := frontend.Client().Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | res.Body.Close() |
| | } |
| |
|
| | var proxyQueryTests = []struct { |
| | baseSuffix string |
| | reqSuffix string |
| | want string |
| | }{ |
| | {"", "", ""}, |
| | {"?sta=tic", "?us=er", "sta=tic&us=er"}, |
| | {"", "?us=er", "us=er"}, |
| | {"?sta=tic", "", "sta=tic"}, |
| | } |
| |
|
| | func TestReverseProxyQuery(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Set("X-Got-Query", r.URL.RawQuery) |
| | w.Write([]byte("hi")) |
| | })) |
| | defer backend.Close() |
| |
|
| | for i, tt := range proxyQueryTests { |
| | backendURL, err := url.Parse(backend.URL + tt.baseSuffix) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) |
| | req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) |
| | req.Close = true |
| | res, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("%d. Get: %v", i, err) |
| | } |
| | if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { |
| | t.Errorf("%d. got query %q; expected %q", i, g, e) |
| | } |
| | res.Body.Close() |
| | frontend.Close() |
| | } |
| | } |
| |
|
| | func TestReverseProxyFlushInterval(t *testing.T) { |
| | const expected = "hi" |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte(expected)) |
| | })) |
| | defer backend.Close() |
| |
|
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.FlushInterval = time.Microsecond |
| |
|
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontend.URL, nil) |
| | req.Close = true |
| | res, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { |
| | t.Errorf("got body %q; expected %q", bodyBytes, expected) |
| | } |
| | } |
| |
|
| | type mockFlusher struct { |
| | http.ResponseWriter |
| | flushed bool |
| | } |
| |
|
| | func (m *mockFlusher) Flush() { |
| | m.flushed = true |
| | } |
| |
|
| | type wrappedRW struct { |
| | http.ResponseWriter |
| | } |
| |
|
| | func (w *wrappedRW) Unwrap() http.ResponseWriter { |
| | return w.ResponseWriter |
| | } |
| |
|
| | func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { |
| | const expected = "hi" |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte(expected)) |
| | })) |
| | defer backend.Close() |
| |
|
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | mf := &mockFlusher{} |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.FlushInterval = -1 |
| | proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | mf.ResponseWriter = w |
| | w = &wrappedRW{mf} |
| | proxyHandler.ServeHTTP(w, r) |
| | }) |
| |
|
| | frontend := httptest.NewServer(proxyWithMiddleware) |
| | defer frontend.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontend.URL, nil) |
| | req.Close = true |
| | res, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { |
| | t.Errorf("got body %q; expected %q", bodyBytes, expected) |
| | } |
| | if !mf.flushed { |
| | t.Errorf("response writer was not flushed") |
| | } |
| | } |
| |
|
| | func TestReverseProxyFlushIntervalHeaders(t *testing.T) { |
| | const expected = "hi" |
| | stopCh := make(chan struct{}) |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Add("MyHeader", expected) |
| | w.WriteHeader(200) |
| | w.(http.Flusher).Flush() |
| | <-stopCh |
| | })) |
| | defer backend.Close() |
| | defer close(stopCh) |
| |
|
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.FlushInterval = time.Microsecond |
| |
|
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontend.URL, nil) |
| | req.Close = true |
| |
|
| | ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) |
| | defer cancel() |
| | req = req.WithContext(ctx) |
| |
|
| | res, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| |
|
| | if res.Header.Get("MyHeader") != expected { |
| | t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) |
| | } |
| | } |
| |
|
| | func TestReverseProxyCancellation(t *testing.T) { |
| | const backendResponse = "I am the backend" |
| |
|
| | reqInFlight := make(chan struct{}) |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | close(reqInFlight) |
| |
|
| | select { |
| | case <-time.After(10 * time.Second): |
| | |
| | |
| | t.Error("Handler never saw CloseNotify") |
| | return |
| | case <-w.(http.CloseNotifier).CloseNotify(): |
| | } |
| |
|
| | w.WriteHeader(http.StatusOK) |
| | w.Write([]byte(backendResponse)) |
| | })) |
| |
|
| | defer backend.Close() |
| |
|
| | backend.Config.ErrorLog = log.New(io.Discard, "", 0) |
| |
|
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| |
|
| | |
| | |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| |
|
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | go func() { |
| | <-reqInFlight |
| | frontendClient.Transport.(*http.Transport).CancelRequest(getReq) |
| | }() |
| | res, err := frontendClient.Do(getReq) |
| | if res != nil { |
| | t.Errorf("got response %v; want nil", res.Status) |
| | } |
| | if err == nil { |
| | |
| | |
| | |
| | t.Error("Server.Client().Do() returned nil error; want non-nil error") |
| | } |
| | } |
| |
|
| | func req(t *testing.T, v string) *http.Request { |
| | req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | return req |
| | } |
| |
|
| | |
| | func TestNilBody(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte("hi")) |
| | })) |
| | defer backend.Close() |
| |
|
| | frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { |
| | backURL, _ := url.Parse(backend.URL) |
| | rp := NewSingleHostReverseProxy(backURL) |
| | r := req(t, "GET / HTTP/1.0\r\n\r\n") |
| | r.Body = nil |
| | rp.ServeHTTP(w, r) |
| | })) |
| | defer frontend.Close() |
| |
|
| | res, err := http.Get(frontend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | defer res.Body.Close() |
| | slurp, err := io.ReadAll(res.Body) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | if string(slurp) != "hi" { |
| | t.Errorf("Got %q; want %q", slurp, "hi") |
| | } |
| | } |
| |
|
| | |
| | func TestUserAgentHeader(t *testing.T) { |
| | var gotUA string |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | gotUA = r.Header.Get("User-Agent") |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | proxyHandler := new(ReverseProxy) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | proxyHandler.Director = func(req *http.Request) { |
| | req.URL = backendURL |
| | } |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | for _, sentUA := range []string{"explicit UA", ""} { |
| | getReq, _ := http.NewRequest("GET", frontend.URL, nil) |
| | getReq.Header.Set("User-Agent", sentUA) |
| | getReq.Close = true |
| | res, err := frontendClient.Do(getReq) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | res.Body.Close() |
| | if got, want := gotUA, sentUA; got != want { |
| | t.Errorf("got forwarded User-Agent %q, want %q", got, want) |
| | } |
| | } |
| | } |
| |
|
| | type bufferPool struct { |
| | get func() []byte |
| | put func([]byte) |
| | } |
| |
|
| | func (bp bufferPool) Get() []byte { return bp.get() } |
| | func (bp bufferPool) Put(v []byte) { bp.put(v) } |
| |
|
| | func TestReverseProxyGetPutBuffer(t *testing.T) { |
| | const msg = "hi" |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | io.WriteString(w, msg) |
| | })) |
| | defer backend.Close() |
| |
|
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | var ( |
| | mu sync.Mutex |
| | log []string |
| | ) |
| | addLog := func(event string) { |
| | mu.Lock() |
| | defer mu.Unlock() |
| | log = append(log, event) |
| | } |
| | rp := NewSingleHostReverseProxy(backendURL) |
| | const size = 1234 |
| | rp.BufferPool = bufferPool{ |
| | get: func() []byte { |
| | addLog("getBuf") |
| | return make([]byte, size) |
| | }, |
| | put: func(p []byte) { |
| | addLog("putBuf-" + strconv.Itoa(len(p))) |
| | }, |
| | } |
| | frontend := httptest.NewServer(rp) |
| | defer frontend.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontend.URL, nil) |
| | req.Close = true |
| | res, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | slurp, err := io.ReadAll(res.Body) |
| | res.Body.Close() |
| | if err != nil { |
| | t.Fatalf("reading body: %v", err) |
| | } |
| | if string(slurp) != msg { |
| | t.Errorf("msg = %q; want %q", slurp, msg) |
| | } |
| | wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} |
| | mu.Lock() |
| | defer mu.Unlock() |
| | if !slices.Equal(log, wantLog) { |
| | t.Errorf("Log events = %q; want %q", log, wantLog) |
| | } |
| | } |
| |
|
| | func TestReverseProxy_Post(t *testing.T) { |
| | const backendResponse = "I am the backend" |
| | const backendStatus = 200 |
| | var requestBody = bytes.Repeat([]byte("a"), 1<<20) |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | slurp, err := io.ReadAll(r.Body) |
| | if err != nil { |
| | t.Errorf("Backend body read = %v", err) |
| | } |
| | if len(slurp) != len(requestBody) { |
| | t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) |
| | } |
| | if !bytes.Equal(slurp, requestBody) { |
| | t.Error("Backend read wrong request body.") |
| | } |
| | w.Write([]byte(backendResponse)) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) |
| | res, err := frontend.Client().Do(postReq) |
| | if err != nil { |
| | t.Fatalf("Do: %v", err) |
| | } |
| | defer res.Body.Close() |
| | if g, e := res.StatusCode, backendStatus; g != e { |
| | t.Errorf("got res.StatusCode %d; expected %d", g, e) |
| | } |
| | bodyBytes, _ := io.ReadAll(res.Body) |
| | if g, e := string(bodyBytes), backendResponse; g != e { |
| | t.Errorf("got body %q; expected %q", g, e) |
| | } |
| | } |
| |
|
| | type RoundTripperFunc func(*http.Request) (*http.Response, error) |
| |
|
| | func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { |
| | return fn(req) |
| | } |
| |
|
| | |
| | func TestReverseProxy_NilBody(t *testing.T) { |
| | backendURL, _ := url.Parse("http://fake.tld/") |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
| | if req.Body != nil { |
| | t.Error("Body != nil; want a nil Body") |
| | } |
| | return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") |
| | }) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | res, err := frontend.Client().Get(frontend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | defer res.Body.Close() |
| | if res.StatusCode != 502 { |
| | t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) |
| | } |
| | } |
| |
|
| | |
| | func TestReverseProxy_AllocatedHeader(t *testing.T) { |
| | proxyHandler := new(ReverseProxy) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | proxyHandler.Director = func(*http.Request) {} |
| | proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { |
| | if req.Header == nil { |
| | t.Error("Header == nil; want a non-nil Header") |
| | } |
| | return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") |
| | }) |
| |
|
| | proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ |
| | Method: "GET", |
| | URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, |
| | Proto: "HTTP/1.0", |
| | ProtoMajor: 1, |
| | }) |
| | } |
| |
|
| | |
| | |
| | func TestReverseProxyModifyResponse(t *testing.T) { |
| | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) |
| | })) |
| | defer backendServer.Close() |
| |
|
| | rpURL, _ := url.Parse(backendServer.URL) |
| | rproxy := NewSingleHostReverseProxy(rpURL) |
| | rproxy.ErrorLog = log.New(io.Discard, "", 0) |
| | rproxy.ModifyResponse = func(resp *http.Response) error { |
| | if resp.Header.Get("X-Hit-Mod") != "true" { |
| | return fmt.Errorf("tried to by-pass proxy") |
| | } |
| | return nil |
| | } |
| |
|
| | frontendProxy := httptest.NewServer(rproxy) |
| | defer frontendProxy.Close() |
| |
|
| | tests := []struct { |
| | url string |
| | wantCode int |
| | }{ |
| | {frontendProxy.URL + "/mod", http.StatusOK}, |
| | {frontendProxy.URL + "/schedule", http.StatusBadGateway}, |
| | } |
| |
|
| | for i, tt := range tests { |
| | resp, err := http.Get(tt.url) |
| | if err != nil { |
| | t.Fatalf("failed to reach proxy: %v", err) |
| | } |
| | if g, e := resp.StatusCode, tt.wantCode; g != e { |
| | t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) |
| | } |
| | resp.Body.Close() |
| | } |
| | } |
| |
|
| | type failingRoundTripper struct{} |
| |
|
| | func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { |
| | return nil, errors.New("some error") |
| | } |
| |
|
| | type staticResponseRoundTripper struct{ res *http.Response } |
| |
|
| | func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { |
| | return rt.res, nil |
| | } |
| |
|
| | func TestReverseProxyErrorHandler(t *testing.T) { |
| | tests := []struct { |
| | name string |
| | wantCode int |
| | errorHandler func(http.ResponseWriter, *http.Request, error) |
| | transport http.RoundTripper |
| | modifyResponse func(*http.Response) error |
| | }{ |
| | { |
| | name: "default", |
| | wantCode: http.StatusBadGateway, |
| | }, |
| | { |
| | name: "errorhandler", |
| | wantCode: http.StatusTeapot, |
| | errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, |
| | }, |
| | { |
| | name: "modifyresponse_noerr", |
| | transport: staticResponseRoundTripper{ |
| | &http.Response{StatusCode: 345, Body: http.NoBody}, |
| | }, |
| | modifyResponse: func(res *http.Response) error { |
| | res.StatusCode++ |
| | return nil |
| | }, |
| | errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, |
| | wantCode: 346, |
| | }, |
| | { |
| | name: "modifyresponse_err", |
| | transport: staticResponseRoundTripper{ |
| | &http.Response{StatusCode: 345, Body: http.NoBody}, |
| | }, |
| | modifyResponse: func(res *http.Response) error { |
| | res.StatusCode++ |
| | return errors.New("some error to trigger errorHandler") |
| | }, |
| | errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, |
| | wantCode: http.StatusTeapot, |
| | }, |
| | } |
| |
|
| | for _, tt := range tests { |
| | t.Run(tt.name, func(t *testing.T) { |
| | target := &url.URL{ |
| | Scheme: "http", |
| | Host: "dummy.tld", |
| | Path: "/", |
| | } |
| | rproxy := NewSingleHostReverseProxy(target) |
| | rproxy.Transport = tt.transport |
| | rproxy.ModifyResponse = tt.modifyResponse |
| | if rproxy.Transport == nil { |
| | rproxy.Transport = failingRoundTripper{} |
| | } |
| | rproxy.ErrorLog = log.New(io.Discard, "", 0) |
| | if tt.errorHandler != nil { |
| | rproxy.ErrorHandler = tt.errorHandler |
| | } |
| | frontendProxy := httptest.NewServer(rproxy) |
| | defer frontendProxy.Close() |
| |
|
| | resp, err := http.Get(frontendProxy.URL + "/test") |
| | if err != nil { |
| | t.Fatalf("failed to reach proxy: %v", err) |
| | } |
| | if g, e := resp.StatusCode, tt.wantCode; g != e { |
| | t.Errorf("got res.StatusCode %d; expected %d", g, e) |
| | } |
| | resp.Body.Close() |
| | }) |
| | } |
| | } |
| |
|
| | |
| | func TestReverseProxy_CopyBuffer(t *testing.T) { |
| | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | out := "this call was relayed by the reverse proxy" |
| | |
| | w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) |
| | fmt.Fprintln(w, out) |
| | })) |
| | defer backendServer.Close() |
| |
|
| | rpURL, err := url.Parse(backendServer.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | var proxyLog bytes.Buffer |
| | rproxy := NewSingleHostReverseProxy(rpURL) |
| | rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) |
| | donec := make(chan bool, 1) |
| | frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | defer func() { donec <- true }() |
| | rproxy.ServeHTTP(w, r) |
| | })) |
| | defer frontendProxy.Close() |
| |
|
| | if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { |
| | t.Fatalf("want non-nil error") |
| | } |
| | |
| | |
| | |
| | |
| | <-donec |
| |
|
| | expected := []string{ |
| | "EOF", |
| | "read", |
| | } |
| | for _, phrase := range expected { |
| | if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { |
| | t.Errorf("expected log to contain phrase %q", phrase) |
| | } |
| | } |
| | } |
| |
|
| | type staticTransport struct { |
| | res *http.Response |
| | } |
| |
|
| | func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { |
| | return t.res, nil |
| | } |
| |
|
| | func BenchmarkServeHTTP(b *testing.B) { |
| | res := &http.Response{ |
| | StatusCode: 200, |
| | Body: io.NopCloser(strings.NewReader("")), |
| | } |
| | proxy := &ReverseProxy{ |
| | Director: func(*http.Request) {}, |
| | Transport: &staticTransport{res}, |
| | } |
| |
|
| | w := httptest.NewRecorder() |
| | r := httptest.NewRequest("GET", "/", nil) |
| |
|
| | b.ReportAllocs() |
| | for i := 0; i < b.N; i++ { |
| | proxy.ServeHTTP(w, r) |
| | } |
| | } |
| |
|
| | func TestServeHTTPDeepCopy(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte("Hello Gopher!")) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | type result struct { |
| | before, after string |
| | } |
| |
|
| | resultChan := make(chan result, 1) |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | before := r.URL.String() |
| | proxyHandler.ServeHTTP(w, r) |
| | after := r.URL.String() |
| | resultChan <- result{before: before, after: after} |
| | })) |
| | defer frontend.Close() |
| |
|
| | want := result{before: "/", after: "/"} |
| |
|
| | res, err := frontend.Client().Get(frontend.URL) |
| | if err != nil { |
| | t.Fatalf("Do: %v", err) |
| | } |
| | res.Body.Close() |
| |
|
| | got := <-resultChan |
| | if got != want { |
| | t.Errorf("got = %+v; want = %+v", got, want) |
| | } |
| | } |
| |
|
| | |
| | |
| | func TestClonesRequestHeaders(t *testing.T) { |
| | log.SetOutput(io.Discard) |
| | defer log.SetOutput(os.Stderr) |
| | req, _ := http.NewRequest("GET", "http://foo.tld/", nil) |
| | req.RemoteAddr = "1.2.3.4:56789" |
| | rp := &ReverseProxy{ |
| | Director: func(req *http.Request) { |
| | req.Header.Set("From-Director", "1") |
| | }, |
| | Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { |
| | if v := req.Header.Get("From-Director"); v != "1" { |
| | t.Errorf("From-Directory value = %q; want 1", v) |
| | } |
| | return nil, io.EOF |
| | }), |
| | } |
| | rp.ServeHTTP(httptest.NewRecorder(), req) |
| |
|
| | for _, h := range []string{ |
| | "From-Director", |
| | "X-Forwarded-For", |
| | } { |
| | if req.Header.Get(h) != "" { |
| | t.Errorf("%v header mutation modified caller's request", h) |
| | } |
| | } |
| | } |
| |
|
| | type roundTripperFunc func(req *http.Request) (*http.Response, error) |
| |
|
| | func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { |
| | return fn(req) |
| | } |
| |
|
| | func TestModifyResponseClosesBody(t *testing.T) { |
| | req, _ := http.NewRequest("GET", "http://foo.tld/", nil) |
| | req.RemoteAddr = "1.2.3.4:56789" |
| | closeCheck := new(checkCloser) |
| | logBuf := new(strings.Builder) |
| | outErr := errors.New("ModifyResponse error") |
| | rp := &ReverseProxy{ |
| | Director: func(req *http.Request) {}, |
| | Transport: &staticTransport{&http.Response{ |
| | StatusCode: 200, |
| | Body: closeCheck, |
| | }}, |
| | ErrorLog: log.New(logBuf, "", 0), |
| | ModifyResponse: func(*http.Response) error { |
| | return outErr |
| | }, |
| | } |
| | rec := httptest.NewRecorder() |
| | rp.ServeHTTP(rec, req) |
| | res := rec.Result() |
| | if g, e := res.StatusCode, http.StatusBadGateway; g != e { |
| | t.Errorf("got res.StatusCode %d; expected %d", g, e) |
| | } |
| | if !closeCheck.closed { |
| | t.Errorf("body should have been closed") |
| | } |
| | if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { |
| | t.Errorf("ErrorLog %q does not contain %q", g, e) |
| | } |
| | } |
| |
|
| | type checkCloser struct { |
| | closed bool |
| | } |
| |
|
| | func (cc *checkCloser) Close() error { |
| | cc.closed = true |
| | return nil |
| | } |
| |
|
| | func (cc *checkCloser) Read(b []byte) (int, error) { |
| | return len(b), nil |
| | } |
| |
|
| | |
| | func TestReverseProxy_PanicBodyError(t *testing.T) { |
| | log.SetOutput(io.Discard) |
| | defer log.SetOutput(os.Stderr) |
| | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | out := "this call was relayed by the reverse proxy" |
| | |
| | w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) |
| | fmt.Fprintln(w, out) |
| | })) |
| | defer backendServer.Close() |
| |
|
| | rpURL, err := url.Parse(backendServer.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | rproxy := NewSingleHostReverseProxy(rpURL) |
| |
|
| | |
| | |
| | defer func() { |
| | err := recover() |
| | if err == nil { |
| | t.Fatal("handler should have panicked") |
| | } |
| | if err != http.ErrAbortHandler { |
| | t.Fatal("expected ErrAbortHandler, got", err) |
| | } |
| | }() |
| | req, _ := http.NewRequest("GET", "http://foo.tld/", nil) |
| | rproxy.ServeHTTP(httptest.NewRecorder(), req) |
| | } |
| |
|
| | |
| | func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | out := "this call was relayed by the reverse proxy" |
| | |
| | w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) |
| | fmt.Fprintln(w, out) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | var wg sync.WaitGroup |
| | for i := 0; i < 2; i++ { |
| | wg.Add(1) |
| | go func() { |
| | defer wg.Done() |
| | for j := 0; j < 10; j++ { |
| | const reqLen = 6 * 1024 * 1024 |
| | req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) |
| | req.ContentLength = reqLen |
| | resp, _ := frontendClient.Transport.RoundTrip(req) |
| | if resp != nil { |
| | io.Copy(io.Discard, resp.Body) |
| | resp.Body.Close() |
| | } |
| | } |
| | }() |
| | } |
| | wg.Wait() |
| | } |
| |
|
| | func TestSelectFlushInterval(t *testing.T) { |
| | tests := []struct { |
| | name string |
| | p *ReverseProxy |
| | res *http.Response |
| | want time.Duration |
| | }{ |
| | { |
| | name: "default", |
| | res: &http.Response{}, |
| | p: &ReverseProxy{FlushInterval: 123}, |
| | want: 123, |
| | }, |
| | { |
| | name: "server-sent events overrides non-zero", |
| | res: &http.Response{ |
| | Header: http.Header{ |
| | "Content-Type": {"text/event-stream"}, |
| | }, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 123}, |
| | want: -1, |
| | }, |
| | { |
| | name: "server-sent events overrides zero", |
| | res: &http.Response{ |
| | Header: http.Header{ |
| | "Content-Type": {"text/event-stream"}, |
| | }, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 0}, |
| | want: -1, |
| | }, |
| | { |
| | name: "server-sent events with media-type parameters overrides non-zero", |
| | res: &http.Response{ |
| | Header: http.Header{ |
| | "Content-Type": {"text/event-stream;charset=utf-8"}, |
| | }, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 123}, |
| | want: -1, |
| | }, |
| | { |
| | name: "server-sent events with media-type parameters overrides zero", |
| | res: &http.Response{ |
| | Header: http.Header{ |
| | "Content-Type": {"text/event-stream;charset=utf-8"}, |
| | }, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 0}, |
| | want: -1, |
| | }, |
| | { |
| | name: "Content-Length: -1, overrides non-zero", |
| | res: &http.Response{ |
| | ContentLength: -1, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 123}, |
| | want: -1, |
| | }, |
| | { |
| | name: "Content-Length: -1, overrides zero", |
| | res: &http.Response{ |
| | ContentLength: -1, |
| | }, |
| | p: &ReverseProxy{FlushInterval: 0}, |
| | want: -1, |
| | }, |
| | } |
| | for _, tt := range tests { |
| | t.Run(tt.name, func(t *testing.T) { |
| | got := tt.p.flushInterval(tt.res) |
| | if got != tt.want { |
| | t.Errorf("flushLatency = %v; want %v", got, tt.want) |
| | } |
| | }) |
| | } |
| | } |
| |
|
| | func TestReverseProxyWebSocket(t *testing.T) { |
| | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if upgradeType(r.Header) != "websocket" { |
| | t.Error("unexpected backend request") |
| | http.Error(w, "unexpected request", 400) |
| | return |
| | } |
| | c, _, err := w.(http.Hijacker).Hijack() |
| | if err != nil { |
| | t.Error(err) |
| | return |
| | } |
| | defer c.Close() |
| | io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") |
| | bs := bufio.NewScanner(c) |
| | if !bs.Scan() { |
| | t.Errorf("backend failed to read line from client: %v", bs.Err()) |
| | return |
| | } |
| | fmt.Fprintf(c, "backend got %q\n", bs.Text()) |
| | })) |
| | defer backendServer.Close() |
| |
|
| | backURL, _ := url.Parse(backendServer.URL) |
| | rproxy := NewSingleHostReverseProxy(backURL) |
| | rproxy.ErrorLog = log.New(io.Discard, "", 0) |
| | rproxy.ModifyResponse = func(res *http.Response) error { |
| | res.Header.Add("X-Modified", "true") |
| | return nil |
| | } |
| |
|
| | handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { |
| | rw.Header().Set("X-Header", "X-Value") |
| | rproxy.ServeHTTP(rw, req) |
| | if got, want := rw.Header().Get("X-Modified"), "true"; got != want { |
| | t.Errorf("response writer X-Modified header = %q; want %q", got, want) |
| | } |
| | }) |
| |
|
| | frontendProxy := httptest.NewServer(handler) |
| | defer frontendProxy.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontendProxy.URL, nil) |
| | req.Header.Set("Connection", "Upgrade") |
| | req.Header.Set("Upgrade", "websocket") |
| |
|
| | c := frontendProxy.Client() |
| | res, err := c.Do(req) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | if res.StatusCode != 101 { |
| | t.Fatalf("status = %v; want 101", res.Status) |
| | } |
| |
|
| | got := res.Header.Get("X-Header") |
| | want := "X-Value" |
| | if got != want { |
| | t.Errorf("Header(XHeader) = %q; want %q", got, want) |
| | } |
| |
|
| | if !ascii.EqualFold(upgradeType(res.Header), "websocket") { |
| | t.Fatalf("not websocket upgrade; got %#v", res.Header) |
| | } |
| | rwc, ok := res.Body.(io.ReadWriteCloser) |
| | if !ok { |
| | t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) |
| | } |
| | defer rwc.Close() |
| |
|
| | if got, want := res.Header.Get("X-Modified"), "true"; got != want { |
| | t.Errorf("response X-Modified header = %q; want %q", got, want) |
| | } |
| |
|
| | io.WriteString(rwc, "Hello\n") |
| | bs := bufio.NewScanner(rwc) |
| | if !bs.Scan() { |
| | t.Fatalf("Scan: %v", bs.Err()) |
| | } |
| | got = bs.Text() |
| | want = `backend got "Hello"` |
| | if got != want { |
| | t.Errorf("got %#q, want %#q", got, want) |
| | } |
| | } |
| |
|
| | func TestReverseProxyWebSocketCancellation(t *testing.T) { |
| | n := 5 |
| | triggerCancelCh := make(chan bool, n) |
| | nthResponse := func(i int) string { |
| | return fmt.Sprintf("backend response #%d\n", i) |
| | } |
| | terminalMsg := "final message" |
| |
|
| | cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if g, ws := upgradeType(r.Header), "websocket"; g != ws { |
| | t.Errorf("Unexpected upgrade type %q, want %q", g, ws) |
| | http.Error(w, "Unexpected request", 400) |
| | return |
| | } |
| | conn, bufrw, err := w.(http.Hijacker).Hijack() |
| | if err != nil { |
| | t.Error(err) |
| | return |
| | } |
| | defer conn.Close() |
| |
|
| | upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" |
| | if _, err := io.WriteString(conn, upgradeMsg); err != nil { |
| | t.Error(err) |
| | return |
| | } |
| | if _, _, err := bufrw.ReadLine(); err != nil { |
| | t.Errorf("Failed to read line from client: %v", err) |
| | return |
| | } |
| |
|
| | for i := 0; i < n; i++ { |
| | if _, err := bufrw.WriteString(nthResponse(i)); err != nil { |
| | select { |
| | case <-triggerCancelCh: |
| | default: |
| | t.Errorf("Writing response #%d failed: %v", i, err) |
| | } |
| | return |
| | } |
| | bufrw.Flush() |
| | time.Sleep(time.Second) |
| | } |
| | if _, err := bufrw.WriteString(terminalMsg); err != nil { |
| | select { |
| | case <-triggerCancelCh: |
| | default: |
| | t.Errorf("Failed to write terminal message: %v", err) |
| | } |
| | } |
| | bufrw.Flush() |
| | })) |
| | defer cst.Close() |
| |
|
| | backendURL, _ := url.Parse(cst.URL) |
| | rproxy := NewSingleHostReverseProxy(backendURL) |
| | rproxy.ErrorLog = log.New(io.Discard, "", 0) |
| | rproxy.ModifyResponse = func(res *http.Response) error { |
| | res.Header.Add("X-Modified", "true") |
| | return nil |
| | } |
| |
|
| | handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { |
| | rw.Header().Set("X-Header", "X-Value") |
| | ctx, cancel := context.WithCancel(req.Context()) |
| | go func() { |
| | <-triggerCancelCh |
| | cancel() |
| | }() |
| | rproxy.ServeHTTP(rw, req.WithContext(ctx)) |
| | }) |
| |
|
| | frontendProxy := httptest.NewServer(handler) |
| | defer frontendProxy.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontendProxy.URL, nil) |
| | req.Header.Set("Connection", "Upgrade") |
| | req.Header.Set("Upgrade", "websocket") |
| |
|
| | res, err := frontendProxy.Client().Do(req) |
| | if err != nil { |
| | t.Fatalf("Dialing to frontend proxy: %v", err) |
| | } |
| | defer res.Body.Close() |
| | if g, w := res.StatusCode, 101; g != w { |
| | t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) |
| | } |
| |
|
| | if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { |
| | t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| | } |
| |
|
| | if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { |
| | t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| | } |
| |
|
| | rwc, ok := res.Body.(io.ReadWriteCloser) |
| | if !ok { |
| | t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) |
| | } |
| |
|
| | if got, want := res.Header.Get("X-Modified"), "true"; got != want { |
| | t.Errorf("response X-Modified header = %q; want %q", got, want) |
| | } |
| |
|
| | if _, err := io.WriteString(rwc, "Hello\n"); err != nil { |
| | t.Fatalf("Failed to write first message: %v", err) |
| | } |
| |
|
| | |
| |
|
| | br := bufio.NewReader(rwc) |
| | for { |
| | line, err := br.ReadString('\n') |
| | switch { |
| | case line == terminalMsg: |
| | t.Fatalf("The websocket request was not canceled, unfortunately!") |
| |
|
| | case err == io.EOF: |
| | return |
| |
|
| | case err != nil: |
| | t.Fatalf("Unexpected error: %v", err) |
| |
|
| | case line == nthResponse(0): |
| | |
| | close(triggerCancelCh) |
| | } |
| | } |
| | } |
| |
|
| | func TestReverseProxyWebSocketHalfTCP(t *testing.T) { |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | switch runtime.GOOS { |
| | case "plan9": |
| | t.Skipf("not supported on %s", runtime.GOOS) |
| | } |
| |
|
| | mustRead := func(t *testing.T, conn *net.TCPConn, msg string) { |
| | b := make([]byte, len(msg)) |
| | if _, err := conn.Read(b); err != nil { |
| | t.Errorf("failed to read: %v", err) |
| | } |
| |
|
| | if got, want := string(b), msg; got != want { |
| | t.Errorf("got %#q, want %#q", got, want) |
| | } |
| | } |
| |
|
| | mustReadError := func(t *testing.T, conn *net.TCPConn, e error) { |
| | b := make([]byte, 1) |
| | if _, err := conn.Read(b); !errors.Is(err, e) { |
| | t.Errorf("failed to read error: %v", err) |
| | } |
| | } |
| |
|
| | mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) { |
| | if _, err := conn.Write([]byte(msg)); err != nil { |
| | t.Errorf("failed to write: %v", err) |
| | } |
| | } |
| |
|
| | mustCloseRead := func(t *testing.T, conn *net.TCPConn) { |
| | if err := conn.CloseRead(); err != nil { |
| | t.Errorf("failed to CloseRead: %v", err) |
| | } |
| | } |
| |
|
| | mustCloseWrite := func(t *testing.T, conn *net.TCPConn) { |
| | if err := conn.CloseWrite(); err != nil { |
| | t.Errorf("failed to CloseWrite: %v", err) |
| | } |
| | } |
| |
|
| | tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){ |
| | "server close read": func(t *testing.T, cli, srv *net.TCPConn) { |
| | mustCloseRead(t, srv) |
| | mustWrite(t, srv, "server sends") |
| | mustRead(t, cli, "server sends") |
| | }, |
| | "server close write": func(t *testing.T, cli, srv *net.TCPConn) { |
| | mustCloseWrite(t, srv) |
| | mustWrite(t, cli, "client sends") |
| | mustRead(t, srv, "client sends") |
| | mustReadError(t, cli, io.EOF) |
| | }, |
| | "client close read": func(t *testing.T, cli, srv *net.TCPConn) { |
| | mustCloseRead(t, cli) |
| | mustWrite(t, cli, "client sends") |
| | mustRead(t, srv, "client sends") |
| | }, |
| | "client close write": func(t *testing.T, cli, srv *net.TCPConn) { |
| | mustCloseWrite(t, cli) |
| | mustWrite(t, srv, "server sends") |
| | mustRead(t, cli, "server sends") |
| | mustReadError(t, srv, io.EOF) |
| | }, |
| | } |
| |
|
| | for name, test := range tests { |
| | t.Run(name, func(t *testing.T) { |
| | var srv *net.TCPConn |
| |
|
| | backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | if g, ws := upgradeType(r.Header), "websocket"; g != ws { |
| | t.Fatalf("Unexpected upgrade type %q, want %q", g, ws) |
| | } |
| |
|
| | conn, _, err := w.(http.Hijacker).Hijack() |
| | if err != nil { |
| | conn.Close() |
| | t.Fatalf("hijack failed: %v", err) |
| | } |
| |
|
| | var ok bool |
| | if srv, ok = conn.(*net.TCPConn); !ok { |
| | conn.Close() |
| | t.Fatal("conn is not a TCPConn") |
| | } |
| |
|
| | upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" |
| | if _, err := io.WriteString(srv, upgradeMsg); err != nil { |
| | srv.Close() |
| | t.Fatalf("backend upgrade failed: %v", err) |
| | } |
| | })) |
| | defer backendServer.Close() |
| |
|
| | backendURL, _ := url.Parse(backendServer.URL) |
| | rproxy := NewSingleHostReverseProxy(backendURL) |
| | rproxy.ErrorLog = log.New(io.Discard, "", 0) |
| | frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { |
| | rproxy.ServeHTTP(rw, req) |
| | })) |
| | defer frontendProxy.Close() |
| |
|
| | frontendURL, _ := url.Parse(frontendProxy.URL) |
| | addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host) |
| | if err != nil { |
| | t.Fatalf("failed to resolve TCP address: %v", err) |
| | } |
| | cli, err := net.DialTCP("tcp", nil, addr) |
| | if err != nil { |
| | t.Fatalf("failed to dial TCP address: %v", err) |
| | } |
| | defer cli.Close() |
| |
|
| | req, _ := http.NewRequest("GET", frontendProxy.URL, nil) |
| | req.Header.Set("Connection", "Upgrade") |
| | req.Header.Set("Upgrade", "websocket") |
| | if err := req.Write(cli); err != nil { |
| | t.Fatalf("failed to write request: %v", err) |
| | } |
| |
|
| | br := bufio.NewReader(cli) |
| | resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) |
| | if err != nil { |
| | t.Fatalf("failed to read response: %v", err) |
| | } |
| | if resp.StatusCode != 101 { |
| | t.Fatalf("status code not 101: %v", resp.StatusCode) |
| | } |
| | if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || |
| | strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { |
| | t.Fatalf("frontend upgrade failed") |
| | } |
| | defer srv.Close() |
| |
|
| | test(t, cli, srv) |
| | }) |
| | } |
| | } |
| |
|
| | func TestReverseProxyUpgradeNoCloseWrite(t *testing.T) { |
| | |
| | |
| | |
| | backendDone := make(chan struct{}) |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Set("Connection", "upgrade") |
| | w.Header().Set("Upgrade", "u") |
| | w.WriteHeader(101) |
| | conn, _, err := http.NewResponseController(w).Hijack() |
| | if err != nil { |
| | t.Errorf("Hijack: %v", err) |
| | } |
| | io.Copy(io.Discard, conn) |
| | close(backendDone) |
| | })) |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| |
|
| | |
| | |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ModifyResponse = func(resp *http.Response) error { |
| | type readWriteCloserOnly struct { |
| | io.ReadWriteCloser |
| | } |
| | resp.Body = readWriteCloserOnly{resp.Body.(io.ReadWriteCloser)} |
| | return nil |
| | } |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | |
| | req, _ := http.NewRequest("GET", frontend.URL, nil) |
| | req.Header.Set("Connection", "upgrade") |
| | req.Header.Set("Upgrade", "u") |
| | resp, err := frontend.Client().Do(req) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | resp.Body.Close() |
| |
|
| | |
| | <-backendDone |
| | } |
| |
|
| | func TestUnannouncedTrailer(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.WriteHeader(http.StatusOK) |
| | w.(http.Flusher).Flush() |
| | w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | res, err := frontendClient.Get(frontend.URL) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| |
|
| | io.ReadAll(res.Body) |
| | res.Body.Close() |
| | if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { |
| | t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) |
| | } |
| |
|
| | } |
| |
|
| | func TestSetURL(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte(r.Host)) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.SetURL(backendURL) |
| | }, |
| | } |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | res, err := frontendClient.Get(frontend.URL) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| |
|
| | body, err := io.ReadAll(res.Body) |
| | if err != nil { |
| | t.Fatalf("Reading body: %v", err) |
| | } |
| |
|
| | if got, want := string(body), backendURL.Host; got != want { |
| | t.Errorf("backend got Host %q, want %q", got, want) |
| | } |
| | } |
| |
|
| | func TestSingleJoinSlash(t *testing.T) { |
| | tests := []struct { |
| | slasha string |
| | slashb string |
| | expected string |
| | }{ |
| | {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, |
| | {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, |
| | {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, |
| | {"https://www.google.com", "", "https://www.google.com/"}, |
| | {"", "favicon.ico", "/favicon.ico"}, |
| | } |
| | for _, tt := range tests { |
| | if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { |
| | t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", |
| | tt.slasha, |
| | tt.slashb, |
| | tt.expected, |
| | got) |
| | } |
| | } |
| | } |
| |
|
| | func TestJoinURLPath(t *testing.T) { |
| | tests := []struct { |
| | a *url.URL |
| | b *url.URL |
| | wantPath string |
| | wantRaw string |
| | }{ |
| | {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, |
| | {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, |
| | {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, |
| | {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, |
| | {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, |
| | {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, |
| | } |
| |
|
| | for _, tt := range tests { |
| | p, rp := joinURLPath(tt.a, tt.b) |
| | if p != tt.wantPath || rp != tt.wantRaw { |
| | t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", |
| | tt.a.Path, tt.a.RawPath, |
| | tt.b.Path, tt.b.RawPath, |
| | tt.wantPath, tt.wantRaw, |
| | p, rp) |
| | } |
| | } |
| | } |
| |
|
| | func TestReverseProxyRewriteReplacesOut(t *testing.T) { |
| | const content = "response_content" |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte(content)) |
| | })) |
| | defer backend.Close() |
| | proxyHandler := &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.Out, _ = http.NewRequest("GET", backend.URL, nil) |
| | }, |
| | } |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | res, err := frontend.Client().Get(frontend.URL) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | body, _ := io.ReadAll(res.Body) |
| | if got, want := string(body), content; got != want { |
| | t.Errorf("got response %q, want %q", got, want) |
| | } |
| | } |
| |
|
| | func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) { |
| | |
| | |
| | |
| | |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | for i := 0; i < 5; i++ { |
| | w.WriteHeader(103) |
| | } |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| |
|
| | rw := &testResponseWriter{} |
| | func() { |
| | |
| | |
| | ctx, cancel := context.WithCancel(context.Background()) |
| | defer cancel() |
| | ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ |
| | Got1xxResponse: func(code int, header textproto.MIMEHeader) error { |
| | cancel() |
| | return nil |
| | }, |
| | }) |
| |
|
| | req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil) |
| | proxyHandler.ServeHTTP(rw, req) |
| | }() |
| | |
| | |
| | |
| | for _ = range rw.Header() { |
| | } |
| | } |
| |
|
| | func Test1xxResponses(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | h := w.Header() |
| | h.Add("Link", "</style.css>; rel=preload; as=style") |
| | h.Add("Link", "</script.js>; rel=preload; as=script") |
| | w.WriteHeader(http.StatusEarlyHints) |
| |
|
| | h.Add("Link", "</foo.js>; rel=preload; as=script") |
| | w.WriteHeader(http.StatusProcessing) |
| |
|
| | w.Write([]byte("Hello")) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := NewSingleHostReverseProxy(backendURL) |
| | proxyHandler.ErrorLog = log.New(io.Discard, "", 0) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| | frontendClient := frontend.Client() |
| |
|
| | checkLinkHeaders := func(t *testing.T, expected, got []string) { |
| | t.Helper() |
| |
|
| | if len(expected) != len(got) { |
| | t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) |
| | } |
| |
|
| | for i := range expected { |
| | if i >= len(got) { |
| | t.Errorf("Expected %q link header; got nothing", expected[i]) |
| |
|
| | continue |
| | } |
| |
|
| | if expected[i] != got[i] { |
| | t.Errorf("Expected %q link header; got %q", expected[i], got[i]) |
| | } |
| | } |
| | } |
| |
|
| | var respCounter uint8 |
| | trace := &httptrace.ClientTrace{ |
| | Got1xxResponse: func(code int, header textproto.MIMEHeader) error { |
| | switch code { |
| | case http.StatusEarlyHints: |
| | checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) |
| | case http.StatusProcessing: |
| | checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) |
| | default: |
| | t.Error("Unexpected 1xx response") |
| | } |
| |
|
| | respCounter++ |
| |
|
| | return nil |
| | }, |
| | } |
| | req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) |
| |
|
| | res, err := frontendClient.Do(req) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| |
|
| | defer res.Body.Close() |
| |
|
| | if respCounter != 2 { |
| | t.Errorf("Expected 2 1xx responses; got %d", respCounter) |
| | } |
| | checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) |
| |
|
| | body, _ := io.ReadAll(res.Body) |
| | if string(body) != "Hello" { |
| | t.Errorf("Read body %q; want Hello", body) |
| | } |
| | } |
| |
|
| | const ( |
| | testWantsCleanQuery = true |
| | testWantsRawQuery = false |
| | ) |
| |
|
| | func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { |
| | testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { |
| | proxyHandler := NewSingleHostReverseProxy(u) |
| | oldDirector := proxyHandler.Director |
| | proxyHandler.Director = func(r *http.Request) { |
| | oldDirector(r) |
| | } |
| | return proxyHandler |
| | }) |
| | } |
| |
|
| | func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { |
| | testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { |
| | proxyHandler := NewSingleHostReverseProxy(u) |
| | oldDirector := proxyHandler.Director |
| | proxyHandler.Director = func(r *http.Request) { |
| | |
| | |
| | r.FormValue("a") |
| | oldDirector(r) |
| | } |
| | return proxyHandler |
| | }) |
| | } |
| |
|
| | func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { |
| | testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { |
| | return &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.SetURL(u) |
| | }, |
| | } |
| | }) |
| | } |
| |
|
| | func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { |
| | testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { |
| | return &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.SetURL(u) |
| | r.Out.URL.RawQuery = r.In.URL.RawQuery |
| | }, |
| | } |
| | }) |
| | } |
| |
|
| | func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { |
| | const content = "response_content" |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Write([]byte(r.URL.RawQuery)) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := newProxy(backendURL) |
| | frontend := httptest.NewServer(proxyHandler) |
| | defer frontend.Close() |
| |
|
| | |
| | backend.Config.ErrorLog = log.New(io.Discard, "", 0) |
| | frontend.Config.ErrorLog = log.New(io.Discard, "", 0) |
| |
|
| | for _, test := range []struct { |
| | rawQuery string |
| | cleanQuery string |
| | }{{ |
| | rawQuery: "a=1&a=2;b=3", |
| | cleanQuery: "a=1", |
| | }, { |
| | rawQuery: "a=1&a=%zz&b=3", |
| | cleanQuery: "a=1&b=3", |
| | }} { |
| | res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) |
| | if err != nil { |
| | t.Fatalf("Get: %v", err) |
| | } |
| | defer res.Body.Close() |
| | body, _ := io.ReadAll(res.Body) |
| | wantQuery := test.rawQuery |
| | if wantCleanQuery { |
| | wantQuery = test.cleanQuery |
| | } |
| | if got, want := string(body), wantQuery; got != want { |
| | t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | func TestReverseProxyHijackCopyError(t *testing.T) { |
| | backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Set("Upgrade", "someproto") |
| | w.WriteHeader(http.StatusSwitchingProtocols) |
| | })) |
| | defer backend.Close() |
| | backendURL, err := url.Parse(backend.URL) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | proxyHandler := &ReverseProxy{ |
| | Rewrite: func(r *ProxyRequest) { |
| | r.SetURL(backendURL) |
| | }, |
| | ModifyResponse: func(resp *http.Response) error { |
| | resp.Body = &testReadWriteCloser{ |
| | read: func([]byte) (int, error) { |
| | return 0, errors.New("read error") |
| | }, |
| | } |
| | return nil |
| | }, |
| | } |
| |
|
| | hijacked := false |
| | rw := &testResponseWriter{ |
| | writeHeader: func(statusCode int) { |
| | if hijacked { |
| | t.Errorf("WriteHeader(%v) called after Hijack", statusCode) |
| | } |
| | }, |
| | hijack: func() (net.Conn, *bufio.ReadWriter, error) { |
| | hijacked = true |
| | cli, srv := net.Pipe() |
| | go io.Copy(io.Discard, cli) |
| | return srv, bufio.NewReadWriter(bufio.NewReader(srv), bufio.NewWriter(srv)), nil |
| | }, |
| | } |
| | req, _ := http.NewRequest("GET", "http://example.tld/", nil) |
| | req.Header.Set("Upgrade", "someproto") |
| | proxyHandler.ServeHTTP(rw, req) |
| | } |
| |
|
| | |
| | func TestReverseProxyInvalidUpstream100ContinueDoNotHang(t *testing.T) { |
| | proxy := ReverseProxy{ |
| | Transport: &http.Transport{DisableKeepAlives: true, ExpectContinueTimeout: time.Second * 60}, |
| | Director: func(request *http.Request) { |
| | request.URL.Scheme = "http" |
| | request.URL.Host = "doesnotexist:12345" |
| | }, |
| | } |
| | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| | proxy.ServeHTTP(w, r) |
| | }) |
| | upstreamServer := httptest.NewServer(handler) |
| | defer upstreamServer.Close() |
| |
|
| | conn, err := net.Dial("tcp", upstreamServer.Listener.Addr().String()) |
| | if err != nil { |
| | t.Fatal(err) |
| | } |
| | defer conn.Close() |
| |
|
| | requestBody := `{"test": "data"}` |
| | initialRequest := fmt.Sprintf("POST %s/test-expect HTTP/1.1\r\n"+ |
| | "Host: %s\r\n"+ |
| | "Content-Type: application/json\r\n"+ |
| | "Content-Length: %d\r\n"+ |
| | "Expect: 100-continue\r\n"+ |
| | "\r\n", upstreamServer.URL, upstreamServer.Listener.Addr().String(), len(requestBody)) |
| |
|
| | if _, err := conn.Write([]byte(initialRequest)); err != nil { |
| | log.Fatal(err) |
| | } |
| | buff := make([]byte, 1024) |
| | if _, err := conn.Read(buff); err != nil { |
| | log.Fatal(err) |
| | } |
| | } |
| |
|
| | type testResponseWriter struct { |
| | h http.Header |
| | writeHeader func(int) |
| | write func([]byte) (int, error) |
| | hijack func() (net.Conn, *bufio.ReadWriter, error) |
| | } |
| |
|
| | func (rw *testResponseWriter) Header() http.Header { |
| | if rw.h == nil { |
| | rw.h = make(http.Header) |
| | } |
| | return rw.h |
| | } |
| |
|
| | func (rw *testResponseWriter) WriteHeader(statusCode int) { |
| | if rw.writeHeader != nil { |
| | rw.writeHeader(statusCode) |
| | } |
| | } |
| |
|
| | func (rw *testResponseWriter) Write(p []byte) (int, error) { |
| | if rw.write != nil { |
| | return rw.write(p) |
| | } |
| | return len(p), nil |
| | } |
| |
|
| | func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { |
| | if rw.hijack != nil { |
| | return rw.hijack() |
| | } |
| | return nil, nil, errors.ErrUnsupported |
| | } |
| |
|
| | type testReadWriteCloser struct { |
| | read func([]byte) (int, error) |
| | write func([]byte) (int, error) |
| | close func() error |
| | } |
| |
|
| | func (rc *testReadWriteCloser) Read(p []byte) (int, error) { |
| | if rc.read != nil { |
| | return rc.read(p) |
| | } |
| | return 0, io.EOF |
| | } |
| |
|
| | func (rc *testReadWriteCloser) Write(p []byte) (int, error) { |
| | if rc.write != nil { |
| | return rc.write(p) |
| | } |
| | return len(p), nil |
| | } |
| |
|
| | func (rc *testReadWriteCloser) Close() error { |
| | if rc.close != nil { |
| | return rc.close() |
| | } |
| | return nil |
| | } |
| |
|