| |
|
|
| package response |
|
|
| import ( |
| "encoding/json" |
| "errors" |
| "net/http" |
| "net/http/httptest" |
| "testing" |
|
|
| errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| |
|
|
| |
| func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { |
| t.Helper() |
| var got Response |
| require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) |
| return got |
| } |
|
|
| |
| func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { |
| t.Helper() |
| |
| var raw struct { |
| Code int `json:"code"` |
| Message string `json:"message"` |
| Reason string `json:"reason,omitempty"` |
| Data json.RawMessage `json:"data,omitempty"` |
| } |
| require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) |
|
|
| var pd PaginatedData |
| require.NoError(t, json.Unmarshal(raw.Data, &pd)) |
|
|
| return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd |
| } |
|
|
| |
| func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) |
| return w, c |
| } |
|
|
| |
|
|
| func TestErrorWithDetails(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| statusCode int |
| message string |
| reason string |
| metadata map[string]string |
| want Response |
| }{ |
| { |
| name: "plain_error", |
| statusCode: http.StatusBadRequest, |
| message: "invalid request", |
| want: Response{ |
| Code: http.StatusBadRequest, |
| Message: "invalid request", |
| }, |
| }, |
| { |
| name: "structured_error", |
| statusCode: http.StatusForbidden, |
| message: "no access", |
| reason: "FORBIDDEN", |
| metadata: map[string]string{"k": "v"}, |
| want: Response{ |
| Code: http.StatusForbidden, |
| Message: "no access", |
| Reason: "FORBIDDEN", |
| Metadata: map[string]string{"k": "v"}, |
| }, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata) |
|
|
| require.Equal(t, tt.statusCode, w.Code) |
|
|
| var got Response |
| require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) |
| require.Equal(t, tt.want, got) |
| }) |
| } |
| } |
|
|
| func TestErrorFrom(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| err error |
| wantWritten bool |
| wantHTTPCode int |
| wantBody Response |
| }{ |
| { |
| name: "nil_error", |
| err: nil, |
| wantWritten: false, |
| }, |
| { |
| name: "application_error", |
| err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), |
| wantWritten: true, |
| wantHTTPCode: http.StatusForbidden, |
| wantBody: Response{ |
| Code: http.StatusForbidden, |
| Message: "no access", |
| Reason: "FORBIDDEN", |
| Metadata: map[string]string{"scope": "admin"}, |
| }, |
| }, |
| { |
| name: "bad_request_error", |
| err: errors2.BadRequest("INVALID_REQUEST", "invalid request"), |
| wantWritten: true, |
| wantHTTPCode: http.StatusBadRequest, |
| wantBody: Response{ |
| Code: http.StatusBadRequest, |
| Message: "invalid request", |
| Reason: "INVALID_REQUEST", |
| }, |
| }, |
| { |
| name: "unauthorized_error", |
| err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"), |
| wantWritten: true, |
| wantHTTPCode: http.StatusUnauthorized, |
| wantBody: Response{ |
| Code: http.StatusUnauthorized, |
| Message: "unauthorized", |
| Reason: "UNAUTHORIZED", |
| }, |
| }, |
| { |
| name: "not_found_error", |
| err: errors2.NotFound("NOT_FOUND", "not found"), |
| wantWritten: true, |
| wantHTTPCode: http.StatusNotFound, |
| wantBody: Response{ |
| Code: http.StatusNotFound, |
| Message: "not found", |
| Reason: "NOT_FOUND", |
| }, |
| }, |
| { |
| name: "conflict_error", |
| err: errors2.Conflict("CONFLICT", "conflict"), |
| wantWritten: true, |
| wantHTTPCode: http.StatusConflict, |
| wantBody: Response{ |
| Code: http.StatusConflict, |
| Message: "conflict", |
| Reason: "CONFLICT", |
| }, |
| }, |
| { |
| name: "unknown_error_defaults_to_500", |
| err: errors.New("boom"), |
| wantWritten: true, |
| wantHTTPCode: http.StatusInternalServerError, |
| wantBody: Response{ |
| Code: http.StatusInternalServerError, |
| Message: errors2.UnknownMessage, |
| }, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| written := ErrorFrom(c, tt.err) |
| require.Equal(t, tt.wantWritten, written) |
|
|
| if !tt.wantWritten { |
| require.Equal(t, 200, w.Code) |
| require.Empty(t, w.Body.String()) |
| return |
| } |
|
|
| require.Equal(t, tt.wantHTTPCode, w.Code) |
| var got Response |
| require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) |
| require.Equal(t, tt.wantBody, got) |
| }) |
| } |
| } |
|
|
| |
|
|
| func TestSuccess(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| data any |
| wantCode int |
| wantBody Response |
| }{ |
| { |
| name: "返回字符串数据", |
| data: "hello", |
| wantCode: http.StatusOK, |
| wantBody: Response{Code: 0, Message: "success", Data: "hello"}, |
| }, |
| { |
| name: "返回nil数据", |
| data: nil, |
| wantCode: http.StatusOK, |
| wantBody: Response{Code: 0, Message: "success"}, |
| }, |
| { |
| name: "返回map数据", |
| data: map[string]string{"key": "value"}, |
| wantCode: http.StatusOK, |
| wantBody: Response{Code: 0, Message: "success"}, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Success(c, tt.data) |
|
|
| require.Equal(t, tt.wantCode, w.Code) |
|
|
| |
| got := parseResponseBody(t, w) |
| require.Equal(t, 0, got.Code) |
| require.Equal(t, "success", got.Message) |
|
|
| if tt.data == nil { |
| require.Nil(t, got.Data) |
| } else { |
| require.NotNil(t, got.Data) |
| } |
| }) |
| } |
| } |
|
|
| func TestCreated(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| data any |
| wantCode int |
| }{ |
| { |
| name: "创建成功_返回数据", |
| data: map[string]int{"id": 42}, |
| wantCode: http.StatusCreated, |
| }, |
| { |
| name: "创建成功_nil数据", |
| data: nil, |
| wantCode: http.StatusCreated, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Created(c, tt.data) |
|
|
| require.Equal(t, tt.wantCode, w.Code) |
|
|
| got := parseResponseBody(t, w) |
| require.Equal(t, 0, got.Code) |
| require.Equal(t, "success", got.Message) |
| }) |
| } |
| } |
|
|
| func TestError(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| statusCode int |
| message string |
| }{ |
| { |
| name: "400错误", |
| statusCode: http.StatusBadRequest, |
| message: "bad request", |
| }, |
| { |
| name: "500错误", |
| statusCode: http.StatusInternalServerError, |
| message: "internal error", |
| }, |
| { |
| name: "自定义状态码", |
| statusCode: 418, |
| message: "I'm a teapot", |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Error(c, tt.statusCode, tt.message) |
|
|
| require.Equal(t, tt.statusCode, w.Code) |
|
|
| got := parseResponseBody(t, w) |
| require.Equal(t, tt.statusCode, got.Code) |
| require.Equal(t, tt.message, got.Message) |
| require.Empty(t, got.Reason) |
| require.Nil(t, got.Metadata) |
| require.Nil(t, got.Data) |
| }) |
| } |
| } |
|
|
| func TestBadRequest(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| BadRequest(c, "参数无效") |
|
|
| require.Equal(t, http.StatusBadRequest, w.Code) |
| got := parseResponseBody(t, w) |
| require.Equal(t, http.StatusBadRequest, got.Code) |
| require.Equal(t, "参数无效", got.Message) |
| } |
|
|
| func TestUnauthorized(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Unauthorized(c, "未登录") |
|
|
| require.Equal(t, http.StatusUnauthorized, w.Code) |
| got := parseResponseBody(t, w) |
| require.Equal(t, http.StatusUnauthorized, got.Code) |
| require.Equal(t, "未登录", got.Message) |
| } |
|
|
| func TestForbidden(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Forbidden(c, "无权限") |
|
|
| require.Equal(t, http.StatusForbidden, w.Code) |
| got := parseResponseBody(t, w) |
| require.Equal(t, http.StatusForbidden, got.Code) |
| require.Equal(t, "无权限", got.Message) |
| } |
|
|
| func TestNotFound(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| NotFound(c, "资源不存在") |
|
|
| require.Equal(t, http.StatusNotFound, w.Code) |
| got := parseResponseBody(t, w) |
| require.Equal(t, http.StatusNotFound, got.Code) |
| require.Equal(t, "资源不存在", got.Message) |
| } |
|
|
| func TestInternalError(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| InternalError(c, "服务器内部错误") |
|
|
| require.Equal(t, http.StatusInternalServerError, w.Code) |
| got := parseResponseBody(t, w) |
| require.Equal(t, http.StatusInternalServerError, got.Code) |
| require.Equal(t, "服务器内部错误", got.Message) |
| } |
|
|
| func TestPaginated(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| items any |
| total int64 |
| page int |
| pageSize int |
| wantPages int |
| wantTotal int64 |
| wantPage int |
| wantPageSize int |
| }{ |
| { |
| name: "标准分页_多页", |
| items: []string{"a", "b"}, |
| total: 25, |
| page: 1, |
| pageSize: 10, |
| wantPages: 3, |
| wantTotal: 25, |
| wantPage: 1, |
| wantPageSize: 10, |
| }, |
| { |
| name: "总数刚好整除", |
| items: []string{"a"}, |
| total: 20, |
| page: 2, |
| pageSize: 10, |
| wantPages: 2, |
| wantTotal: 20, |
| wantPage: 2, |
| wantPageSize: 10, |
| }, |
| { |
| name: "总数为0_pages至少为1", |
| items: []string{}, |
| total: 0, |
| page: 1, |
| pageSize: 10, |
| wantPages: 1, |
| wantTotal: 0, |
| wantPage: 1, |
| wantPageSize: 10, |
| }, |
| { |
| name: "单页数据", |
| items: []int{1, 2, 3}, |
| total: 3, |
| page: 1, |
| pageSize: 20, |
| wantPages: 1, |
| wantTotal: 3, |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "总数为1", |
| items: []string{"only"}, |
| total: 1, |
| page: 1, |
| pageSize: 10, |
| wantPages: 1, |
| wantTotal: 1, |
| wantPage: 1, |
| wantPageSize: 10, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) |
|
|
| require.Equal(t, http.StatusOK, w.Code) |
|
|
| resp, pd := parsePaginatedBody(t, w) |
| require.Equal(t, 0, resp.Code) |
| require.Equal(t, "success", resp.Message) |
| require.Equal(t, tt.wantTotal, pd.Total) |
| require.Equal(t, tt.wantPage, pd.Page) |
| require.Equal(t, tt.wantPageSize, pd.PageSize) |
| require.Equal(t, tt.wantPages, pd.Pages) |
| }) |
| } |
| } |
|
|
| func TestPaginatedWithResult(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| items any |
| pagination *PaginationResult |
| wantTotal int64 |
| wantPage int |
| wantPageSize int |
| wantPages int |
| }{ |
| { |
| name: "正常分页结果", |
| items: []string{"a", "b"}, |
| pagination: &PaginationResult{ |
| Total: 50, |
| Page: 3, |
| PageSize: 10, |
| Pages: 5, |
| }, |
| wantTotal: 50, |
| wantPage: 3, |
| wantPageSize: 10, |
| wantPages: 5, |
| }, |
| { |
| name: "pagination为nil_使用默认值", |
| items: []string{}, |
| pagination: nil, |
| wantTotal: 0, |
| wantPage: 1, |
| wantPageSize: 20, |
| wantPages: 1, |
| }, |
| { |
| name: "单页结果", |
| items: []int{1}, |
| pagination: &PaginationResult{ |
| Total: 1, |
| Page: 1, |
| PageSize: 20, |
| Pages: 1, |
| }, |
| wantTotal: 1, |
| wantPage: 1, |
| wantPageSize: 20, |
| wantPages: 1, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| PaginatedWithResult(c, tt.items, tt.pagination) |
|
|
| require.Equal(t, http.StatusOK, w.Code) |
|
|
| resp, pd := parsePaginatedBody(t, w) |
| require.Equal(t, 0, resp.Code) |
| require.Equal(t, "success", resp.Message) |
| require.Equal(t, tt.wantTotal, pd.Total) |
| require.Equal(t, tt.wantPage, pd.Page) |
| require.Equal(t, tt.wantPageSize, pd.PageSize) |
| require.Equal(t, tt.wantPages, pd.Pages) |
| }) |
| } |
| } |
|
|
| func TestParsePagination(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| tests := []struct { |
| name string |
| query string |
| wantPage int |
| wantPageSize int |
| }{ |
| { |
| name: "无参数_使用默认值", |
| query: "", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "仅指定page", |
| query: "page=3", |
| wantPage: 3, |
| wantPageSize: 20, |
| }, |
| { |
| name: "仅指定page_size", |
| query: "page_size=50", |
| wantPage: 1, |
| wantPageSize: 50, |
| }, |
| { |
| name: "同时指定page和page_size", |
| query: "page=2&page_size=30", |
| wantPage: 2, |
| wantPageSize: 30, |
| }, |
| { |
| name: "使用limit代替page_size", |
| query: "limit=15", |
| wantPage: 1, |
| wantPageSize: 15, |
| }, |
| { |
| name: "page_size优先于limit", |
| query: "page_size=25&limit=50", |
| wantPage: 1, |
| wantPageSize: 25, |
| }, |
| { |
| name: "page为0_使用默认值", |
| query: "page=0", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "page_size超过1000_使用默认值", |
| query: "page_size=1001", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "page_size恰好1000_有效", |
| query: "page_size=1000", |
| wantPage: 1, |
| wantPageSize: 1000, |
| }, |
| { |
| name: "page为非数字_使用默认值", |
| query: "page=abc", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "page_size为非数字_使用默认值", |
| query: "page_size=xyz", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "limit为非数字_使用默认值", |
| query: "limit=abc", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "page_size为0_使用默认值", |
| query: "page_size=0", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "limit为0_使用默认值", |
| query: "limit=0", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "大页码", |
| query: "page=999&page_size=100", |
| wantPage: 999, |
| wantPageSize: 100, |
| }, |
| { |
| name: "page_size为1_最小有效值", |
| query: "page_size=1", |
| wantPage: 1, |
| wantPageSize: 1, |
| }, |
| { |
| name: "混合数字和字母的page", |
| query: "page=12a", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| { |
| name: "limit超过1000_使用默认值", |
| query: "limit=2000", |
| wantPage: 1, |
| wantPageSize: 20, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| _, c := newContextWithQuery(tt.query) |
|
|
| page, pageSize := ParsePagination(c) |
|
|
| require.Equal(t, tt.wantPage, page, "page 不符合预期") |
| require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") |
| }) |
| } |
| } |
|
|
| func Test_parseInt(t *testing.T) { |
| tests := []struct { |
| name string |
| input string |
| wantVal int |
| wantErr bool |
| }{ |
| { |
| name: "正常数字", |
| input: "123", |
| wantVal: 123, |
| wantErr: false, |
| }, |
| { |
| name: "零", |
| input: "0", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "单个数字", |
| input: "5", |
| wantVal: 5, |
| wantErr: false, |
| }, |
| { |
| name: "大数字", |
| input: "99999", |
| wantVal: 99999, |
| wantErr: false, |
| }, |
| { |
| name: "包含字母_返回0", |
| input: "abc", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "数字开头接字母_返回0", |
| input: "12a", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "包含负号_返回0", |
| input: "-1", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "包含小数点_返回0", |
| input: "1.5", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "包含空格_返回0", |
| input: "1 2", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| { |
| name: "空字符串", |
| input: "", |
| wantVal: 0, |
| wantErr: false, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| val, err := parseInt(tt.input) |
| if tt.wantErr { |
| require.Error(t, err) |
| } else { |
| require.NoError(t, err) |
| } |
| require.Equal(t, tt.wantVal, val) |
| }) |
| } |
| } |
|
|