File size: 3,237 Bytes
8059bf0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | //go:build unit
package middleware
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
v := c.Request.Context().Value(ctxkey.ClientRequestID)
require.NotNil(t, v)
id, ok := v.(string)
require.True(t, ok)
require.NotEmpty(t, id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestClientRequestID_PreservesExisting(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
require.True(t, ok)
require.Equal(t, "keep", id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestBodyLimit(4))
r.POST("/t", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
require.Error(t, err)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ForcePlatform("anthropic"))
r.GET("/t", func(c *gin.Context) {
require.True(t, HasForcePlatform(c))
v, ok := GetForcePlatformFromContext(c)
require.True(t, ok)
require.Equal(t, "anthropic", v)
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
require.Equal(t, "anthropic", ctxV)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
c := &gin.Context{}
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
c.Set(string(ContextKeyUserRole), "admin")
sub, ok := GetAuthSubjectFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), sub.UserID)
require.Equal(t, 2, sub.Concurrency)
role, ok := GetUserRoleFromContext(c)
require.True(t, ok)
require.Equal(t, "admin", role)
}
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
c := &gin.Context{}
key := &service.APIKey{ID: 1}
c.Set(string(ContextKeyAPIKey), key)
gotKey, ok := GetAPIKeyFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), gotKey.ID)
sub := &service.UserSubscription{ID: 2}
c.Set(string(ContextKeySubscription), sub)
gotSub, ok := GetSubscriptionFromContext(c)
require.True(t, ok)
require.Equal(t, int64(2), gotSub.ID)
}
|