|
|
package amp |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"compress/gzip" |
|
|
"context" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"net/http/httptest" |
|
|
"strings" |
|
|
"testing" |
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
) |
|
|
|
|
|
|
|
|
func gzipBytes(b []byte) []byte { |
|
|
var buf bytes.Buffer |
|
|
zw := gzip.NewWriter(&buf) |
|
|
zw.Write(b) |
|
|
zw.Close() |
|
|
return buf.Bytes() |
|
|
} |
|
|
|
|
|
|
|
|
func mkResp(status int, hdr http.Header, body []byte) *http.Response { |
|
|
if hdr == nil { |
|
|
hdr = http.Header{} |
|
|
} |
|
|
return &http.Response{ |
|
|
StatusCode: status, |
|
|
Header: hdr, |
|
|
Body: io.NopCloser(bytes.NewReader(body)), |
|
|
ContentLength: int64(len(body)), |
|
|
} |
|
|
} |
|
|
|
|
|
func TestCreateReverseProxy_ValidURL(t *testing.T) { |
|
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key")) |
|
|
if err != nil { |
|
|
t.Fatalf("expected no error, got: %v", err) |
|
|
} |
|
|
if proxy == nil { |
|
|
t.Fatal("expected proxy to be created") |
|
|
} |
|
|
} |
|
|
|
|
|
func TestCreateReverseProxy_InvalidURL(t *testing.T) { |
|
|
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key")) |
|
|
if err == nil { |
|
|
t.Fatal("expected error for invalid URL") |
|
|
} |
|
|
} |
|
|
|
|
|
func TestModifyResponse_GzipScenarios(t *testing.T) { |
|
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
goodJSON := []byte(`{"ok":true}`) |
|
|
good := gzipBytes(goodJSON) |
|
|
truncated := good[:10] |
|
|
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...) |
|
|
|
|
|
cases := []struct { |
|
|
name string |
|
|
header http.Header |
|
|
body []byte |
|
|
status int |
|
|
wantBody []byte |
|
|
wantCE string |
|
|
}{ |
|
|
{ |
|
|
name: "decompresses_valid_gzip_no_header", |
|
|
header: http.Header{}, |
|
|
body: good, |
|
|
status: 200, |
|
|
wantBody: goodJSON, |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "skips_when_ce_present", |
|
|
header: http.Header{"Content-Encoding": []string{"gzip"}}, |
|
|
body: good, |
|
|
status: 200, |
|
|
wantBody: good, |
|
|
wantCE: "gzip", |
|
|
}, |
|
|
{ |
|
|
name: "passes_truncated_unchanged", |
|
|
header: http.Header{}, |
|
|
body: truncated, |
|
|
status: 200, |
|
|
wantBody: truncated, |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "passes_corrupted_unchanged", |
|
|
header: http.Header{}, |
|
|
body: corrupted, |
|
|
status: 200, |
|
|
wantBody: corrupted, |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "non_gzip_unchanged", |
|
|
header: http.Header{}, |
|
|
body: []byte("plain"), |
|
|
status: 200, |
|
|
wantBody: []byte("plain"), |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "empty_body", |
|
|
header: http.Header{}, |
|
|
body: []byte{}, |
|
|
status: 200, |
|
|
wantBody: []byte{}, |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "single_byte_body", |
|
|
header: http.Header{}, |
|
|
body: []byte{0x1f}, |
|
|
status: 200, |
|
|
wantBody: []byte{0x1f}, |
|
|
wantCE: "", |
|
|
}, |
|
|
{ |
|
|
name: "skips_non_2xx_status", |
|
|
header: http.Header{}, |
|
|
body: good, |
|
|
status: 404, |
|
|
wantBody: good, |
|
|
wantCE: "", |
|
|
}, |
|
|
} |
|
|
|
|
|
for _, tc := range cases { |
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
resp := mkResp(tc.status, tc.header, tc.body) |
|
|
if err := proxy.ModifyResponse(resp); err != nil { |
|
|
t.Fatalf("ModifyResponse error: %v", err) |
|
|
} |
|
|
got, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
t.Fatalf("ReadAll error: %v", err) |
|
|
} |
|
|
if !bytes.Equal(got, tc.wantBody) { |
|
|
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got) |
|
|
} |
|
|
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE { |
|
|
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce) |
|
|
} |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) { |
|
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
goodJSON := []byte(`{"message":"test response"}`) |
|
|
gzipped := gzipBytes(goodJSON) |
|
|
|
|
|
|
|
|
|
|
|
resp := mkResp(200, http.Header{ |
|
|
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, |
|
|
}, gzipped) |
|
|
|
|
|
if err := proxy.ModifyResponse(resp); err != nil { |
|
|
t.Fatalf("ModifyResponse error: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
got, _ := io.ReadAll(resp.Body) |
|
|
if !bytes.Equal(got, goodJSON) { |
|
|
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON) |
|
|
} |
|
|
|
|
|
|
|
|
wantCL := fmt.Sprintf("%d", len(goodJSON)) |
|
|
gotCL := resp.Header.Get("Content-Length") |
|
|
if gotCL != wantCL { |
|
|
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL) |
|
|
} |
|
|
|
|
|
|
|
|
if resp.ContentLength != int64(len(goodJSON)) { |
|
|
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) { |
|
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
goodJSON := []byte(`{"ok":true}`) |
|
|
gzipped := gzipBytes(goodJSON) |
|
|
|
|
|
t.Run("sse_skips_decompression", func(t *testing.T) { |
|
|
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped) |
|
|
if err := proxy.ModifyResponse(resp); err != nil { |
|
|
t.Fatalf("ModifyResponse error: %v", err) |
|
|
} |
|
|
|
|
|
got, _ := io.ReadAll(resp.Body) |
|
|
if !bytes.Equal(got, gzipped) { |
|
|
t.Fatal("SSE response should not be decompressed") |
|
|
} |
|
|
}) |
|
|
} |
|
|
|
|
|
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) { |
|
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
goodJSON := []byte(`{"ok":true}`) |
|
|
gzipped := gzipBytes(goodJSON) |
|
|
|
|
|
t.Run("chunked_json_decompresses", func(t *testing.T) { |
|
|
|
|
|
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped) |
|
|
if err := proxy.ModifyResponse(resp); err != nil { |
|
|
t.Fatalf("ModifyResponse error: %v", err) |
|
|
} |
|
|
|
|
|
got, _ := io.ReadAll(resp.Body) |
|
|
if !bytes.Equal(got, goodJSON) { |
|
|
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON) |
|
|
} |
|
|
}) |
|
|
} |
|
|
|
|
|
func TestReverseProxy_InjectsHeaders(t *testing.T) { |
|
|
gotHeaders := make(chan http.Header, 1) |
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
gotHeaders <- r.Header.Clone() |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`ok`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
proxy.ServeHTTP(w, r) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
res.Body.Close() |
|
|
|
|
|
hdr := <-gotHeaders |
|
|
if hdr.Get("X-Api-Key") != "secret" { |
|
|
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) |
|
|
} |
|
|
if hdr.Get("Authorization") != "Bearer secret" { |
|
|
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_EmptySecret(t *testing.T) { |
|
|
gotHeaders := make(chan http.Header, 1) |
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
gotHeaders <- r.Header.Clone() |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`ok`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
proxy.ServeHTTP(w, r) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
res.Body.Close() |
|
|
|
|
|
hdr := <-gotHeaders |
|
|
|
|
|
if hdr.Get("X-Api-Key") != "" { |
|
|
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key")) |
|
|
} |
|
|
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " { |
|
|
t.Fatalf("Authorization should not be set, got: %q", authVal) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) { |
|
|
type captured struct { |
|
|
headers http.Header |
|
|
query string |
|
|
} |
|
|
got := make(chan captured, 1) |
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery} |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`ok`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key") |
|
|
proxy.ServeHTTP(w, r.WithContext(ctx)) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
req.Header.Set("Authorization", "Bearer client-key") |
|
|
req.Header.Set("X-Api-Key", "client-key") |
|
|
req.Header.Set("X-Goog-Api-Key", "client-key") |
|
|
|
|
|
res, err := http.DefaultClient.Do(req) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
res.Body.Close() |
|
|
|
|
|
c := <-got |
|
|
|
|
|
|
|
|
if v := c.headers.Get("X-Goog-Api-Key"); v != "" { |
|
|
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v) |
|
|
} |
|
|
|
|
|
|
|
|
if v := c.headers.Get("Authorization"); v != "Bearer upstream" { |
|
|
t.Fatalf("Authorization should be upstream-injected, got: %q", v) |
|
|
} |
|
|
if v := c.headers.Get("X-Api-Key"); v != "upstream" { |
|
|
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") { |
|
|
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query) |
|
|
} |
|
|
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") { |
|
|
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) { |
|
|
gotHeaders := make(chan http.Header, 1) |
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
gotHeaders <- r.Header.Clone() |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`ok`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
defaultSource := NewStaticSecretSource("default") |
|
|
mapped := NewMappedSecretSource(defaultSource) |
|
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
|
|
{ |
|
|
UpstreamAPIKey: "u1", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, mapped) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1") |
|
|
proxy.ServeHTTP(w, r.WithContext(ctx)) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
res.Body.Close() |
|
|
|
|
|
hdr := <-gotHeaders |
|
|
if hdr.Get("X-Api-Key") != "u1" { |
|
|
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key")) |
|
|
} |
|
|
if hdr.Get("Authorization") != "Bearer u1" { |
|
|
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization")) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) { |
|
|
gotHeaders := make(chan http.Header, 1) |
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
gotHeaders <- r.Header.Clone() |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`ok`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
defaultSource := NewStaticSecretSource("default") |
|
|
mapped := NewMappedSecretSource(defaultSource) |
|
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
|
|
{ |
|
|
UpstreamAPIKey: "u1", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, mapped) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2") |
|
|
proxy.ServeHTTP(w, r.WithContext(ctx)) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
res.Body.Close() |
|
|
|
|
|
hdr := <-gotHeaders |
|
|
if hdr.Get("X-Api-Key") != "default" { |
|
|
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key")) |
|
|
} |
|
|
if hdr.Get("Authorization") != "Bearer default" { |
|
|
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization")) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_ErrorHandler(t *testing.T) { |
|
|
|
|
|
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource("")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
proxy.ServeHTTP(w, r) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/any") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
body, _ := io.ReadAll(res.Body) |
|
|
res.Body.Close() |
|
|
|
|
|
if res.StatusCode != http.StatusBadGateway { |
|
|
t.Fatalf("want 502, got %d", res.StatusCode) |
|
|
} |
|
|
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) { |
|
|
t.Fatalf("unexpected body: %s", body) |
|
|
} |
|
|
if ct := res.Header.Get("Content-Type"); ct != "application/json" { |
|
|
t.Fatalf("content-type: want application/json, got %s", ct) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) { |
|
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
w.WriteHeader(200) |
|
|
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`))) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
proxy.ServeHTTP(w, r) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
body, _ := io.ReadAll(res.Body) |
|
|
res.Body.Close() |
|
|
|
|
|
expected := []byte(`{"upstream":"ok"}`) |
|
|
if !bytes.Equal(body, expected) { |
|
|
t.Fatalf("want decompressed JSON, got: %s", body) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) { |
|
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
w.WriteHeader(200) |
|
|
w.Write([]byte(`{"plain":"json"}`)) |
|
|
})) |
|
|
defer upstream.Close() |
|
|
|
|
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key")) |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
proxy.ServeHTTP(w, r) |
|
|
})) |
|
|
defer srv.Close() |
|
|
|
|
|
res, err := http.Get(srv.URL + "/test") |
|
|
if err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
body, _ := io.ReadAll(res.Body) |
|
|
res.Body.Close() |
|
|
|
|
|
expected := []byte(`{"plain":"json"}`) |
|
|
if !bytes.Equal(body, expected) { |
|
|
t.Fatalf("want plain JSON unchanged, got: %s", body) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestIsStreamingResponse(t *testing.T) { |
|
|
cases := []struct { |
|
|
name string |
|
|
header http.Header |
|
|
want bool |
|
|
}{ |
|
|
{ |
|
|
name: "sse", |
|
|
header: http.Header{"Content-Type": []string{"text/event-stream"}}, |
|
|
want: true, |
|
|
}, |
|
|
{ |
|
|
name: "chunked_not_streaming", |
|
|
header: http.Header{"Transfer-Encoding": []string{"chunked"}}, |
|
|
want: false, |
|
|
}, |
|
|
{ |
|
|
name: "normal_json", |
|
|
header: http.Header{"Content-Type": []string{"application/json"}}, |
|
|
want: false, |
|
|
}, |
|
|
{ |
|
|
name: "empty", |
|
|
header: http.Header{}, |
|
|
want: false, |
|
|
}, |
|
|
} |
|
|
|
|
|
for _, tc := range cases { |
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
resp := &http.Response{Header: tc.header} |
|
|
got := isStreamingResponse(resp) |
|
|
if got != tc.want { |
|
|
t.Fatalf("want %v, got %v", tc.want, got) |
|
|
} |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestFilterBetaFeatures(t *testing.T) { |
|
|
tests := []struct { |
|
|
name string |
|
|
header string |
|
|
featureToRemove string |
|
|
expected string |
|
|
}{ |
|
|
{ |
|
|
name: "Remove context-1m from middle", |
|
|
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", |
|
|
}, |
|
|
{ |
|
|
name: "Remove context-1m from start", |
|
|
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "fine-grained-tool-streaming-2025-05-14", |
|
|
}, |
|
|
{ |
|
|
name: "Remove context-1m from end", |
|
|
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "fine-grained-tool-streaming-2025-05-14", |
|
|
}, |
|
|
{ |
|
|
name: "Feature not present", |
|
|
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", |
|
|
}, |
|
|
{ |
|
|
name: "Only feature to remove", |
|
|
header: "context-1m-2025-08-07", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "", |
|
|
}, |
|
|
{ |
|
|
name: "Empty header", |
|
|
header: "", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "", |
|
|
}, |
|
|
{ |
|
|
name: "Header with spaces", |
|
|
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20", |
|
|
featureToRemove: "context-1m-2025-08-07", |
|
|
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20", |
|
|
}, |
|
|
} |
|
|
|
|
|
for _, tt := range tests { |
|
|
t.Run(tt.name, func(t *testing.T) { |
|
|
result := filterBetaFeatures(tt.header, tt.featureToRemove) |
|
|
if result != tt.expected { |
|
|
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected) |
|
|
} |
|
|
}) |
|
|
} |
|
|
} |
|
|
|