| package auth |
|
|
| import ( |
| "context" |
| "crypto/sha256" |
| "encoding/hex" |
| "errors" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| "ds2api/internal/account" |
| "ds2api/internal/config" |
| ) |
|
|
| type ctxKey string |
|
|
| const authCtxKey ctxKey = "auth_context" |
|
|
| var ( |
| ErrUnauthorized = errors.New("unauthorized: missing auth token") |
| ErrNoAccount = errors.New("no accounts configured or all accounts are busy") |
| ) |
|
|
| type RequestAuth struct { |
| UseConfigToken bool |
| DeepSeekToken string |
| CallerID string |
| AccountID string |
| Account config.Account |
| TriedAccounts map[string]bool |
| resolver *Resolver |
| } |
|
|
| type LoginFunc func(ctx context.Context, acc config.Account) (string, error) |
|
|
| type Resolver struct { |
| Store *config.Store |
| Pool *account.Pool |
| Login LoginFunc |
|
|
| mu sync.Mutex |
| tokenRefreshedAt map[string]time.Time |
| } |
|
|
| func NewResolver(store *config.Store, pool *account.Pool, login LoginFunc) *Resolver { |
| return &Resolver{ |
| Store: store, |
| Pool: pool, |
| Login: login, |
| tokenRefreshedAt: map[string]time.Time{}, |
| } |
| } |
|
|
| func (r *Resolver) Determine(req *http.Request) (*RequestAuth, error) { |
| callerKey := extractCallerToken(req) |
| if callerKey == "" { |
| return nil, ErrUnauthorized |
| } |
| callerID := callerTokenID(callerKey) |
| ctx := req.Context() |
| if !r.Store.HasAPIKey(callerKey) { |
| return &RequestAuth{ |
| UseConfigToken: false, |
| DeepSeekToken: callerKey, |
| CallerID: callerID, |
| resolver: r, |
| TriedAccounts: map[string]bool{}, |
| }, nil |
| } |
| target := strings.TrimSpace(req.Header.Get("X-Ds2-Target-Account")) |
| a, err := r.acquireManagedRequestAuth(ctx, callerID, target) |
| if err != nil { |
| return nil, err |
| } |
| return a, nil |
| } |
|
|
| func (r *Resolver) acquireManagedRequestAuth(ctx context.Context, callerID, target string) (*RequestAuth, error) { |
| tried := map[string]bool{} |
| var lastEnsureErr error |
| for { |
| if target == "" && len(tried) >= len(r.Store.Accounts()) { |
| if lastEnsureErr != nil { |
| return nil, lastEnsureErr |
| } |
| return nil, ErrNoAccount |
| } |
| acc, ok := r.Pool.AcquireWait(ctx, target, tried) |
| if !ok { |
| if lastEnsureErr != nil { |
| return nil, lastEnsureErr |
| } |
| return nil, ErrNoAccount |
| } |
|
|
| a := &RequestAuth{ |
| UseConfigToken: true, |
| CallerID: callerID, |
| AccountID: acc.Identifier(), |
| Account: acc, |
| TriedAccounts: tried, |
| resolver: r, |
| } |
|
|
| if err := r.ensureManagedToken(ctx, a); err != nil { |
| lastEnsureErr = err |
| tried[a.AccountID] = true |
| r.Pool.Release(a.AccountID) |
| if target != "" { |
| return nil, err |
| } |
| continue |
| } |
| return a, nil |
| } |
| } |
|
|
| |
| |
| func (r *Resolver) DetermineCaller(req *http.Request) (*RequestAuth, error) { |
| callerKey := extractCallerToken(req) |
| if callerKey == "" { |
| return nil, ErrUnauthorized |
| } |
| callerID := callerTokenID(callerKey) |
| a := &RequestAuth{ |
| UseConfigToken: false, |
| CallerID: callerID, |
| resolver: r, |
| TriedAccounts: map[string]bool{}, |
| } |
| if r == nil || r.Store == nil || !r.Store.HasAPIKey(callerKey) { |
| a.DeepSeekToken = callerKey |
| } |
| return a, nil |
| } |
|
|
| func WithAuth(ctx context.Context, a *RequestAuth) context.Context { |
| return context.WithValue(ctx, authCtxKey, a) |
| } |
|
|
| func FromContext(ctx context.Context) (*RequestAuth, bool) { |
| v := ctx.Value(authCtxKey) |
| a, ok := v.(*RequestAuth) |
| return a, ok |
| } |
|
|
| func (r *Resolver) loginAndPersist(ctx context.Context, a *RequestAuth) error { |
| token, err := r.Login(ctx, a.Account) |
| if err != nil { |
| return err |
| } |
| a.Account.Token = token |
| a.DeepSeekToken = token |
| r.markTokenRefreshedNow(a.AccountID) |
| return r.Store.UpdateAccountToken(a.AccountID, token) |
| } |
|
|
| func (r *Resolver) RefreshToken(ctx context.Context, a *RequestAuth) bool { |
| if !a.UseConfigToken || a.AccountID == "" { |
| return false |
| } |
| _ = r.Store.UpdateAccountToken(a.AccountID, "") |
| a.Account.Token = "" |
| if err := r.loginAndPersist(ctx, a); err != nil { |
| config.Logger.Error("[refresh_token] failed", "account", a.AccountID, "error", err) |
| return false |
| } |
| return true |
| } |
|
|
| func (r *Resolver) MarkTokenInvalid(a *RequestAuth) { |
| if !a.UseConfigToken || a.AccountID == "" { |
| return |
| } |
| a.Account.Token = "" |
| a.DeepSeekToken = "" |
| r.clearTokenRefreshMark(a.AccountID) |
| _ = r.Store.UpdateAccountToken(a.AccountID, "") |
| } |
|
|
| func (r *Resolver) SwitchAccount(ctx context.Context, a *RequestAuth) bool { |
| if !a.UseConfigToken { |
| return false |
| } |
| if a.TriedAccounts == nil { |
| a.TriedAccounts = map[string]bool{} |
| } |
| if a.AccountID != "" { |
| a.TriedAccounts[a.AccountID] = true |
| r.Pool.Release(a.AccountID) |
| } |
| for { |
| acc, ok := r.Pool.Acquire("", a.TriedAccounts) |
| if !ok { |
| return false |
| } |
| a.Account = acc |
| a.AccountID = acc.Identifier() |
| if err := r.ensureManagedToken(ctx, a); err != nil { |
| a.TriedAccounts[a.AccountID] = true |
| r.Pool.Release(a.AccountID) |
| continue |
| } |
| return true |
| } |
| } |
|
|
| func (r *Resolver) Release(a *RequestAuth) { |
| if a == nil || !a.UseConfigToken || a.AccountID == "" { |
| return |
| } |
| r.Pool.Release(a.AccountID) |
| } |
|
|
| func extractCallerToken(req *http.Request) string { |
| authHeader := strings.TrimSpace(req.Header.Get("Authorization")) |
| if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { |
| token := strings.TrimSpace(authHeader[7:]) |
| if token != "" { |
| return token |
| } |
| } |
| if key := strings.TrimSpace(req.Header.Get("x-api-key")); key != "" { |
| return key |
| } |
| |
| if key := strings.TrimSpace(req.Header.Get("x-goog-api-key")); key != "" { |
| return key |
| } |
| |
| |
| if key := strings.TrimSpace(req.URL.Query().Get("key")); key != "" { |
| return key |
| } |
| return strings.TrimSpace(req.URL.Query().Get("api_key")) |
| } |
|
|
| func callerTokenID(token string) string { |
| token = strings.TrimSpace(token) |
| if token == "" { |
| return "" |
| } |
| sum := sha256.Sum256([]byte(token)) |
| return "caller:" + hex.EncodeToString(sum[:8]) |
| } |
|
|
| func (r *Resolver) ensureManagedToken(ctx context.Context, a *RequestAuth) error { |
| if strings.TrimSpace(a.Account.Token) == "" { |
| return r.loginAndPersist(ctx, a) |
| } |
| if r.shouldForceRefresh(a.AccountID) { |
| if err := r.loginAndPersist(ctx, a); err != nil { |
| return err |
| } |
| return nil |
| } |
| a.DeepSeekToken = a.Account.Token |
| return nil |
| } |
|
|
| func (r *Resolver) shouldForceRefresh(accountID string) bool { |
| if r == nil || r.Store == nil { |
| return false |
| } |
| if strings.TrimSpace(accountID) == "" { |
| return false |
| } |
| intervalHours := r.Store.RuntimeTokenRefreshIntervalHours() |
| if intervalHours <= 0 { |
| return false |
| } |
| now := time.Now() |
| r.mu.Lock() |
| defer r.mu.Unlock() |
| last, ok := r.tokenRefreshedAt[accountID] |
| if !ok || last.IsZero() { |
| r.tokenRefreshedAt[accountID] = now |
| return false |
| } |
| return now.Sub(last) >= time.Duration(intervalHours)*time.Hour |
| } |
|
|
| func (r *Resolver) markTokenRefreshedNow(accountID string) { |
| if strings.TrimSpace(accountID) == "" { |
| return |
| } |
| r.mu.Lock() |
| defer r.mu.Unlock() |
| r.tokenRefreshedAt[accountID] = time.Now() |
| } |
|
|
| func (r *Resolver) clearTokenRefreshMark(accountID string) { |
| if strings.TrimSpace(accountID) == "" { |
| return |
| } |
| r.mu.Lock() |
| defer r.mu.Unlock() |
| delete(r.tokenRefreshedAt, accountID) |
| } |
|
|