| package service | |
| import ( | |
| "context" | |
| "errors" | |
| "fmt" | |
| "strings" | |
| "time" | |
| "github.com/QuantumNous/new-api/common" | |
| "github.com/QuantumNous/new-api/constant" | |
| "github.com/QuantumNous/new-api/model" | |
| ) | |
| type CodexCredentialRefreshOptions struct { | |
| ResetCaches bool | |
| } | |
| type CodexOAuthKey struct { | |
| IDToken string `json:"id_token,omitempty"` | |
| AccessToken string `json:"access_token,omitempty"` | |
| RefreshToken string `json:"refresh_token,omitempty"` | |
| AccountID string `json:"account_id,omitempty"` | |
| LastRefresh string `json:"last_refresh,omitempty"` | |
| Email string `json:"email,omitempty"` | |
| Type string `json:"type,omitempty"` | |
| Expired string `json:"expired,omitempty"` | |
| } | |
| func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) { | |
| if strings.TrimSpace(raw) == "" { | |
| return nil, errors.New("codex channel: empty oauth key") | |
| } | |
| var key CodexOAuthKey | |
| if err := common.Unmarshal([]byte(raw), &key); err != nil { | |
| return nil, errors.New("codex channel: invalid oauth key json") | |
| } | |
| return &key, nil | |
| } | |
| func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) { | |
| ch, err := model.GetChannelById(channelID, true) | |
| if err != nil { | |
| return nil, nil, err | |
| } | |
| if ch == nil { | |
| return nil, nil, fmt.Errorf("channel not found") | |
| } | |
| if ch.Type != constant.ChannelTypeCodex { | |
| return nil, nil, fmt.Errorf("channel type is not Codex") | |
| } | |
| oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key)) | |
| if err != nil { | |
| return nil, nil, err | |
| } | |
| if strings.TrimSpace(oauthKey.RefreshToken) == "" { | |
| return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential") | |
| } | |
| refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second) | |
| defer cancel() | |
| res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken) | |
| if err != nil { | |
| return nil, nil, err | |
| } | |
| oauthKey.AccessToken = res.AccessToken | |
| oauthKey.RefreshToken = res.RefreshToken | |
| oauthKey.LastRefresh = time.Now().Format(time.RFC3339) | |
| oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339) | |
| if strings.TrimSpace(oauthKey.Type) == "" { | |
| oauthKey.Type = "codex" | |
| } | |
| if strings.TrimSpace(oauthKey.AccountID) == "" { | |
| if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok { | |
| oauthKey.AccountID = accountID | |
| } | |
| } | |
| if strings.TrimSpace(oauthKey.Email) == "" { | |
| if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok { | |
| oauthKey.Email = email | |
| } | |
| } | |
| encoded, err := common.Marshal(oauthKey) | |
| if err != nil { | |
| return nil, nil, err | |
| } | |
| if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil { | |
| return nil, nil, err | |
| } | |
| if opts.ResetCaches { | |
| model.InitChannelCache() | |
| ResetProxyClientCache() | |
| } | |
| return oauthKey, ch, nil | |
| } | |