newtest / service /codex_credential_refresh.go
xwwww's picture
Upload 888 files
305487b verified
package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
)
type CodexCredentialRefreshOptions struct {
ResetCaches bool
}
type CodexOAuthKey struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Type string `json:"type,omitempty"`
Expired string `json:"expired,omitempty"`
}
func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("codex channel: empty oauth key")
}
var key CodexOAuthKey
if err := common.Unmarshal([]byte(raw), &key); err != nil {
return nil, errors.New("codex channel: invalid oauth key json")
}
return &key, nil
}
func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
ch, err := model.GetChannelById(channelID, true)
if err != nil {
return nil, nil, err
}
if ch == nil {
return nil, nil, fmt.Errorf("channel not found")
}
if ch.Type != constant.ChannelTypeCodex {
return nil, nil, fmt.Errorf("channel type is not Codex")
}
oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
if err != nil {
return nil, nil, err
}
if strings.TrimSpace(oauthKey.RefreshToken) == "" {
return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
}
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
if err != nil {
return nil, nil, err
}
oauthKey.AccessToken = res.AccessToken
oauthKey.RefreshToken = res.RefreshToken
oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
if strings.TrimSpace(oauthKey.Type) == "" {
oauthKey.Type = "codex"
}
if strings.TrimSpace(oauthKey.AccountID) == "" {
if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
oauthKey.AccountID = accountID
}
}
if strings.TrimSpace(oauthKey.Email) == "" {
if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
oauthKey.Email = email
}
}
encoded, err := common.Marshal(oauthKey)
if err != nil {
return nil, nil, err
}
if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
return nil, nil, err
}
if opts.ResetCaches {
model.InitChannelCache()
ResetProxyClientCache()
}
return oauthKey, ch, nil
}