newtest / service /codex_credential_refresh_task.go
xwwww's picture
Upload 888 files
305487b verified
package service
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/bytedance/gopkg/util/gopool"
)
const (
codexCredentialRefreshTickInterval = 10 * time.Minute
codexCredentialRefreshThreshold = 24 * time.Hour
codexCredentialRefreshBatchSize = 200
codexCredentialRefreshTimeout = 15 * time.Second
)
var (
codexCredentialRefreshOnce sync.Once
codexCredentialRefreshRunning atomic.Bool
)
func StartCodexCredentialAutoRefreshTask() {
codexCredentialRefreshOnce.Do(func() {
if !common.IsMasterNode {
return
}
gopool.Go(func() {
logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
ticker := time.NewTicker(codexCredentialRefreshTickInterval)
defer ticker.Stop()
runCodexCredentialAutoRefreshOnce()
for range ticker.C {
runCodexCredentialAutoRefreshOnce()
}
})
})
}
func runCodexCredentialAutoRefreshOnce() {
if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
return
}
defer codexCredentialRefreshRunning.Store(false)
ctx := context.Background()
now := time.Now()
var refreshed int
var scanned int
offset := 0
for {
var channels []*model.Channel
err := model.DB.
Select("id", "name", "key", "status", "channel_info").
Where("type = ? AND status = 1", constant.ChannelTypeCodex).
Order("id asc").
Limit(codexCredentialRefreshBatchSize).
Offset(offset).
Find(&channels).Error
if err != nil {
logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
return
}
if len(channels) == 0 {
break
}
offset += codexCredentialRefreshBatchSize
for _, ch := range channels {
if ch == nil {
continue
}
scanned++
if ch.ChannelInfo.IsMultiKey {
continue
}
rawKey := strings.TrimSpace(ch.Key)
if rawKey == "" {
continue
}
oauthKey, err := parseCodexOAuthKey(rawKey)
if err != nil {
continue
}
refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
if refreshToken == "" {
continue
}
expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
continue
}
refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
cancel()
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
continue
}
refreshed++
logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
}
}
if refreshed > 0 {
func() {
defer func() {
if r := recover(); r != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
}
}()
model.InitChannelCache()
}()
ResetProxyClientCache()
}
if common.DebugEnabled {
logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
}
}