axonhub / internal /server /api /codex_test.go
llzai's picture
Upload 1793 files
9853396 verified
package api
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/looplj/axonhub/internal/contexts"
"github.com/looplj/axonhub/internal/pkg/xcache"
"github.com/looplj/axonhub/llm/httpclient"
"github.com/looplj/axonhub/llm/transformer/openai/codex"
)
type roundTripperFunc func(req *http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestCodexHandlers_StartOAuth_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewCodexHandlers(CodexHandlersParams{
CacheConfig: xcache.Config{Mode: xcache.ModeMemory},
HttpClient: httpclient.NewHttpClient(),
})
router := gin.New()
router.Use(func(c *gin.Context) {
c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123))
c.Next()
})
router.POST("/admin/codex/oauth/start", h.StartOAuth)
req := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{"))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), "invalid request format")
}
func TestCodexHandlers_Exchange_StateDeletedOnTokenExchangeFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
var tokenCalls int
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenCalls++
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte(`{"error":"bad_gateway"}`))
}))
t.Cleanup(tokenServer.Close)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() == codex.TokenURL {
proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, req.Body)
if err != nil {
return nil, err
}
proxyReq.Header = req.Header.Clone()
return http.DefaultTransport.RoundTrip(proxyReq)
}
return http.DefaultTransport.RoundTrip(req)
})
hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport})
h := NewCodexHandlers(CodexHandlersParams{
CacheConfig: xcache.Config{Mode: xcache.ModeMemory},
HttpClient: hc,
})
router := gin.New()
router.Use(func(c *gin.Context) {
c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123))
c.Next()
})
router.POST("/admin/codex/oauth/start", h.StartOAuth)
router.POST("/admin/codex/oauth/exchange", h.Exchange)
startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}"))
startReq.Header.Set("Content-Type", "application/json")
startW := httptest.NewRecorder()
router.ServeHTTP(startW, startReq)
require.Equal(t, http.StatusOK, startW.Code)
var startResp StartCodexOAuthResponse
require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp))
require.NotEmpty(t, startResp.SessionID)
exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{
SessionID: startResp.SessionID,
CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID,
})
require.NoError(t, err)
exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody))
exchangeReq.Header.Set("Content-Type", "application/json")
exchangeW := httptest.NewRecorder()
router.ServeHTTP(exchangeW, exchangeReq)
require.Equal(t, http.StatusBadGateway, exchangeW.Code)
require.Equal(t, 1, tokenCalls)
exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody))
exchangeReq2.Header.Set("Content-Type", "application/json")
exchangeW2 := httptest.NewRecorder()
router.ServeHTTP(exchangeW2, exchangeReq2)
require.Equal(t, http.StatusBadRequest, exchangeW2.Code)
require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session")
}
func TestCodexHandlers_Exchange_RejectsStateMismatch(t *testing.T) {
gin.SetMode(gin.TestMode)
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"bearer"}`))
}))
t.Cleanup(tokenServer.Close)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() == codex.TokenURL {
body, _ := io.ReadAll(req.Body)
_ = req.Body.Close()
proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, bytes.NewBuffer(body))
if err != nil {
return nil, err
}
proxyReq.Header = req.Header.Clone()
return http.DefaultTransport.RoundTrip(proxyReq)
}
return http.DefaultTransport.RoundTrip(req)
})
hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport})
h := NewCodexHandlers(CodexHandlersParams{
CacheConfig: xcache.Config{Mode: xcache.ModeMemory},
HttpClient: hc,
})
router := gin.New()
router.Use(func(c *gin.Context) {
c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123))
c.Next()
})
router.POST("/admin/codex/oauth/start", h.StartOAuth)
router.POST("/admin/codex/oauth/exchange", h.Exchange)
startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}"))
startReq.Header.Set("Content-Type", "application/json")
startW := httptest.NewRecorder()
router.ServeHTTP(startW, startReq)
require.Equal(t, http.StatusOK, startW.Code)
var startResp StartCodexOAuthResponse
require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp))
require.NotEmpty(t, startResp.SessionID)
exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{
SessionID: startResp.SessionID,
CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=mismatch",
})
require.NoError(t, err)
exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody))
exchangeReq.Header.Set("Content-Type", "application/json")
exchangeW := httptest.NewRecorder()
router.ServeHTTP(exchangeW, exchangeReq)
require.Equal(t, http.StatusBadRequest, exchangeW.Code)
require.Contains(t, exchangeW.Body.String(), "oauth state mismatch")
exchangeBody2, err := json.Marshal(ExchangeCodexOAuthRequest{
SessionID: startResp.SessionID,
CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID,
})
require.NoError(t, err)
exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody2))
exchangeReq2.Header.Set("Content-Type", "application/json")
exchangeW2 := httptest.NewRecorder()
router.ServeHTTP(exchangeW2, exchangeReq2)
require.Equal(t, http.StatusBadRequest, exchangeW2.Code)
require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session")
}
func TestCodexHandlers_Exchange_DeletesStateOnSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"bearer"}`))
}))
t.Cleanup(tokenServer.Close)
transport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() == codex.TokenURL {
body, _ := io.ReadAll(req.Body)
_ = req.Body.Close()
proxyReq, err := http.NewRequestWithContext(req.Context(), req.Method, tokenServer.URL, bytes.NewBuffer(body))
if err != nil {
return nil, err
}
proxyReq.Header = req.Header.Clone()
return http.DefaultTransport.RoundTrip(proxyReq)
}
return http.DefaultTransport.RoundTrip(req)
})
hc := httpclient.NewHttpClientWithClient(&http.Client{Transport: transport})
h := NewCodexHandlers(CodexHandlersParams{
CacheConfig: xcache.Config{Mode: xcache.ModeMemory},
HttpClient: hc,
})
router := gin.New()
router.Use(func(c *gin.Context) {
c.Request = c.Request.WithContext(contexts.WithProjectID(c.Request.Context(), 123))
c.Next()
})
router.POST("/admin/codex/oauth/start", h.StartOAuth)
router.POST("/admin/codex/oauth/exchange", h.Exchange)
startReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/start", bytes.NewBufferString("{}"))
startReq.Header.Set("Content-Type", "application/json")
startW := httptest.NewRecorder()
router.ServeHTTP(startW, startReq)
require.Equal(t, http.StatusOK, startW.Code)
var startResp StartCodexOAuthResponse
require.NoError(t, json.Unmarshal(startW.Body.Bytes(), &startResp))
exchangeBody, err := json.Marshal(ExchangeCodexOAuthRequest{
SessionID: startResp.SessionID,
CallbackURL: "http://localhost:1455/auth/callback?code=test-code&state=" + startResp.SessionID,
})
require.NoError(t, err)
exchangeReq := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody))
exchangeReq.Header.Set("Content-Type", "application/json")
exchangeW := httptest.NewRecorder()
router.ServeHTTP(exchangeW, exchangeReq)
require.Equal(t, http.StatusOK, exchangeW.Code)
exchangeReq2 := httptest.NewRequest(http.MethodPost, "/admin/codex/oauth/exchange", bytes.NewBuffer(exchangeBody))
exchangeReq2.Header.Set("Content-Type", "application/json")
exchangeW2 := httptest.NewRecorder()
router.ServeHTTP(exchangeW2, exchangeReq2)
require.Equal(t, http.StatusBadRequest, exchangeW2.Code)
require.Contains(t, exchangeW2.Body.String(), "invalid or expired oauth session")
}