Spaces:
Sleeping
Sleeping
| package api | |
| import ( | |
| "crypto/rand" | |
| "crypto/sha256" | |
| "encoding/base64" | |
| "errors" | |
| "fmt" | |
| "net/http" | |
| "net/url" | |
| "strings" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "go.uber.org/fx" | |
| "github.com/looplj/axonhub/internal/log" | |
| "github.com/looplj/axonhub/internal/pkg/xcache" | |
| "github.com/looplj/axonhub/llm/httpclient" | |
| "github.com/looplj/axonhub/llm/oauth" | |
| "github.com/looplj/axonhub/llm/transformer/openai/codex" | |
| ) | |
| type CodexHandlersParams struct { | |
| fx.In | |
| CacheConfig xcache.Config | |
| HttpClient *httpclient.HttpClient | |
| } | |
| type CodexHandlers struct { | |
| stateCache xcache.Cache[codexOAuthState] | |
| httpClient *httpclient.HttpClient | |
| } | |
| func NewCodexHandlers(params CodexHandlersParams) *CodexHandlers { | |
| return &CodexHandlers{ | |
| stateCache: xcache.NewFromConfig[codexOAuthState](params.CacheConfig), | |
| httpClient: params.HttpClient, | |
| } | |
| } | |
| type StartCodexOAuthRequest struct{} | |
| type StartCodexOAuthResponse struct { | |
| SessionID string `json:"session_id"` | |
| AuthURL string `json:"auth_url"` | |
| } | |
| type codexOAuthState struct { | |
| CodeVerifier string `json:"code_verifier"` | |
| CreatedAt int64 `json:"created_at"` | |
| } | |
| func generateCodexCodeVerifier() (string, error) { | |
| b := make([]byte, 64) | |
| if _, err := rand.Read(b); err != nil { | |
| return "", err | |
| } | |
| return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil | |
| } | |
| func generateCodexCodeChallenge(verifier string) string { | |
| hash := sha256.Sum256([]byte(verifier)) | |
| return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) | |
| } | |
| func generateCodexState() (string, error) { | |
| b := make([]byte, 32) | |
| if _, err := rand.Read(b); err != nil { | |
| return "", err | |
| } | |
| return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil | |
| } | |
| func codexOAuthCacheKey(sessionID string) string { | |
| return fmt.Sprintf("codex:oauth:%s", sessionID) | |
| } | |
| // StartOAuth creates a PKCE session and returns the authorize URL. | |
| // POST /admin/codex/oauth/start. | |
| func (h *CodexHandlers) StartOAuth(c *gin.Context) { | |
| ctx := c.Request.Context() | |
| var req StartCodexOAuthRequest | |
| if err := c.ShouldBindJSON(&req); err != nil { | |
| JSONError(c, http.StatusBadRequest, errors.New("invalid request format")) | |
| return | |
| } | |
| state, err := generateCodexState() | |
| if err != nil { | |
| JSONError(c, http.StatusInternalServerError, fmt.Errorf("failed to generate oauth state: %w", err)) | |
| return | |
| } | |
| codeVerifier, err := generateCodexCodeVerifier() | |
| if err != nil { | |
| JSONError(c, http.StatusInternalServerError, fmt.Errorf("failed to generate code verifier: %w", err)) | |
| return | |
| } | |
| codeChallenge := generateCodexCodeChallenge(codeVerifier) | |
| cacheKey := codexOAuthCacheKey(state) | |
| if err := h.stateCache.Set(ctx, cacheKey, codexOAuthState{CodeVerifier: codeVerifier, CreatedAt: time.Now().Unix()}, xcache.WithExpiration(10*time.Minute)); err != nil { | |
| JSONError(c, http.StatusInternalServerError, fmt.Errorf("failed to save oauth state: %w", err)) | |
| return | |
| } | |
| params := url.Values{} | |
| params.Set("response_type", "code") | |
| params.Set("client_id", codex.ClientID) | |
| params.Set("redirect_uri", codex.RedirectURI) | |
| params.Set("scope", codex.Scopes) | |
| params.Set("code_challenge", codeChallenge) | |
| params.Set("code_challenge_method", "S256") | |
| params.Set("state", state) | |
| params.Set("id_token_add_organizations", "true") | |
| params.Set("codex_cli_simplified_flow", "true") | |
| authURL := fmt.Sprintf("%s?%s", codex.AuthorizeURL, params.Encode()) | |
| c.JSON(http.StatusOK, StartCodexOAuthResponse{SessionID: state, AuthURL: authURL}) | |
| } | |
| type ExchangeCodexOAuthRequest struct { | |
| SessionID string `json:"session_id" binding:"required"` | |
| CallbackURL string `json:"callback_url" binding:"required"` | |
| } | |
| type ExchangeCodexOAuthResponse struct { | |
| Credentials string `json:"credentials"` | |
| } | |
| func parseCodexCallbackURL(callbackURL string) (string, string, error) { | |
| trimmed := strings.TrimSpace(callbackURL) | |
| if !strings.HasPrefix(trimmed, "http://") && !strings.HasPrefix(trimmed, "https://") { | |
| return "", "", fmt.Errorf("callback_url must be a full URL") | |
| } | |
| u, err := url.Parse(trimmed) | |
| if err != nil { | |
| return "", "", fmt.Errorf("invalid callback_url: %w", err) | |
| } | |
| q := u.Query() | |
| code := q.Get("code") | |
| if code == "" { | |
| return "", "", fmt.Errorf("code parameter not found in callback_url") | |
| } | |
| state := q.Get("state") | |
| if state == "" { | |
| return "", "", fmt.Errorf("state parameter not found in callback_url") | |
| } | |
| return code, state, nil | |
| } | |
| // Exchange exchanges callback URL for OAuth credentials JSON. | |
| // POST /admin/codex/oauth/exchange. | |
| func (h *CodexHandlers) Exchange(c *gin.Context) { | |
| ctx := c.Request.Context() | |
| var req ExchangeCodexOAuthRequest | |
| if err := c.ShouldBindJSON(&req); err != nil { | |
| JSONError(c, http.StatusBadRequest, errors.New("invalid request format")) | |
| return | |
| } | |
| if req.SessionID == "" || req.CallbackURL == "" { | |
| JSONError(c, http.StatusBadRequest, errors.New("session_id and callback_url are required")) | |
| return | |
| } | |
| cacheKey := codexOAuthCacheKey(req.SessionID) | |
| state, err := h.stateCache.Get(ctx, cacheKey) | |
| if err != nil { | |
| JSONError(c, http.StatusBadRequest, errors.New("invalid or expired oauth session")) | |
| return | |
| } | |
| if err := h.stateCache.Delete(ctx, cacheKey); err != nil { | |
| log.Warn(ctx, "failed to delete used oauth state from cache", log.String("session_id", req.SessionID), log.Cause(err)) | |
| } | |
| code, callbackState, err := parseCodexCallbackURL(req.CallbackURL) | |
| if err != nil { | |
| JSONError(c, http.StatusBadRequest, err) | |
| return | |
| } | |
| if callbackState != req.SessionID { | |
| JSONError(c, http.StatusBadRequest, errors.New("oauth state mismatch")) | |
| return | |
| } | |
| tokenProvider := codex.NewTokenProvider(codex.TokenProviderParams{ | |
| HTTPClient: h.httpClient, | |
| }) | |
| creds, err := tokenProvider.Exchange(ctx, oauth.ExchangeParams{ | |
| Code: code, | |
| CodeVerifier: state.CodeVerifier, | |
| ClientID: codex.ClientID, | |
| RedirectURI: codex.RedirectURI, | |
| }) | |
| if err != nil { | |
| JSONError(c, http.StatusBadGateway, fmt.Errorf("token exchange failed: %w", err)) | |
| return | |
| } | |
| output, err := creds.ToJSON() | |
| if err != nil { | |
| JSONError(c, http.StatusInternalServerError, fmt.Errorf("failed to encode credentials: %w", err)) | |
| return | |
| } | |
| c.JSON(http.StatusOK, ExchangeCodexOAuthResponse{Credentials: output}) | |
| } | |