kek
Fresh start: Go 1.23 + go-git/v5 compatibility
f606b10
package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
func gzipBytes(b []byte) []byte {
var buf bytes.Buffer
zw := gzip.NewWriter(&buf)
zw.Write(b)
zw.Close()
return buf.Bytes()
}
// Helper: create a mock http.Response
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
if hdr == nil {
hdr = http.Header{}
}
return &http.Response{
StatusCode: status,
Header: hdr,
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
}
func TestCreateReverseProxy_ValidURL(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if proxy == nil {
t.Fatal("expected proxy to be created")
}
}
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
if err == nil {
t.Fatal("expected error for invalid URL")
}
}
func TestModifyResponse_GzipScenarios(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
good := gzipBytes(goodJSON)
truncated := good[:10]
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
cases := []struct {
name string
header http.Header
body []byte
status int
wantBody []byte
wantCE string
}{
{
name: "decompresses_valid_gzip_no_header",
header: http.Header{},
body: good,
status: 200,
wantBody: goodJSON,
wantCE: "",
},
{
name: "skips_when_ce_present",
header: http.Header{"Content-Encoding": []string{"gzip"}},
body: good,
status: 200,
wantBody: good,
wantCE: "gzip",
},
{
name: "passes_truncated_unchanged",
header: http.Header{},
body: truncated,
status: 200,
wantBody: truncated,
wantCE: "",
},
{
name: "passes_corrupted_unchanged",
header: http.Header{},
body: corrupted,
status: 200,
wantBody: corrupted,
wantCE: "",
},
{
name: "non_gzip_unchanged",
header: http.Header{},
body: []byte("plain"),
status: 200,
wantBody: []byte("plain"),
wantCE: "",
},
{
name: "empty_body",
header: http.Header{},
body: []byte{},
status: 200,
wantBody: []byte{},
wantCE: "",
},
{
name: "single_byte_body",
header: http.Header{},
body: []byte{0x1f},
status: 200,
wantBody: []byte{0x1f},
wantCE: "",
},
{
name: "skips_non_2xx_status",
header: http.Header{},
body: good,
status: 404,
wantBody: good,
wantCE: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := mkResp(tc.status, tc.header, tc.body)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if !bytes.Equal(got, tc.wantBody) {
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
}
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
}
})
}
}
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"message":"test response"}`)
gzipped := gzipBytes(goodJSON)
// Simulate upstream response with gzip body AND Content-Length header
// (this is the scenario the bot flagged - stale Content-Length after decompression)
resp := mkResp(200, http.Header{
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Verify body is decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
}
// Verify Content-Length header is updated to decompressed size
wantCL := fmt.Sprintf("%d", len(goodJSON))
gotCL := resp.Header.Get("Content-Length")
if gotCL != wantCL {
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
}
// Verify struct field also matches
if resp.ContentLength != int64(len(goodJSON)) {
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
}
}
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("sse_skips_decompression", func(t *testing.T) {
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// SSE should NOT be decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, gzipped) {
t.Fatal("SSE response should not be decompressed")
}
})
}
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("chunked_json_decompresses", func(t *testing.T) {
// Chunked JSON responses (like thread APIs) should be decompressed
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Should decompress because it's not SSE
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
}
})
}
func TestReverseProxy_InjectsHeaders(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "secret" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer secret" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_EmptySecret(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
// Should NOT inject headers when secret is empty
if hdr.Get("X-Api-Key") != "" {
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
}
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
t.Fatalf("Authorization should not be set, got: %q", authVal)
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/any")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
if res.StatusCode != http.StatusBadGateway {
t.Fatalf("want 502, got %d", res.StatusCode)
}
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
t.Fatalf("unexpected body: %s", body)
}
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
t.Fatalf("content-type: want application/json, got %s", ct)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"upstream":"ok"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want decompressed JSON, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
// Upstream returns plain JSON
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"plain":"json"}`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"plain":"json"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want plain JSON unchanged, got: %s", body)
}
}
func TestIsStreamingResponse(t *testing.T) {
cases := []struct {
name string
header http.Header
want bool
}{
{
name: "sse",
header: http.Header{"Content-Type": []string{"text/event-stream"}},
want: true,
},
{
name: "chunked_not_streaming",
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
want: false, // Chunked is transport-level, not streaming
},
{
name: "normal_json",
header: http.Header{"Content-Type": []string{"application/json"}},
want: false,
},
{
name: "empty",
header: http.Header{},
want: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := &http.Response{Header: tc.header}
got := isStreamingResponse(resp)
if got != tc.want {
t.Fatalf("want %v, got %v", tc.want, got)
}
})
}
}
func TestFilterBetaFeatures(t *testing.T) {
tests := []struct {
name string
header string
featureToRemove string
expected string
}{
{
name: "Remove context-1m from middle",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Remove context-1m from start",
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Remove context-1m from end",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Feature not present",
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Only feature to remove",
header: "context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Empty header",
header: "",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Header with spaces",
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterBetaFeatures(tt.header, tt.featureToRemove)
if result != tt.expected {
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
}
})
}
}