| package api | |
| import ( | |
| "bytes" | |
| "encoding/json" | |
| "io" | |
| "net/http" | |
| "net/http/httptest" | |
| "testing" | |
| "github.com/gin-gonic/gin" | |
| "github.com/stretchr/testify/require" | |
| "github.com/looplj/axonhub/internal/contexts" | |
| "github.com/looplj/axonhub/internal/pkg/xcache" | |
| "github.com/looplj/axonhub/llm/httpclient" | |
| "github.com/looplj/axonhub/llm/transformer/openai/codex" | |
| ) | |
| type roundTripperFunc func(req *http.Request) (*http.Response, error) | |
| func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { | |
| return f(req) | |
| } | |
| func TestCodexHandlers_StartOAuth_InvalidJSON(t *testing.T) { | |
| gin.SetMode(gin.TestMode) | |
| h := NewCodexHandlers(CodexHandlersParams{ | |
| CacheConfig: xcache.Config{Mode: xcache.ModeMemory}, | |
| HttpClient: httpclient.NewHttpClient(), | |
| }) | |
| router := gin.New() | |
| router.Use(func(c *gin.Context) { | |
| c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123)) | |
| c.Next() | |
| }) | |
| router.POST("/admin/codex/oauth/start", h.StartOAuth) | |
| req := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{")) | |
| req.Header.Set("Content-Type", "application/json") | |
| w := httptest.NewRecorder() | |
| router.ServeHTTP(w, req) | |
| require.Equal(t, http.StatusBadRequest, w.Code) | |
| require.Contains(t, w.Body.String(), "invalid request format") | |
| } | |
| func TestCodexHandlers_Exchange_StateDeletedOnTokenExchangeFailure(t *testing.T) { | |
| gin.SetMode(gin.TestMode) | |
| var tokenCalls int | |
| tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| tokenCalls++ | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusBadGateway) | |
| _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) | |
| })) | |
| t.Cleanup(tokenServer.Close) | |
| transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { | |
| if req.URL.String() == codex.TokenURL { | |
| proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, req.Body) | |
| if err != nil { | |
| return nil, err | |
| } | |
| proxyReq.Header = req.Header.Clone() | |
| return http.DefaultTransport.RoundTrip(proxyReq) | |
| } | |
| return http.DefaultTransport.RoundTrip(req) | |
| }) | |
| hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport}) | |
| h := NewCodexHandlers(CodexHandlersParams{ | |
| CacheConfig: xcache.Config{Mode: xcache.ModeMemory}, | |
| HttpClient: hc, | |
| }) | |
| router := gin.New() | |
| router.Use(func(c *gin.Context) { | |
| c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123)) | |
| c.Next() | |
| }) | |
| router.POST("/admin/codex/oauth/start", h.StartOAuth) | |
| router.POST("/admin/codex/oauth/exchange", h.Exchange) | |
| startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}")) | |
| startReq.Header.Set("Content-Type", "application/json") | |
| startW := httptest.NewRecorder() | |
| router.ServeHTTP(startW, startReq) | |
| require.Equal(t, http.StatusOK, startW.Code) | |
| var startResp StartCodexOAuthResponse | |
| require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp)) | |
| require.NotEmpty(t, startResp.SessionID) | |
| exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{ | |
| SessionID: startResp.SessionID, | |
| CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID, | |
| }) | |
| require.NoError(t, err) | |
| exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody)) | |
| exchangeReq.Header.Set("Content-Type", "application/json") | |
| exchangeW := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW, exchangeReq) | |
| require.Equal(t, http.StatusBadGateway, exchangeW.Code) | |
| require.Equal(t, 1, tokenCalls) | |
| exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody)) | |
| exchangeReq2.Header.Set("Content-Type", "application/json") | |
| exchangeW2 := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW2, exchangeReq2) | |
| require.Equal(t, http.StatusBadRequest, exchangeW2.Code) | |
| require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session") | |
| } | |
| func TestCodexHandlers_Exchange_RejectsStateMismatch(t *testing.T) { | |
| gin.SetMode(gin.TestMode) | |
| tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| _, _ = w.Write([]byte(`{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"bearer"}`)) | |
| })) | |
| t.Cleanup(tokenServer.Close) | |
| transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { | |
| if req.URL.String() == codex.TokenURL { | |
| body, _ := io.ReadAll(req.Body) | |
| _ = req.Body.Close() | |
| proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, bytes.NewBuffer(body)) | |
| if err != nil { | |
| return nil, err | |
| } | |
| proxyReq.Header = req.Header.Clone() | |
| return http.DefaultTransport.RoundTrip(proxyReq) | |
| } | |
| return http.DefaultTransport.RoundTrip(req) | |
| }) | |
| hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport}) | |
| h := NewCodexHandlers(CodexHandlersParams{ | |
| CacheConfig: xcache.Config{Mode: xcache.ModeMemory}, | |
| HttpClient: hc, | |
| }) | |
| router := gin.New() | |
| router.Use(func(c *gin.Context) { | |
| c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123)) | |
| c.Next() | |
| }) | |
| router.POST("/admin/codex/oauth/start", h.StartOAuth) | |
| router.POST("/admin/codex/oauth/exchange", h.Exchange) | |
| startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}")) | |
| startReq.Header.Set("Content-Type", "application/json") | |
| startW := httptest.NewRecorder() | |
| router.ServeHTTP(startW, startReq) | |
| require.Equal(t, http.StatusOK, startW.Code) | |
| var startResp StartCodexOAuthResponse | |
| require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp)) | |
| require.NotEmpty(t, startResp.SessionID) | |
| exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{ | |
| SessionID: startResp.SessionID, | |
| CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=mismatch", | |
| }) | |
| require.NoError(t, err) | |
| exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody)) | |
| exchangeReq.Header.Set("Content-Type", "application/json") | |
| exchangeW := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW, exchangeReq) | |
| require.Equal(t, http.StatusBadRequest, exchangeW.Code) | |
| require.Contains(t, exchangeW.Body.String(), "oauth state mismatch") | |
| exchangeBody2, err := json.Marshal(ExchangeCodexOAuthRequest{ | |
| SessionID: startResp.SessionID, | |
| CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID, | |
| }) | |
| require.NoError(t, err) | |
| exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody2)) | |
| exchangeReq2.Header.Set("Content-Type", "application/json") | |
| exchangeW2 := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW2, exchangeReq2) | |
| require.Equal(t, http.StatusBadRequest, exchangeW2.Code) | |
| require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session") | |
| } | |
| func TestCodexHandlers_Exchange_DeletesStateOnSuccess(t *testing.T) { | |
| gin.SetMode(gin.TestMode) | |
| tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| _, _ = w.Write([]byte(`{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"bearer"}`)) | |
| })) | |
| t.Cleanup(tokenServer.Close) | |
| transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { | |
| if req.URL.String() == codex.TokenURL { | |
| body, _ := io.ReadAll(req.Body) | |
| _ = req.Body.Close() | |
| proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, bytes.NewBuffer(body)) | |
| if err != nil { | |
| return nil, err | |
| } | |
| proxyReq.Header = req.Header.Clone() | |
| return http.DefaultTransport.RoundTrip(proxyReq) | |
| } | |
| return http.DefaultTransport.RoundTrip(req) | |
| }) | |
| hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport}) | |
| h := NewCodexHandlers(CodexHandlersParams{ | |
| CacheConfig: xcache.Config{Mode: xcache.ModeMemory}, | |
| HttpClient: hc, | |
| }) | |
| router := gin.New() | |
| router.Use(func(c *gin.Context) { | |
| c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123)) | |
| c.Next() | |
| }) | |
| router.POST("/admin/codex/oauth/start", h.StartOAuth) | |
| router.POST("/admin/codex/oauth/exchange", h.Exchange) | |
| startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}")) | |
| startReq.Header.Set("Content-Type", "application/json") | |
| startW := httptest.NewRecorder() | |
| router.ServeHTTP(startW, startReq) | |
| require.Equal(t, http.StatusOK, startW.Code) | |
| var startResp StartCodexOAuthResponse | |
| require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp)) | |
| exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{ | |
| SessionID: startResp.SessionID, | |
| CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID, | |
| }) | |
| require.NoError(t, err) | |
| exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody)) | |
| exchangeReq.Header.Set("Content-Type", "application/json") | |
| exchangeW := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW, exchangeReq) | |
| require.Equal(t, http.StatusOK, exchangeW.Code) | |
| exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody)) | |
| exchangeReq2.Header.Set("Content-Type", "application/json") | |
| exchangeW2 := httptest.NewRecorder() | |
| router.ServeHTTP(exchangeW2, exchangeReq2) | |
| require.Equal(t, http.StatusBadRequest, exchangeW2.Code) | |
| require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session") | |
| } | |