| package controller | |
| import ( | |
| "context" | |
| "errors" | |
| "fmt" | |
| "net/http" | |
| "net/url" | |
| "strconv" | |
| "strings" | |
| "time" | |
| "github.com/QuantumNous/new-api/common" | |
| "github.com/QuantumNous/new-api/constant" | |
| "github.com/QuantumNous/new-api/model" | |
| "github.com/QuantumNous/new-api/relay/channel/codex" | |
| "github.com/QuantumNous/new-api/service" | |
| "github.com/gin-contrib/sessions" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| type codexOAuthCompleteRequest struct { | |
| Input string `json:"input"` | |
| } | |
| func codexOAuthSessionKey(channelID int, field string) string { | |
| return fmt.Sprintf("codex_oauth_%s_%d", field, channelID) | |
| } | |
| func parseCodexAuthorizationInput(input string) (code string, state string, err error) { | |
| v := strings.TrimSpace(input) | |
| if v == "" { | |
| return "", "", errors.New("empty input") | |
| } | |
| if strings.Contains(v, "#") { | |
| parts := strings.SplitN(v, "#", 2) | |
| code = strings.TrimSpace(parts[0]) | |
| state = strings.TrimSpace(parts[1]) | |
| return code, state, nil | |
| } | |
| if strings.Contains(v, "code=") { | |
| u, parseErr := url.Parse(v) | |
| if parseErr == nil { | |
| q := u.Query() | |
| code = strings.TrimSpace(q.Get("code")) | |
| state = strings.TrimSpace(q.Get("state")) | |
| return code, state, nil | |
| } | |
| q, parseErr := url.ParseQuery(v) | |
| if parseErr == nil { | |
| code = strings.TrimSpace(q.Get("code")) | |
| state = strings.TrimSpace(q.Get("state")) | |
| return code, state, nil | |
| } | |
| } | |
| code = v | |
| return code, "", nil | |
| } | |
| func StartCodexOAuth(c *gin.Context) { | |
| startCodexOAuthWithChannelID(c, 0) | |
| } | |
| func StartCodexOAuthForChannel(c *gin.Context) { | |
| channelID, err := strconv.Atoi(c.Param("id")) | |
| if err != nil { | |
| common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) | |
| return | |
| } | |
| startCodexOAuthWithChannelID(c, channelID) | |
| } | |
| func startCodexOAuthWithChannelID(c *gin.Context, channelID int) { | |
| if channelID > 0 { | |
| ch, err := model.GetChannelById(channelID, false) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if ch == nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"}) | |
| return | |
| } | |
| if ch.Type != constant.ChannelTypeCodex { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) | |
| return | |
| } | |
| } | |
| flow, err := service.CreateCodexOAuthAuthorizationFlow() | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| session := sessions.Default(c) | |
| session.Set(codexOAuthSessionKey(channelID, "state"), flow.State) | |
| session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier) | |
| session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix()) | |
| _ = session.Save() | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "", | |
| "data": gin.H{ | |
| "authorize_url": flow.AuthorizeURL, | |
| }, | |
| }) | |
| } | |
| func CompleteCodexOAuth(c *gin.Context) { | |
| completeCodexOAuthWithChannelID(c, 0) | |
| } | |
| func CompleteCodexOAuthForChannel(c *gin.Context) { | |
| channelID, err := strconv.Atoi(c.Param("id")) | |
| if err != nil { | |
| common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) | |
| return | |
| } | |
| completeCodexOAuthWithChannelID(c, channelID) | |
| } | |
| func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { | |
| req := codexOAuthCompleteRequest{} | |
| if err := c.ShouldBindJSON(&req); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| code, state, err := parseCodexAuthorizationInput(req.Input) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) | |
| return | |
| } | |
| if strings.TrimSpace(code) == "" { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"}) | |
| return | |
| } | |
| if strings.TrimSpace(state) == "" { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"}) | |
| return | |
| } | |
| if channelID > 0 { | |
| ch, err := model.GetChannelById(channelID, false) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if ch == nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"}) | |
| return | |
| } | |
| if ch.Type != constant.ChannelTypeCodex { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) | |
| return | |
| } | |
| } | |
| session := sessions.Default(c) | |
| expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string) | |
| verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string) | |
| if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"}) | |
| return | |
| } | |
| if state != expectedState { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"}) | |
| return | |
| } | |
| ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) | |
| defer cancel() | |
| tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier) | |
| if err != nil { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) | |
| return | |
| } | |
| accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken) | |
| if !ok { | |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"}) | |
| return | |
| } | |
| email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken) | |
| key := codex.OAuthKey{ | |
| AccessToken: tokenRes.AccessToken, | |
| RefreshToken: tokenRes.RefreshToken, | |
| AccountID: accountID, | |
| LastRefresh: time.Now().Format(time.RFC3339), | |
| Expired: tokenRes.ExpiresAt.Format(time.RFC3339), | |
| Email: email, | |
| Type: "codex", | |
| } | |
| encoded, err := common.Marshal(key) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| session.Delete(codexOAuthSessionKey(channelID, "state")) | |
| session.Delete(codexOAuthSessionKey(channelID, "verifier")) | |
| session.Delete(codexOAuthSessionKey(channelID, "created_at")) | |
| _ = session.Save() | |
| if channelID > 0 { | |
| if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| model.InitChannelCache() | |
| service.ResetProxyClientCache() | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "saved", | |
| "data": gin.H{ | |
| "channel_id": channelID, | |
| "account_id": accountID, | |
| "email": email, | |
| "expires_at": key.Expired, | |
| "last_refresh": key.LastRefresh, | |
| }, | |
| }) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{ | |
| "success": true, | |
| "message": "generated", | |
| "data": gin.H{ | |
| "key": string(encoded), | |
| "account_id": accountID, | |
| "email": email, | |
| "expires_at": key.Expired, | |
| "last_refresh": key.LastRefresh, | |
| }, | |
| }) | |
| } | |