|
|
package amp |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"encoding/json" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"sync" |
|
|
"testing" |
|
|
"time" |
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
log "github.com/sirupsen/logrus" |
|
|
"github.com/sirupsen/logrus/hooks/test" |
|
|
) |
|
|
|
|
|
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { |
|
|
ctx := context.Background() |
|
|
|
|
|
cases := []struct { |
|
|
name string |
|
|
configKey string |
|
|
envKey string |
|
|
fileJSON string |
|
|
want string |
|
|
}{ |
|
|
{"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"}, |
|
|
{"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, |
|
|
{"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, |
|
|
{"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, |
|
|
{"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, |
|
|
{"missing_file_returns_empty", "", "", "", ""}, |
|
|
{"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""}, |
|
|
} |
|
|
|
|
|
for _, tc := range cases { |
|
|
tc := tc |
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
tmpDir := t.TempDir() |
|
|
secretsPath := filepath.Join(tmpDir, "secrets.json") |
|
|
|
|
|
if tc.fileJSON != "" { |
|
|
if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
} |
|
|
|
|
|
t.Setenv("AMP_API_KEY", tc.envKey) |
|
|
|
|
|
s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond) |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != tc.want { |
|
|
t.Fatalf("want %q, got %q", tc.want, got) |
|
|
} |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestMultiSourceSecret_CacheBehavior(t *testing.T) { |
|
|
ctx := context.Background() |
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "secrets.json") |
|
|
|
|
|
|
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond) |
|
|
|
|
|
|
|
|
got1, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("Get failed: %v", err) |
|
|
} |
|
|
if got1 != "v1" { |
|
|
t.Fatalf("expected v1, got %s", got1) |
|
|
} |
|
|
|
|
|
|
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
got2, _ := s.Get(ctx) |
|
|
if got2 != "v1" { |
|
|
t.Fatalf("cache hit expected v1, got %s", got2) |
|
|
} |
|
|
|
|
|
|
|
|
time.Sleep(60 * time.Millisecond) |
|
|
got3, _ := s.Get(ctx) |
|
|
if got3 != "v2" { |
|
|
t.Fatalf("cache miss expected v2, got %s", got3) |
|
|
} |
|
|
|
|
|
|
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
s.InvalidateCache() |
|
|
got4, _ := s.Get(ctx) |
|
|
if got4 != "v3" { |
|
|
t.Fatalf("invalidate expected v3, got %s", got4) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestMultiSourceSecret_FileHandling(t *testing.T) { |
|
|
ctx := context.Background() |
|
|
|
|
|
t.Run("missing_file_no_error", func(t *testing.T) { |
|
|
s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond) |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("expected no error for missing file, got: %v", err) |
|
|
} |
|
|
if got != "" { |
|
|
t.Fatalf("expected empty string, got %q", got) |
|
|
} |
|
|
}) |
|
|
|
|
|
t.Run("invalid_json", func(t *testing.T) { |
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "secrets.json") |
|
|
if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
|
|
_, err := s.Get(ctx) |
|
|
if err == nil { |
|
|
t.Fatal("expected error for invalid JSON") |
|
|
} |
|
|
}) |
|
|
|
|
|
t.Run("missing_key_in_json", func(t *testing.T) { |
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "secrets.json") |
|
|
if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "" { |
|
|
t.Fatalf("expected empty string for missing key, got %q", got) |
|
|
} |
|
|
}) |
|
|
|
|
|
t.Run("empty_key_value", func(t *testing.T) { |
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "secrets.json") |
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
|
|
got, _ := s.Get(ctx) |
|
|
if got != "" { |
|
|
t.Fatalf("expected empty after trim, got %q", got) |
|
|
} |
|
|
}) |
|
|
} |
|
|
|
|
|
func TestMultiSourceSecret_Concurrency(t *testing.T) { |
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "secrets.json") |
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 5*time.Second) |
|
|
ctx := context.Background() |
|
|
|
|
|
|
|
|
const goroutines = 50 |
|
|
const iterations = 100 |
|
|
|
|
|
var wg sync.WaitGroup |
|
|
errors := make(chan error, goroutines) |
|
|
|
|
|
for i := 0; i < goroutines; i++ { |
|
|
wg.Add(1) |
|
|
go func() { |
|
|
defer wg.Done() |
|
|
for j := 0; j < iterations; j++ { |
|
|
val, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
errors <- err |
|
|
return |
|
|
} |
|
|
if val != "concurrent" { |
|
|
errors <- err |
|
|
return |
|
|
} |
|
|
} |
|
|
}() |
|
|
} |
|
|
|
|
|
wg.Wait() |
|
|
close(errors) |
|
|
|
|
|
for err := range errors { |
|
|
t.Errorf("concurrency error: %v", err) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestStaticSecretSource(t *testing.T) { |
|
|
ctx := context.Background() |
|
|
|
|
|
t.Run("returns_provided_key", func(t *testing.T) { |
|
|
s := NewStaticSecretSource("test-key-123") |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "test-key-123" { |
|
|
t.Fatalf("want test-key-123, got %q", got) |
|
|
} |
|
|
}) |
|
|
|
|
|
t.Run("trims_whitespace", func(t *testing.T) { |
|
|
s := NewStaticSecretSource(" test-key ") |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "test-key" { |
|
|
t.Fatalf("want test-key, got %q", got) |
|
|
} |
|
|
}) |
|
|
|
|
|
t.Run("empty_string", func(t *testing.T) { |
|
|
s := NewStaticSecretSource("") |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "" { |
|
|
t.Fatalf("want empty string, got %q", got) |
|
|
} |
|
|
}) |
|
|
} |
|
|
|
|
|
func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) { |
|
|
|
|
|
tmpDir := t.TempDir() |
|
|
p := filepath.Join(tmpDir, "nonexistent.json") |
|
|
|
|
|
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
|
|
ctx := context.Background() |
|
|
|
|
|
|
|
|
got1, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("expected no error for missing file, got: %v", err) |
|
|
} |
|
|
if got1 != "" { |
|
|
t.Fatalf("expected empty string, got %q", got1) |
|
|
} |
|
|
|
|
|
|
|
|
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil { |
|
|
t.Fatal(err) |
|
|
} |
|
|
|
|
|
|
|
|
got2, _ := s.Get(ctx) |
|
|
if got2 != "" { |
|
|
t.Fatalf("cache should return empty, got %q", got2) |
|
|
} |
|
|
|
|
|
|
|
|
time.Sleep(110 * time.Millisecond) |
|
|
got3, _ := s.Get(ctx) |
|
|
if got3 != "new-value" { |
|
|
t.Fatalf("after cache expiry, expected new-value, got %q", got3) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) { |
|
|
defaultSource := NewStaticSecretSource("default") |
|
|
s := NewMappedSecretSource(defaultSource) |
|
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
|
|
{ |
|
|
UpstreamAPIKey: "u1", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "u1" { |
|
|
t.Fatalf("want u1, got %q", got) |
|
|
} |
|
|
|
|
|
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2") |
|
|
got, err = s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "default" { |
|
|
t.Fatalf("want default fallback, got %q", got) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) { |
|
|
defaultSource := NewStaticSecretSource("default") |
|
|
s := NewMappedSecretSource(defaultSource) |
|
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
|
|
{ |
|
|
UpstreamAPIKey: "u1", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
{ |
|
|
UpstreamAPIKey: "u2", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") |
|
|
got, err := s.Get(ctx) |
|
|
if err != nil { |
|
|
t.Fatalf("unexpected error: %v", err) |
|
|
} |
|
|
if got != "u1" { |
|
|
t.Fatalf("want u1 (first wins), got %q", got) |
|
|
} |
|
|
} |
|
|
|
|
|
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) { |
|
|
hook := test.NewLocal(log.StandardLogger()) |
|
|
defer hook.Reset() |
|
|
|
|
|
defaultSource := NewStaticSecretSource("default") |
|
|
s := NewMappedSecretSource(defaultSource) |
|
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
|
|
{ |
|
|
UpstreamAPIKey: "u1", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
{ |
|
|
UpstreamAPIKey: "u2", |
|
|
APIKeys: []string{"k1"}, |
|
|
}, |
|
|
}) |
|
|
|
|
|
foundWarning := false |
|
|
for _, entry := range hook.AllEntries() { |
|
|
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." { |
|
|
foundWarning = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if !foundWarning { |
|
|
t.Fatal("expected warning log for duplicate client key, but none was found") |
|
|
} |
|
|
} |
|
|
|