| package middleware |
|
|
| import ( |
| "net/http" |
| "net/http/httptest" |
| "testing" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/config" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/assert" |
| ) |
|
|
| func init() { |
| |
| gin.SetMode(gin.TestMode) |
| } |
|
|
| |
|
|
| func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| tests := []struct { |
| name string |
| method string |
| origin string |
| }{ |
| { |
| name: "preflight_disallowed_origin", |
| method: http.MethodOptions, |
| origin: "https://evil.example.com", |
| }, |
| { |
| name: "get_disallowed_origin", |
| method: http.MethodGet, |
| origin: "https://evil.example.com", |
| }, |
| { |
| name: "post_disallowed_origin", |
| method: http.MethodPost, |
| origin: "https://attacker.example.com", |
| }, |
| { |
| name: "preflight_no_origin", |
| method: http.MethodOptions, |
| origin: "", |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(tt.method, "/", nil) |
| if tt.origin != "" { |
| c.Request.Header.Set("Origin", tt.origin) |
| } |
|
|
| middleware(c) |
|
|
| |
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"), |
| "不允许的 origin 不应收到 Allow-Headers") |
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"), |
| "不允许的 origin 不应收到 Allow-Methods") |
| assert.Empty(t, w.Header().Get("Access-Control-Max-Age"), |
| "不允许的 origin 不应收到 Max-Age") |
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), |
| "不允许的 origin 不应收到 Allow-Origin") |
| }) |
| } |
| } |
|
|
| func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| tests := []struct { |
| name string |
| method string |
| }{ |
| {name: "preflight_OPTIONS", method: http.MethodOptions}, |
| {name: "normal_GET", method: http.MethodGet}, |
| {name: "normal_POST", method: http.MethodPost}, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(tt.method, "/", nil) |
| c.Request.Header.Set("Origin", "https://allowed.example.com") |
|
|
| middleware(c) |
|
|
| |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), |
| "允许的 origin 应收到 Allow-Headers") |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), |
| "允许的 origin 应收到 Allow-Methods") |
| assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"), |
| "允许的 origin 应收到 Max-Age=86400") |
| assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"), |
| "允许的 origin 应收到 Allow-Origin") |
| }) |
| } |
| } |
|
|
| func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) |
| c.Request.Header.Set("Origin", "https://evil.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, http.StatusForbidden, w.Code, |
| "不允许的 origin 的 preflight 请求应返回 403") |
| } |
|
|
| func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) |
| c.Request.Header.Set("Origin", "https://allowed.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, http.StatusNoContent, w.Code, |
| "允许的 origin 的 preflight 请求应返回 204") |
| } |
|
|
| func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"*"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://any-origin.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"), |
| "通配符配置应返回 *") |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), |
| "通配符 origin 应设置 Allow-Headers") |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), |
| "通配符 origin 应设置 Allow-Methods") |
| } |
|
|
| func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: true, |
| } |
| middleware := CORS(cfg) |
|
|
| t.Run("allowed_origin_gets_credentials", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://allowed.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"), |
| "允许的 origin 且开启 credentials 应设置 Allow-Credentials") |
| }) |
|
|
| t.Run("disallowed_origin_no_credentials", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://evil.example.com") |
|
|
| middleware(c) |
|
|
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), |
| "不允许的 origin 不应收到 Allow-Credentials") |
| }) |
| } |
|
|
| func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"*"}, |
| AllowCredentials: true, |
| } |
| middleware := CORS(cfg) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://any.example.com") |
|
|
| middleware(c) |
|
|
| |
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), |
| "通配符 origin 应禁用 Allow-Credentials") |
| } |
|
|
| func TestCORS_MultipleAllowedOrigins(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{ |
| "https://app1.example.com", |
| "https://app2.example.com", |
| }, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| t.Run("first_origin_allowed", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://app1.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin")) |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) |
| }) |
|
|
| t.Run("second_origin_allowed", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://app2.example.com") |
|
|
| middleware(c) |
|
|
| assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin")) |
| assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) |
| }) |
|
|
| t.Run("unlisted_origin_rejected", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://app3.example.com") |
|
|
| middleware(c) |
|
|
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) |
| assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) |
| }) |
| } |
|
|
| func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) { |
| cfg := config.CORSConfig{ |
| AllowedOrigins: []string{"https://allowed.example.com"}, |
| AllowCredentials: false, |
| } |
| middleware := CORS(cfg) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.Request.Header.Set("Origin", "https://allowed.example.com") |
|
|
| middleware(c) |
|
|
| assert.Contains(t, w.Header().Values("Vary"), "Origin", |
| "非通配符允许的 origin 应设置 Vary: Origin") |
| } |
|
|
| func TestNormalizeOrigins(t *testing.T) { |
| tests := []struct { |
| name string |
| input []string |
| expect []string |
| }{ |
| {name: "nil_input", input: nil, expect: nil}, |
| {name: "empty_input", input: []string{}, expect: nil}, |
| {name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}}, |
| {name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}}, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| result := normalizeOrigins(tt.input) |
| assert.Equal(t, tt.expect, result) |
| }) |
| } |
| } |
|
|