| package auth |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "math" |
| "net/http" |
| "sort" |
| "strconv" |
| "sync" |
| "time" |
|
|
| cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" |
| ) |
|
|
| |
| type RoundRobinSelector struct { |
| mu sync.Mutex |
| cursors map[string]int |
| } |
|
|
| |
| |
| |
| type FillFirstSelector struct{} |
|
|
| type blockReason int |
|
|
| const ( |
| blockReasonNone blockReason = iota |
| blockReasonCooldown |
| blockReasonDisabled |
| blockReasonOther |
| ) |
|
|
| type modelCooldownError struct { |
| model string |
| resetIn time.Duration |
| provider string |
| } |
|
|
| func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError { |
| if resetIn < 0 { |
| resetIn = 0 |
| } |
| return &modelCooldownError{ |
| model: model, |
| provider: provider, |
| resetIn: resetIn, |
| } |
| } |
|
|
| func (e *modelCooldownError) Error() string { |
| modelName := e.model |
| if modelName == "" { |
| modelName = "requested model" |
| } |
| message := fmt.Sprintf("All credentials for model %s are cooling down", modelName) |
| if e.provider != "" { |
| message = fmt.Sprintf("%s via provider %s", message, e.provider) |
| } |
| resetSeconds := int(math.Ceil(e.resetIn.Seconds())) |
| if resetSeconds < 0 { |
| resetSeconds = 0 |
| } |
| displayDuration := e.resetIn |
| if displayDuration > 0 && displayDuration < time.Second { |
| displayDuration = time.Second |
| } else { |
| displayDuration = displayDuration.Round(time.Second) |
| } |
| errorBody := map[string]any{ |
| "code": "model_cooldown", |
| "message": message, |
| "model": e.model, |
| "reset_time": displayDuration.String(), |
| "reset_seconds": resetSeconds, |
| } |
| if e.provider != "" { |
| errorBody["provider"] = e.provider |
| } |
| payload := map[string]any{"error": errorBody} |
| data, err := json.Marshal(payload) |
| if err != nil { |
| return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message) |
| } |
| return string(data) |
| } |
|
|
| func (e *modelCooldownError) StatusCode() int { |
| return http.StatusTooManyRequests |
| } |
|
|
| func (e *modelCooldownError) Headers() http.Header { |
| headers := make(http.Header) |
| headers.Set("Content-Type", "application/json") |
| resetSeconds := int(math.Ceil(e.resetIn.Seconds())) |
| if resetSeconds < 0 { |
| resetSeconds = 0 |
| } |
| headers.Set("Retry-After", strconv.Itoa(resetSeconds)) |
| return headers |
| } |
|
|
| func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) { |
| available = make([]*Auth, 0, len(auths)) |
| for i := 0; i < len(auths); i++ { |
| candidate := auths[i] |
| blocked, reason, next := isAuthBlockedForModel(candidate, model, now) |
| if !blocked { |
| available = append(available, candidate) |
| continue |
| } |
| if reason == blockReasonCooldown { |
| cooldownCount++ |
| if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) { |
| earliest = next |
| } |
| } |
| } |
| if len(available) > 1 { |
| sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) |
| } |
| return available, cooldownCount, earliest |
| } |
|
|
| func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) { |
| if len(auths) == 0 { |
| return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"} |
| } |
|
|
| available, cooldownCount, earliest := collectAvailable(auths, model, now) |
| if len(available) == 0 { |
| if cooldownCount == len(auths) && !earliest.IsZero() { |
| resetIn := earliest.Sub(now) |
| if resetIn < 0 { |
| resetIn = 0 |
| } |
| return nil, newModelCooldownError(model, provider, resetIn) |
| } |
| return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} |
| } |
|
|
| return available, nil |
| } |
|
|
| |
| func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { |
| _ = ctx |
| _ = opts |
| now := time.Now() |
| available, err := getAvailableAuths(auths, provider, model, now) |
| if err != nil { |
| return nil, err |
| } |
| key := provider + ":" + model |
| s.mu.Lock() |
| if s.cursors == nil { |
| s.cursors = make(map[string]int) |
| } |
| index := s.cursors[key] |
|
|
| if index >= 2_147_483_640 { |
| index = 0 |
| } |
|
|
| s.cursors[key] = index + 1 |
| s.mu.Unlock() |
| |
| return available[index%len(available)], nil |
| } |
|
|
| |
| func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { |
| _ = ctx |
| _ = opts |
| now := time.Now() |
| available, err := getAvailableAuths(auths, provider, model, now) |
| if err != nil { |
| return nil, err |
| } |
| return available[0], nil |
| } |
|
|
| func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) { |
| if auth == nil { |
| return true, blockReasonOther, time.Time{} |
| } |
| if auth.Disabled || auth.Status == StatusDisabled { |
| return true, blockReasonDisabled, time.Time{} |
| } |
| if model != "" { |
| if len(auth.ModelStates) > 0 { |
| if state, ok := auth.ModelStates[model]; ok && state != nil { |
| if state.Status == StatusDisabled { |
| return true, blockReasonDisabled, time.Time{} |
| } |
| if state.Unavailable { |
| if state.NextRetryAfter.IsZero() { |
| return false, blockReasonNone, time.Time{} |
| } |
| if state.NextRetryAfter.After(now) { |
| next := state.NextRetryAfter |
| if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) { |
| next = state.Quota.NextRecoverAt |
| } |
| if next.Before(now) { |
| next = now |
| } |
| if state.Quota.Exceeded { |
| return true, blockReasonCooldown, next |
| } |
| return true, blockReasonOther, next |
| } |
| } |
| return false, blockReasonNone, time.Time{} |
| } |
| } |
| return false, blockReasonNone, time.Time{} |
| } |
| if auth.Unavailable && auth.NextRetryAfter.After(now) { |
| next := auth.NextRetryAfter |
| if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) { |
| next = auth.Quota.NextRecoverAt |
| } |
| if next.Before(now) { |
| next = now |
| } |
| if auth.Quota.Exceeded { |
| return true, blockReasonCooldown, next |
| } |
| return true, blockReasonOther, next |
| } |
| return false, blockReasonNone, time.Time{} |
| } |
|
|