| package service | |
| import ( | |
| "context" | |
| "fmt" | |
| "strings" | |
| "sync" | |
| "sync/atomic" | |
| "time" | |
| "github.com/QuantumNous/new-api/common" | |
| "github.com/QuantumNous/new-api/constant" | |
| "github.com/QuantumNous/new-api/logger" | |
| "github.com/QuantumNous/new-api/model" | |
| "github.com/bytedance/gopkg/util/gopool" | |
| ) | |
| const ( | |
| codexCredentialRefreshTickInterval = 10 * time.Minute | |
| codexCredentialRefreshThreshold = 24 * time.Hour | |
| codexCredentialRefreshBatchSize = 200 | |
| codexCredentialRefreshTimeout = 15 * time.Second | |
| ) | |
| var ( | |
| codexCredentialRefreshOnce sync.Once | |
| codexCredentialRefreshRunning atomic.Bool | |
| ) | |
| func StartCodexCredentialAutoRefreshTask() { | |
| codexCredentialRefreshOnce.Do(func() { | |
| if !common.IsMasterNode { | |
| return | |
| } | |
| gopool.Go(func() { | |
| logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold)) | |
| ticker := time.NewTicker(codexCredentialRefreshTickInterval) | |
| defer ticker.Stop() | |
| runCodexCredentialAutoRefreshOnce() | |
| for range ticker.C { | |
| runCodexCredentialAutoRefreshOnce() | |
| } | |
| }) | |
| }) | |
| } | |
| func runCodexCredentialAutoRefreshOnce() { | |
| if !codexCredentialRefreshRunning.CompareAndSwap(false, true) { | |
| return | |
| } | |
| defer codexCredentialRefreshRunning.Store(false) | |
| ctx := context.Background() | |
| now := time.Now() | |
| var refreshed int | |
| var scanned int | |
| offset := 0 | |
| for { | |
| var channels []*model.Channel | |
| err := model.DB. | |
| Select("id", "name", "key", "status", "channel_info"). | |
| Where("type = ? AND status = 1", constant.ChannelTypeCodex). | |
| Order("id asc"). | |
| Limit(codexCredentialRefreshBatchSize). | |
| Offset(offset). | |
| Find(&channels).Error | |
| if err != nil { | |
| logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err)) | |
| return | |
| } | |
| if len(channels) == 0 { | |
| break | |
| } | |
| offset += codexCredentialRefreshBatchSize | |
| for _, ch := range channels { | |
| if ch == nil { | |
| continue | |
| } | |
| scanned++ | |
| if ch.ChannelInfo.IsMultiKey { | |
| continue | |
| } | |
| rawKey := strings.TrimSpace(ch.Key) | |
| if rawKey == "" { | |
| continue | |
| } | |
| oauthKey, err := parseCodexOAuthKey(rawKey) | |
| if err != nil { | |
| continue | |
| } | |
| refreshToken := strings.TrimSpace(oauthKey.RefreshToken) | |
| if refreshToken == "" { | |
| continue | |
| } | |
| expiredAtRaw := strings.TrimSpace(oauthKey.Expired) | |
| expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw) | |
| if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold { | |
| continue | |
| } | |
| refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout) | |
| newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false}) | |
| cancel() | |
| if err != nil { | |
| logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err)) | |
| continue | |
| } | |
| refreshed++ | |
| logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired)) | |
| } | |
| } | |
| if refreshed > 0 { | |
| func() { | |
| defer func() { | |
| if r := recover(); r != nil { | |
| logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r)) | |
| } | |
| }() | |
| model.InitChannelCache() | |
| }() | |
| ResetProxyClientCache() | |
| } | |
| if common.DebugEnabled { | |
| logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed) | |
| } | |
| } | |