File size: 3,657 Bytes
305487b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | package service
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/bytedance/gopkg/util/gopool"
)
const (
codexCredentialRefreshTickInterval = 10 * time.Minute
codexCredentialRefreshThreshold = 24 * time.Hour
codexCredentialRefreshBatchSize = 200
codexCredentialRefreshTimeout = 15 * time.Second
)
var (
codexCredentialRefreshOnce sync.Once
codexCredentialRefreshRunning atomic.Bool
)
func StartCodexCredentialAutoRefreshTask() {
codexCredentialRefreshOnce.Do(func() {
if !common.IsMasterNode {
return
}
gopool.Go(func() {
logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
ticker := time.NewTicker(codexCredentialRefreshTickInterval)
defer ticker.Stop()
runCodexCredentialAutoRefreshOnce()
for range ticker.C {
runCodexCredentialAutoRefreshOnce()
}
})
})
}
func runCodexCredentialAutoRefreshOnce() {
if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
return
}
defer codexCredentialRefreshRunning.Store(false)
ctx := context.Background()
now := time.Now()
var refreshed int
var scanned int
offset := 0
for {
var channels []*model.Channel
err := model.DB.
Select("id", "name", "key", "status", "channel_info").
Where("type = ? AND status = 1", constant.ChannelTypeCodex).
Order("id asc").
Limit(codexCredentialRefreshBatchSize).
Offset(offset).
Find(&channels).Error
if err != nil {
logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
return
}
if len(channels) == 0 {
break
}
offset += codexCredentialRefreshBatchSize
for _, ch := range channels {
if ch == nil {
continue
}
scanned++
if ch.ChannelInfo.IsMultiKey {
continue
}
rawKey := strings.TrimSpace(ch.Key)
if rawKey == "" {
continue
}
oauthKey, err := parseCodexOAuthKey(rawKey)
if err != nil {
continue
}
refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
if refreshToken == "" {
continue
}
expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
continue
}
refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
cancel()
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
continue
}
refreshed++
logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
}
}
if refreshed > 0 {
func() {
defer func() {
if r := recover(); r != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
}
}()
model.InitChannelCache()
}()
ResetProxyClientCache()
}
if common.DebugEnabled {
logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
}
}
|