| package service |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
| "time" |
|
|
| dbent "github.com/Wei-Shaw/sub2api/ent" |
| infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/logger" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" |
| "github.com/Wei-Shaw/sub2api/internal/util/soraerror" |
| ) |
|
|
| |
| type AdminService interface { |
| |
| ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) |
| GetUser(ctx context.Context, id int64) (*User, error) |
| CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) |
| UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) |
| DeleteUser(ctx context.Context, id int64) error |
| UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) |
| GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) |
| GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) |
| |
| |
| |
| GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) |
|
|
| |
| ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) |
| GetAllGroups(ctx context.Context) ([]Group, error) |
| GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) |
| GetGroup(ctx context.Context, id int64) (*Group, error) |
| CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) |
| UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) |
| DeleteGroup(ctx context.Context, id int64) error |
| GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) |
| GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) |
| ClearGroupRateMultipliers(ctx context.Context, groupID int64) error |
| BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error |
| UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error |
|
|
| |
| AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) |
|
|
| |
| ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) |
|
|
| |
| ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) |
| GetAccount(ctx context.Context, id int64) (*Account, error) |
| GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) |
| CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) |
| UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) |
| DeleteAccount(ctx context.Context, id int64) error |
| RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) |
| ClearAccountError(ctx context.Context, id int64) (*Account, error) |
| SetAccountError(ctx context.Context, id int64, errorMsg string) error |
| |
| EnsureOpenAIPrivacy(ctx context.Context, account *Account) string |
| SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) |
| BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) |
| CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error |
|
|
| |
| ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) |
| ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) |
| GetAllProxies(ctx context.Context) ([]Proxy, error) |
| GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) |
| GetProxy(ctx context.Context, id int64) (*Proxy, error) |
| GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) |
| CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) |
| UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) |
| DeleteProxy(ctx context.Context, id int64) error |
| BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) |
| GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) |
| CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) |
| TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) |
| CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) |
|
|
| |
| ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) |
| GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) |
| GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) |
| DeleteRedeemCode(ctx context.Context, id int64) error |
| BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) |
| ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) |
| ResetAccountQuota(ctx context.Context, id int64) error |
| } |
|
|
| |
| type CreateUserInput struct { |
| Email string |
| Password string |
| Username string |
| Notes string |
| Balance float64 |
| Concurrency int |
| AllowedGroups []int64 |
| SoraStorageQuotaBytes int64 |
| } |
|
|
| type UpdateUserInput struct { |
| Email string |
| Password string |
| Username *string |
| Notes *string |
| Balance *float64 |
| Concurrency *int |
| Status string |
| AllowedGroups *[]int64 |
| |
| |
| GroupRates map[int64]*float64 |
| SoraStorageQuotaBytes *int64 |
| } |
|
|
| type CreateGroupInput struct { |
| Name string |
| Description string |
| Platform string |
| RateMultiplier float64 |
| IsExclusive bool |
| SubscriptionType string |
| DailyLimitUSD *float64 |
| WeeklyLimitUSD *float64 |
| MonthlyLimitUSD *float64 |
| |
| ImagePrice1K *float64 |
| ImagePrice2K *float64 |
| ImagePrice4K *float64 |
| |
| SoraImagePrice360 *float64 |
| SoraImagePrice540 *float64 |
| SoraVideoPricePerRequest *float64 |
| SoraVideoPricePerRequestHD *float64 |
| ClaudeCodeOnly bool |
| FallbackGroupID *int64 |
| |
| FallbackGroupIDOnInvalidRequest *int64 |
| |
| ModelRouting map[string][]int64 |
| ModelRoutingEnabled bool |
| MCPXMLInject *bool |
| |
| SupportedModelScopes []string |
| |
| SoraStorageQuotaBytes int64 |
| |
| AllowMessagesDispatch bool |
| DefaultMappedModel string |
| |
| CopyAccountsFromGroupIDs []int64 |
| } |
|
|
| type UpdateGroupInput struct { |
| Name string |
| Description string |
| Platform string |
| RateMultiplier *float64 |
| IsExclusive *bool |
| Status string |
| SubscriptionType string |
| DailyLimitUSD *float64 |
| WeeklyLimitUSD *float64 |
| MonthlyLimitUSD *float64 |
| |
| ImagePrice1K *float64 |
| ImagePrice2K *float64 |
| ImagePrice4K *float64 |
| |
| SoraImagePrice360 *float64 |
| SoraImagePrice540 *float64 |
| SoraVideoPricePerRequest *float64 |
| SoraVideoPricePerRequestHD *float64 |
| ClaudeCodeOnly *bool |
| FallbackGroupID *int64 |
| |
| FallbackGroupIDOnInvalidRequest *int64 |
| |
| ModelRouting map[string][]int64 |
| ModelRoutingEnabled *bool |
| MCPXMLInject *bool |
| |
| SupportedModelScopes *[]string |
| |
| SoraStorageQuotaBytes *int64 |
| |
| AllowMessagesDispatch *bool |
| DefaultMappedModel *string |
| |
| CopyAccountsFromGroupIDs []int64 |
| } |
|
|
| type CreateAccountInput struct { |
| Name string |
| Notes *string |
| Platform string |
| Type string |
| Credentials map[string]any |
| Extra map[string]any |
| ProxyID *int64 |
| Concurrency int |
| Priority int |
| RateMultiplier *float64 |
| LoadFactor *int |
| GroupIDs []int64 |
| ExpiresAt *int64 |
| AutoPauseOnExpired *bool |
| |
| SkipDefaultGroupBind bool |
| |
| |
| SkipMixedChannelCheck bool |
| } |
|
|
| type UpdateAccountInput struct { |
| Name string |
| Notes *string |
| Type string |
| Credentials map[string]any |
| Extra map[string]any |
| ProxyID *int64 |
| Concurrency *int |
| Priority *int |
| RateMultiplier *float64 |
| LoadFactor *int |
| Status string |
| GroupIDs *[]int64 |
| ExpiresAt *int64 |
| AutoPauseOnExpired *bool |
| SkipMixedChannelCheck bool |
| } |
|
|
| |
| type BulkUpdateAccountsInput struct { |
| AccountIDs []int64 |
| Name string |
| ProxyID *int64 |
| Concurrency *int |
| Priority *int |
| RateMultiplier *float64 |
| LoadFactor *int |
| Status string |
| Schedulable *bool |
| GroupIDs *[]int64 |
| Credentials map[string]any |
| Extra map[string]any |
| |
| |
| SkipMixedChannelCheck bool |
| } |
|
|
| |
| type BulkUpdateAccountResult struct { |
| AccountID int64 `json:"account_id"` |
| Success bool `json:"success"` |
| Error string `json:"error,omitempty"` |
| } |
|
|
| |
| type AdminUpdateAPIKeyGroupIDResult struct { |
| APIKey *APIKey |
| AutoGrantedGroupAccess bool |
| GrantedGroupID *int64 |
| GrantedGroupName string |
| } |
|
|
| |
| type ReplaceUserGroupResult struct { |
| MigratedKeys int64 |
| } |
|
|
| |
| type BulkUpdateAccountsResult struct { |
| Success int `json:"success"` |
| Failed int `json:"failed"` |
| SuccessIDs []int64 `json:"success_ids"` |
| FailedIDs []int64 `json:"failed_ids"` |
| Results []BulkUpdateAccountResult `json:"results"` |
| } |
|
|
| type CreateProxyInput struct { |
| Name string |
| Protocol string |
| Host string |
| Port int |
| Username string |
| Password string |
| } |
|
|
| type UpdateProxyInput struct { |
| Name string |
| Protocol string |
| Host string |
| Port int |
| Username string |
| Password string |
| Status string |
| } |
|
|
| type GenerateRedeemCodesInput struct { |
| Count int |
| Type string |
| Value float64 |
| GroupID *int64 |
| ValidityDays int |
| } |
|
|
| type ProxyBatchDeleteResult struct { |
| DeletedIDs []int64 `json:"deleted_ids"` |
| Skipped []ProxyBatchDeleteSkipped `json:"skipped"` |
| } |
|
|
| type ProxyBatchDeleteSkipped struct { |
| ID int64 `json:"id"` |
| Reason string `json:"reason"` |
| } |
|
|
| |
| type ProxyTestResult struct { |
| Success bool `json:"success"` |
| Message string `json:"message"` |
| LatencyMs int64 `json:"latency_ms,omitempty"` |
| IPAddress string `json:"ip_address,omitempty"` |
| City string `json:"city,omitempty"` |
| Region string `json:"region,omitempty"` |
| Country string `json:"country,omitempty"` |
| CountryCode string `json:"country_code,omitempty"` |
| } |
|
|
| type ProxyQualityCheckResult struct { |
| ProxyID int64 `json:"proxy_id"` |
| Score int `json:"score"` |
| Grade string `json:"grade"` |
| Summary string `json:"summary"` |
| ExitIP string `json:"exit_ip,omitempty"` |
| Country string `json:"country,omitempty"` |
| CountryCode string `json:"country_code,omitempty"` |
| BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` |
| PassedCount int `json:"passed_count"` |
| WarnCount int `json:"warn_count"` |
| FailedCount int `json:"failed_count"` |
| ChallengeCount int `json:"challenge_count"` |
| CheckedAt int64 `json:"checked_at"` |
| Items []ProxyQualityCheckItem `json:"items"` |
| } |
|
|
| type ProxyQualityCheckItem struct { |
| Target string `json:"target"` |
| Status string `json:"status"` |
| HTTPStatus int `json:"http_status,omitempty"` |
| LatencyMs int64 `json:"latency_ms,omitempty"` |
| Message string `json:"message,omitempty"` |
| CFRay string `json:"cf_ray,omitempty"` |
| } |
|
|
| |
| type ProxyExitInfo struct { |
| IP string |
| City string |
| Region string |
| Country string |
| CountryCode string |
| } |
|
|
| |
| type ProxyExitInfoProber interface { |
| ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) |
| } |
|
|
| type groupExistenceBatchReader interface { |
| ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) |
| } |
|
|
| type proxyQualityTarget struct { |
| Target string |
| URL string |
| Method string |
| AllowedStatuses map[int]struct{} |
| } |
|
|
| var proxyQualityTargets = []proxyQualityTarget{ |
| { |
| Target: "openai", |
| URL: "https://api.openai.com/v1/models", |
| Method: http.MethodGet, |
| AllowedStatuses: map[int]struct{}{ |
| http.StatusUnauthorized: {}, |
| }, |
| }, |
| { |
| Target: "anthropic", |
| URL: "https://api.anthropic.com/v1/messages", |
| Method: http.MethodGet, |
| AllowedStatuses: map[int]struct{}{ |
| http.StatusUnauthorized: {}, |
| http.StatusMethodNotAllowed: {}, |
| http.StatusNotFound: {}, |
| http.StatusBadRequest: {}, |
| }, |
| }, |
| { |
| Target: "gemini", |
| URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", |
| Method: http.MethodGet, |
| AllowedStatuses: map[int]struct{}{ |
| http.StatusOK: {}, |
| }, |
| }, |
| { |
| Target: "sora", |
| URL: "https://sora.chatgpt.com/backend/me", |
| Method: http.MethodGet, |
| AllowedStatuses: map[int]struct{}{ |
| http.StatusUnauthorized: {}, |
| }, |
| }, |
| } |
|
|
| const ( |
| proxyQualityRequestTimeout = 15 * time.Second |
| proxyQualityResponseHeaderTimeout = 10 * time.Second |
| proxyQualityMaxBodyBytes = int64(8 * 1024) |
| proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" |
| ) |
|
|
| |
| type adminServiceImpl struct { |
| userRepo UserRepository |
| groupRepo GroupRepository |
| accountRepo AccountRepository |
| soraAccountRepo SoraAccountRepository |
| proxyRepo ProxyRepository |
| apiKeyRepo APIKeyRepository |
| redeemCodeRepo RedeemCodeRepository |
| userGroupRateRepo UserGroupRateRepository |
| billingCacheService *BillingCacheService |
| proxyProber ProxyExitInfoProber |
| proxyLatencyCache ProxyLatencyCache |
| authCacheInvalidator APIKeyAuthCacheInvalidator |
| entClient *dbent.Client |
| settingService *SettingService |
| defaultSubAssigner DefaultSubscriptionAssigner |
| userSubRepo UserSubscriptionRepository |
| privacyClientFactory PrivacyClientFactory |
| } |
|
|
| type userGroupRateBatchReader interface { |
| GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) |
| } |
|
|
| |
| func NewAdminService( |
| userRepo UserRepository, |
| groupRepo GroupRepository, |
| accountRepo AccountRepository, |
| soraAccountRepo SoraAccountRepository, |
| proxyRepo ProxyRepository, |
| apiKeyRepo APIKeyRepository, |
| redeemCodeRepo RedeemCodeRepository, |
| userGroupRateRepo UserGroupRateRepository, |
| billingCacheService *BillingCacheService, |
| proxyProber ProxyExitInfoProber, |
| proxyLatencyCache ProxyLatencyCache, |
| authCacheInvalidator APIKeyAuthCacheInvalidator, |
| entClient *dbent.Client, |
| settingService *SettingService, |
| defaultSubAssigner DefaultSubscriptionAssigner, |
| userSubRepo UserSubscriptionRepository, |
| privacyClientFactory PrivacyClientFactory, |
| ) AdminService { |
| return &adminServiceImpl{ |
| userRepo: userRepo, |
| groupRepo: groupRepo, |
| accountRepo: accountRepo, |
| soraAccountRepo: soraAccountRepo, |
| proxyRepo: proxyRepo, |
| apiKeyRepo: apiKeyRepo, |
| redeemCodeRepo: redeemCodeRepo, |
| userGroupRateRepo: userGroupRateRepo, |
| billingCacheService: billingCacheService, |
| proxyProber: proxyProber, |
| proxyLatencyCache: proxyLatencyCache, |
| authCacheInvalidator: authCacheInvalidator, |
| entClient: entClient, |
| settingService: settingService, |
| defaultSubAssigner: defaultSubAssigner, |
| userSubRepo: userSubRepo, |
| privacyClientFactory: privacyClientFactory, |
| } |
| } |
|
|
| |
| func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) |
| if err != nil { |
| return nil, 0, err |
| } |
| |
| if s.userGroupRateRepo != nil && len(users) > 0 { |
| if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { |
| userIDs := make([]int64, 0, len(users)) |
| for i := range users { |
| userIDs = append(userIDs, users[i].ID) |
| } |
| ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) |
| s.loadUserGroupRatesOneByOne(ctx, users) |
| } else { |
| for i := range users { |
| if rates, ok := ratesByUser[users[i].ID]; ok { |
| users[i].GroupRates = rates |
| } |
| } |
| } |
| } else { |
| s.loadUserGroupRatesOneByOne(ctx, users) |
| } |
| } |
| return users, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { |
| if s.userGroupRateRepo == nil { |
| return |
| } |
| for i := range users { |
| rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) |
| continue |
| } |
| users[i].GroupRates = rates |
| } |
| } |
|
|
| func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { |
| user, err := s.userRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| |
| if s.userGroupRateRepo != nil { |
| rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) |
| } else { |
| user.GroupRates = rates |
| } |
| } |
| return user, nil |
| } |
|
|
| func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { |
| user := &User{ |
| Email: input.Email, |
| Username: input.Username, |
| Notes: input.Notes, |
| Role: RoleUser, |
| Balance: input.Balance, |
| Concurrency: input.Concurrency, |
| Status: StatusActive, |
| AllowedGroups: input.AllowedGroups, |
| SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, |
| } |
| if err := user.SetPassword(input.Password); err != nil { |
| return nil, err |
| } |
| if err := s.userRepo.Create(ctx, user); err != nil { |
| return nil, err |
| } |
| s.assignDefaultSubscriptions(ctx, user.ID) |
| return user, nil |
| } |
|
|
| func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { |
| if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { |
| return |
| } |
| items := s.settingService.GetDefaultSubscriptions(ctx) |
| for _, item := range items { |
| if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ |
| UserID: userID, |
| GroupID: item.GroupID, |
| ValidityDays: item.ValidityDays, |
| Notes: "auto assigned by default user subscriptions setting", |
| }); err != nil { |
| logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) |
| } |
| } |
| } |
|
|
| func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { |
| user, err := s.userRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
|
|
| |
| if user.Role == "admin" && input.Status == "disabled" { |
| return nil, errors.New("cannot disable admin user") |
| } |
|
|
| oldConcurrency := user.Concurrency |
| oldStatus := user.Status |
| oldRole := user.Role |
|
|
| if input.Email != "" { |
| user.Email = input.Email |
| } |
| if input.Password != "" { |
| if err := user.SetPassword(input.Password); err != nil { |
| return nil, err |
| } |
| } |
|
|
| if input.Username != nil { |
| user.Username = *input.Username |
| } |
| if input.Notes != nil { |
| user.Notes = *input.Notes |
| } |
|
|
| if input.Status != "" { |
| user.Status = input.Status |
| } |
|
|
| if input.Concurrency != nil { |
| user.Concurrency = *input.Concurrency |
| } |
|
|
| if input.AllowedGroups != nil { |
| user.AllowedGroups = *input.AllowedGroups |
| } |
|
|
| if input.SoraStorageQuotaBytes != nil { |
| user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes |
| } |
|
|
| if err := s.userRepo.Update(ctx, user); err != nil { |
| return nil, err |
| } |
|
|
| |
| if input.GroupRates != nil && s.userGroupRateRepo != nil { |
| if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { |
| logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) |
| } |
| } |
|
|
| if s.authCacheInvalidator != nil { |
| if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { |
| s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) |
| } |
| } |
|
|
| concurrencyDiff := user.Concurrency - oldConcurrency |
| if concurrencyDiff != 0 { |
| code, err := GenerateRedeemCode() |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) |
| return user, nil |
| } |
| adjustmentRecord := &RedeemCode{ |
| Code: code, |
| Type: AdjustmentTypeAdminConcurrency, |
| Value: float64(concurrencyDiff), |
| Status: StatusUsed, |
| UsedBy: &user.ID, |
| } |
| now := time.Now() |
| adjustmentRecord.UsedAt = &now |
| if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { |
| logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) |
| } |
| } |
|
|
| return user, nil |
| } |
|
|
| func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { |
| |
| user, err := s.userRepo.GetByID(ctx, id) |
| if err != nil { |
| return err |
| } |
| if user.Role == "admin" { |
| return errors.New("cannot delete admin user") |
| } |
| if err := s.userRepo.Delete(ctx, id); err != nil { |
| logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) |
| return err |
| } |
| if s.authCacheInvalidator != nil { |
| s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) |
| } |
| return nil |
| } |
|
|
| func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { |
| user, err := s.userRepo.GetByID(ctx, userID) |
| if err != nil { |
| return nil, err |
| } |
|
|
| oldBalance := user.Balance |
|
|
| switch operation { |
| case "set": |
| user.Balance = balance |
| case "add": |
| user.Balance += balance |
| case "subtract": |
| user.Balance -= balance |
| } |
|
|
| if user.Balance < 0 { |
| return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) |
| } |
|
|
| if err := s.userRepo.Update(ctx, user); err != nil { |
| return nil, err |
| } |
| balanceDiff := user.Balance - oldBalance |
| if s.authCacheInvalidator != nil && balanceDiff != 0 { |
| s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) |
| } |
|
|
| if s.billingCacheService != nil { |
| go func() { |
| cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
| defer cancel() |
| if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { |
| logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err) |
| } |
| }() |
| } |
|
|
| if balanceDiff != 0 { |
| code, err := GenerateRedeemCode() |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) |
| return user, nil |
| } |
|
|
| adjustmentRecord := &RedeemCode{ |
| Code: code, |
| Type: AdjustmentTypeAdminBalance, |
| Value: balanceDiff, |
| Status: StatusUsed, |
| UsedBy: &user.ID, |
| Notes: notes, |
| } |
| now := time.Now() |
| adjustmentRecord.UsedAt = &now |
|
|
| if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { |
| logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err) |
| } |
| } |
|
|
| return user, nil |
| } |
|
|
| func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) |
| if err != nil { |
| return nil, 0, err |
| } |
| return keys, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { |
| |
| return map[string]any{ |
| "period": period, |
| "total_requests": 0, |
| "total_cost": 0.0, |
| "total_tokens": 0, |
| "avg_duration_ms": 0, |
| }, nil |
| } |
|
|
| |
| func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) |
| if err != nil { |
| return nil, 0, 0, err |
| } |
| |
| totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) |
| if err != nil { |
| return nil, 0, 0, err |
| } |
| return codes, result.Total, totalRecharged, nil |
| } |
|
|
| |
| func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) |
| if err != nil { |
| return nil, 0, err |
| } |
| return groups, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { |
| return s.groupRepo.ListActive(ctx) |
| } |
|
|
| func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { |
| return s.groupRepo.ListActiveByPlatform(ctx, platform) |
| } |
|
|
| func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { |
| return s.groupRepo.GetByID(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { |
| platform := input.Platform |
| if platform == "" { |
| platform = PlatformAnthropic |
| } |
|
|
| subscriptionType := input.SubscriptionType |
| if subscriptionType == "" { |
| subscriptionType = SubscriptionTypeStandard |
| } |
|
|
| |
| dailyLimit := normalizeLimit(input.DailyLimitUSD) |
| weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) |
| monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) |
|
|
| |
| imagePrice1K := normalizePrice(input.ImagePrice1K) |
| imagePrice2K := normalizePrice(input.ImagePrice2K) |
| imagePrice4K := normalizePrice(input.ImagePrice4K) |
| soraImagePrice360 := normalizePrice(input.SoraImagePrice360) |
| soraImagePrice540 := normalizePrice(input.SoraImagePrice540) |
| soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) |
| soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) |
|
|
| |
| if input.FallbackGroupID != nil { |
| if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { |
| return nil, err |
| } |
| } |
| fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest |
| if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { |
| fallbackOnInvalidRequest = nil |
| } |
| |
| if fallbackOnInvalidRequest != nil { |
| if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { |
| return nil, err |
| } |
| } |
|
|
| |
| mcpXMLInject := true |
| if input.MCPXMLInject != nil { |
| mcpXMLInject = *input.MCPXMLInject |
| } |
|
|
| |
| var accountIDsToCopy []int64 |
| if len(input.CopyAccountsFromGroupIDs) > 0 { |
| |
| seen := make(map[int64]struct{}) |
| uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) |
| for _, srcGroupID := range input.CopyAccountsFromGroupIDs { |
| if _, exists := seen[srcGroupID]; !exists { |
| seen[srcGroupID] = struct{}{} |
| uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) |
| } |
| } |
|
|
| |
| for _, srcGroupID := range uniqueSourceGroupIDs { |
| srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) |
| if err != nil { |
| return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) |
| } |
| if srcGroup.Platform != platform { |
| return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform) |
| } |
| } |
|
|
| |
| var err error |
| accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) |
| if err != nil { |
| return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) |
| } |
| } |
|
|
| group := &Group{ |
| Name: input.Name, |
| Description: input.Description, |
| Platform: platform, |
| RateMultiplier: input.RateMultiplier, |
| IsExclusive: input.IsExclusive, |
| Status: StatusActive, |
| SubscriptionType: subscriptionType, |
| DailyLimitUSD: dailyLimit, |
| WeeklyLimitUSD: weeklyLimit, |
| MonthlyLimitUSD: monthlyLimit, |
| ImagePrice1K: imagePrice1K, |
| ImagePrice2K: imagePrice2K, |
| ImagePrice4K: imagePrice4K, |
| SoraImagePrice360: soraImagePrice360, |
| SoraImagePrice540: soraImagePrice540, |
| SoraVideoPricePerRequest: soraVideoPrice, |
| SoraVideoPricePerRequestHD: soraVideoPriceHD, |
| ClaudeCodeOnly: input.ClaudeCodeOnly, |
| FallbackGroupID: input.FallbackGroupID, |
| FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, |
| ModelRouting: input.ModelRouting, |
| MCPXMLInject: mcpXMLInject, |
| SupportedModelScopes: input.SupportedModelScopes, |
| SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, |
| AllowMessagesDispatch: input.AllowMessagesDispatch, |
| DefaultMappedModel: input.DefaultMappedModel, |
| } |
| if err := s.groupRepo.Create(ctx, group); err != nil { |
| return nil, err |
| } |
|
|
| |
| if len(accountIDsToCopy) > 0 { |
| if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { |
| return nil, fmt.Errorf("failed to bind accounts to new group: %w", err) |
| } |
| group.AccountCount = int64(len(accountIDsToCopy)) |
| } |
|
|
| return group, nil |
| } |
|
|
| |
| func normalizeLimit(limit *float64) *float64 { |
| if limit == nil || *limit < 0 { |
| return nil |
| } |
| return limit |
| } |
|
|
| |
| func normalizePrice(price *float64) *float64 { |
| if price == nil || *price < 0 { |
| return nil |
| } |
| return price |
| } |
|
|
| |
| |
| |
| func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error { |
| |
| if currentGroupID > 0 && currentGroupID == fallbackGroupID { |
| return fmt.Errorf("cannot set self as fallback group") |
| } |
|
|
| visited := map[int64]struct{}{} |
| nextID := fallbackGroupID |
| for { |
| if _, seen := visited[nextID]; seen { |
| return fmt.Errorf("fallback group cycle detected") |
| } |
| visited[nextID] = struct{}{} |
| if currentGroupID > 0 && nextID == currentGroupID { |
| return fmt.Errorf("fallback group cycle detected") |
| } |
|
|
| |
| fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID) |
| if err != nil { |
| return fmt.Errorf("fallback group not found: %w", err) |
| } |
|
|
| |
| if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly { |
| return fmt.Errorf("fallback group cannot have claude_code_only enabled") |
| } |
|
|
| if fallbackGroup.FallbackGroupID == nil { |
| return nil |
| } |
| nextID = *fallbackGroup.FallbackGroupID |
| } |
| } |
|
|
| |
| |
| |
| |
| func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { |
| if platform != PlatformAnthropic && platform != PlatformAntigravity { |
| return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") |
| } |
| if subscriptionType == SubscriptionTypeSubscription { |
| return fmt.Errorf("subscription groups cannot set invalid request fallback") |
| } |
| if currentGroupID > 0 && currentGroupID == fallbackGroupID { |
| return fmt.Errorf("cannot set self as invalid request fallback group") |
| } |
|
|
| fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) |
| if err != nil { |
| return fmt.Errorf("fallback group not found: %w", err) |
| } |
| if fallbackGroup.Platform != PlatformAnthropic { |
| return fmt.Errorf("fallback group must be anthropic platform") |
| } |
| if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { |
| return fmt.Errorf("fallback group cannot be subscription type") |
| } |
| if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { |
| return fmt.Errorf("fallback group cannot have invalid request fallback configured") |
| } |
| return nil |
| } |
|
|
| func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { |
| group, err := s.groupRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
|
|
| if input.Name != "" { |
| group.Name = input.Name |
| } |
| if input.Description != "" { |
| group.Description = input.Description |
| } |
| if input.Platform != "" { |
| group.Platform = input.Platform |
| } |
| if input.RateMultiplier != nil { |
| group.RateMultiplier = *input.RateMultiplier |
| } |
| if input.IsExclusive != nil { |
| group.IsExclusive = *input.IsExclusive |
| } |
| if input.Status != "" { |
| group.Status = input.Status |
| } |
|
|
| |
| if input.SubscriptionType != "" { |
| group.SubscriptionType = input.SubscriptionType |
| } |
| |
| |
| group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) |
| group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) |
| group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) |
| |
| if input.ImagePrice1K != nil { |
| group.ImagePrice1K = normalizePrice(input.ImagePrice1K) |
| } |
| if input.ImagePrice2K != nil { |
| group.ImagePrice2K = normalizePrice(input.ImagePrice2K) |
| } |
| if input.ImagePrice4K != nil { |
| group.ImagePrice4K = normalizePrice(input.ImagePrice4K) |
| } |
| if input.SoraImagePrice360 != nil { |
| group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) |
| } |
| if input.SoraImagePrice540 != nil { |
| group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) |
| } |
| if input.SoraVideoPricePerRequest != nil { |
| group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) |
| } |
| if input.SoraVideoPricePerRequestHD != nil { |
| group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) |
| } |
| if input.SoraStorageQuotaBytes != nil { |
| group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes |
| } |
|
|
| |
| if input.ClaudeCodeOnly != nil { |
| group.ClaudeCodeOnly = *input.ClaudeCodeOnly |
| } |
| if input.FallbackGroupID != nil { |
| |
| if *input.FallbackGroupID > 0 { |
| if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { |
| return nil, err |
| } |
| group.FallbackGroupID = input.FallbackGroupID |
| } else { |
| |
| group.FallbackGroupID = nil |
| } |
| } |
| fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest |
| if input.FallbackGroupIDOnInvalidRequest != nil { |
| if *input.FallbackGroupIDOnInvalidRequest > 0 { |
| fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest |
| } else { |
| fallbackOnInvalidRequest = nil |
| } |
| } |
| if fallbackOnInvalidRequest != nil { |
| if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { |
| return nil, err |
| } |
| } |
| group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest |
|
|
| |
| if input.ModelRouting != nil { |
| group.ModelRouting = input.ModelRouting |
| } |
| if input.ModelRoutingEnabled != nil { |
| group.ModelRoutingEnabled = *input.ModelRoutingEnabled |
| } |
| if input.MCPXMLInject != nil { |
| group.MCPXMLInject = *input.MCPXMLInject |
| } |
|
|
| |
| if input.SupportedModelScopes != nil { |
| group.SupportedModelScopes = *input.SupportedModelScopes |
| } |
|
|
| |
| if input.AllowMessagesDispatch != nil { |
| group.AllowMessagesDispatch = *input.AllowMessagesDispatch |
| } |
| if input.DefaultMappedModel != nil { |
| group.DefaultMappedModel = *input.DefaultMappedModel |
| } |
|
|
| if err := s.groupRepo.Update(ctx, group); err != nil { |
| return nil, err |
| } |
|
|
| |
| if len(input.CopyAccountsFromGroupIDs) > 0 { |
| |
| seen := make(map[int64]struct{}) |
| uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) |
| for _, srcGroupID := range input.CopyAccountsFromGroupIDs { |
| |
| if srcGroupID == id { |
| return nil, fmt.Errorf("cannot copy accounts from self") |
| } |
| |
| if _, exists := seen[srcGroupID]; !exists { |
| seen[srcGroupID] = struct{}{} |
| uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) |
| } |
| } |
|
|
| |
| for _, srcGroupID := range uniqueSourceGroupIDs { |
| srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) |
| if err != nil { |
| return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) |
| } |
| if srcGroup.Platform != group.Platform { |
| return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform) |
| } |
| } |
|
|
| |
| accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) |
| if err != nil { |
| return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) |
| } |
|
|
| |
| if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil { |
| return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) |
| } |
|
|
| |
| if len(accountIDsToCopy) > 0 { |
| if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { |
| return nil, fmt.Errorf("failed to bind accounts to group: %w", err) |
| } |
| } |
| } |
|
|
| if s.authCacheInvalidator != nil { |
| s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) |
| } |
| return group, nil |
| } |
|
|
| func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { |
| var groupKeys []string |
| if s.authCacheInvalidator != nil { |
| keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) |
| if err == nil { |
| groupKeys = keys |
| } |
| } |
|
|
| affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) |
| if err != nil { |
| return err |
| } |
| |
|
|
| |
| if len(affectedUserIDs) > 0 && s.billingCacheService != nil { |
| groupID := id |
| go func() { |
| cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| defer cancel() |
| for _, userID := range affectedUserIDs { |
| if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { |
| logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) |
| } |
| } |
| }() |
| } |
| if s.authCacheInvalidator != nil { |
| for _, key := range groupKeys { |
| s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) |
| } |
| } |
|
|
| return nil |
| } |
|
|
| func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) |
| if err != nil { |
| return nil, 0, err |
| } |
| return keys, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { |
| if s.userGroupRateRepo == nil { |
| return nil, nil |
| } |
| return s.userGroupRateRepo.GetByGroupID(ctx, groupID) |
| } |
|
|
| func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error { |
| if s.userGroupRateRepo == nil { |
| return nil |
| } |
| return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID) |
| } |
|
|
| func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error { |
| if s.userGroupRateRepo == nil { |
| return nil |
| } |
| return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) |
| } |
|
|
| func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { |
| return s.groupRepo.UpdateSortOrders(ctx, updates) |
| } |
|
|
| |
| |
| func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { |
| apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) |
| if err != nil { |
| return nil, err |
| } |
|
|
| if groupID == nil { |
| |
| return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil |
| } |
|
|
| if *groupID < 0 { |
| return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") |
| } |
|
|
| result := &AdminUpdateAPIKeyGroupIDResult{} |
|
|
| if *groupID == 0 { |
| |
| apiKey.GroupID = nil |
| apiKey.Group = nil |
| } else { |
| |
| group, err := s.groupRepo.GetByID(ctx, *groupID) |
| if err != nil { |
| return nil, err |
| } |
| if group.Status != StatusActive { |
| return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") |
| } |
| |
| if group.IsSubscriptionType() { |
| if s.userSubRepo == nil { |
| return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") |
| } |
| if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { |
| if errors.Is(err, ErrSubscriptionNotFound) { |
| return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") |
| } |
| return nil, err |
| } |
| } |
|
|
| gid := *groupID |
| apiKey.GroupID = &gid |
| apiKey.Group = group |
|
|
| |
| if group.IsExclusive && !group.IsSubscriptionType() { |
| opCtx := ctx |
| var tx *dbent.Tx |
| if s.entClient == nil { |
| logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") |
| } else { |
| var txErr error |
| tx, txErr = s.entClient.Tx(ctx) |
| if txErr != nil { |
| return nil, fmt.Errorf("begin transaction: %w", txErr) |
| } |
| defer func() { _ = tx.Rollback() }() |
| opCtx = dbent.NewTxContext(ctx, tx) |
| } |
|
|
| if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { |
| return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) |
| } |
| if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { |
| return nil, fmt.Errorf("update api key: %w", err) |
| } |
| if tx != nil { |
| if err := tx.Commit(); err != nil { |
| return nil, fmt.Errorf("commit transaction: %w", err) |
| } |
| } |
|
|
| result.AutoGrantedGroupAccess = true |
| result.GrantedGroupID = &gid |
| result.GrantedGroupName = group.Name |
|
|
| |
| if s.authCacheInvalidator != nil { |
| s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) |
| } |
|
|
| result.APIKey = apiKey |
| return result, nil |
| } |
| } |
|
|
| |
| if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { |
| return nil, fmt.Errorf("update api key: %w", err) |
| } |
|
|
| |
| if s.authCacheInvalidator != nil { |
| s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) |
| } |
|
|
| result.APIKey = apiKey |
| return result, nil |
| } |
|
|
| |
| func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { |
| if oldGroupID == newGroupID { |
| return nil, infraerrors.BadRequest("SAME_GROUP", "old and new group must be different") |
| } |
|
|
| |
| newGroup, err := s.groupRepo.GetByID(ctx, newGroupID) |
| if err != nil { |
| return nil, err |
| } |
| if newGroup.Status != StatusActive { |
| return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") |
| } |
| if !newGroup.IsExclusive { |
| return nil, infraerrors.BadRequest("GROUP_NOT_EXCLUSIVE", "target group is not exclusive") |
| } |
| if newGroup.IsSubscriptionType() { |
| return nil, infraerrors.BadRequest("GROUP_IS_SUBSCRIPTION", "subscription groups are not supported for replacement") |
| } |
|
|
| |
| if s.entClient == nil { |
| return nil, fmt.Errorf("entClient is nil, cannot perform group replacement") |
| } |
| tx, err := s.entClient.Tx(ctx) |
| if err != nil { |
| return nil, fmt.Errorf("begin transaction: %w", err) |
| } |
| defer func() { _ = tx.Rollback() }() |
| opCtx := dbent.NewTxContext(ctx, tx) |
|
|
| |
| if err := s.userRepo.AddGroupToAllowedGroups(opCtx, userID, newGroupID); err != nil { |
| return nil, fmt.Errorf("add new group to allowed groups: %w", err) |
| } |
|
|
| |
| migrated, err := s.apiKeyRepo.UpdateGroupIDByUserAndGroup(opCtx, userID, oldGroupID, newGroupID) |
| if err != nil { |
| return nil, fmt.Errorf("migrate api keys: %w", err) |
| } |
|
|
| |
| if err := s.userRepo.RemoveGroupFromUserAllowedGroups(opCtx, userID, oldGroupID); err != nil { |
| return nil, fmt.Errorf("remove old group from allowed groups: %w", err) |
| } |
|
|
| if err := tx.Commit(); err != nil { |
| return nil, fmt.Errorf("commit transaction: %w", err) |
| } |
|
|
| |
| if s.authCacheInvalidator != nil { |
| keys, keyErr := s.apiKeyRepo.ListKeysByUserID(ctx, userID) |
| if keyErr == nil { |
| for _, k := range keys { |
| s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, k) |
| } |
| } |
| } |
|
|
| return &ReplaceUserGroupResult{MigratedKeys: migrated}, nil |
| } |
|
|
| |
| func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) |
| if err != nil { |
| return nil, 0, err |
| } |
| now := time.Now() |
| for i := range accounts { |
| syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) |
| } |
| return accounts, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { |
| return s.accountRepo.GetByID(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { |
| if len(ids) == 0 { |
| return []*Account{}, nil |
| } |
|
|
| accounts, err := s.accountRepo.GetByIDs(ctx, ids) |
| if err != nil { |
| return nil, fmt.Errorf("failed to get accounts by IDs: %w", err) |
| } |
|
|
| return accounts, nil |
| } |
|
|
| func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { |
| |
| groupIDs := input.GroupIDs |
| |
| if len(groupIDs) == 0 && !input.SkipDefaultGroupBind { |
| defaultGroupName := input.Platform + "-default" |
| groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) |
| if err == nil { |
| for _, g := range groups { |
| if g.Name == defaultGroupName { |
| groupIDs = []int64{g.ID} |
| break |
| } |
| } |
| } |
| } |
|
|
| |
| if len(groupIDs) > 0 && !input.SkipMixedChannelCheck { |
| if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil { |
| return nil, err |
| } |
| } |
|
|
| |
| if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { |
| baseURL, _ := input.Credentials["base_url"].(string) |
| baseURL = strings.TrimSpace(baseURL) |
| if baseURL == "" { |
| return nil, errors.New("sora apikey 账号必须设置 base_url") |
| } |
| if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { |
| return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") |
| } |
| } |
|
|
| account := &Account{ |
| Name: input.Name, |
| Notes: normalizeAccountNotes(input.Notes), |
| Platform: input.Platform, |
| Type: input.Type, |
| Credentials: input.Credentials, |
| Extra: input.Extra, |
| ProxyID: input.ProxyID, |
| Concurrency: input.Concurrency, |
| Priority: input.Priority, |
| Status: StatusActive, |
| Schedulable: true, |
| } |
| |
| if account.Extra != nil { |
| if err := ValidateQuotaResetConfig(account.Extra); err != nil { |
| return nil, err |
| } |
| ComputeQuotaResetAt(account.Extra) |
| } |
| if input.ExpiresAt != nil && *input.ExpiresAt > 0 { |
| expiresAt := time.Unix(*input.ExpiresAt, 0) |
| account.ExpiresAt = &expiresAt |
| } |
| if input.AutoPauseOnExpired != nil { |
| account.AutoPauseOnExpired = *input.AutoPauseOnExpired |
| } else { |
| account.AutoPauseOnExpired = true |
| } |
| if input.RateMultiplier != nil { |
| if *input.RateMultiplier < 0 { |
| return nil, errors.New("rate_multiplier must be >= 0") |
| } |
| account.RateMultiplier = input.RateMultiplier |
| } |
| if input.LoadFactor != nil && *input.LoadFactor > 0 { |
| if *input.LoadFactor > 10000 { |
| return nil, errors.New("load_factor must be <= 10000") |
| } |
| account.LoadFactor = input.LoadFactor |
| } |
| if err := s.accountRepo.Create(ctx, account); err != nil { |
| return nil, err |
| } |
|
|
| |
| if account.Platform == PlatformSora && s.soraAccountRepo != nil { |
| soraUpdates := map[string]any{ |
| "access_token": account.GetCredential("access_token"), |
| "refresh_token": account.GetCredential("refresh_token"), |
| } |
| if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { |
| |
| logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) |
| } |
| } |
|
|
| |
| if len(groupIDs) > 0 { |
| if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { |
| return nil, err |
| } |
| } |
|
|
| return account, nil |
| } |
|
|
| func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { |
| account, err := s.accountRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| wasOveragesEnabled := account.IsOveragesEnabled() |
|
|
| if input.Name != "" { |
| account.Name = input.Name |
| } |
| if input.Type != "" { |
| account.Type = input.Type |
| } |
| if input.Notes != nil { |
| account.Notes = normalizeAccountNotes(input.Notes) |
| } |
| if len(input.Credentials) > 0 { |
| account.Credentials = input.Credentials |
| } |
| |
| |
| if input.Extra != nil { |
| |
| for _, key := range []string{"quota_used", "quota_daily_used", "quota_daily_start", "quota_weekly_used", "quota_weekly_start"} { |
| if v, ok := account.Extra[key]; ok { |
| input.Extra[key] = v |
| } |
| } |
| account.Extra = input.Extra |
| if account.Platform == PlatformAntigravity && wasOveragesEnabled && !account.IsOveragesEnabled() { |
| delete(account.Extra, "antigravity_credits_overages") |
| |
| if rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any); ok { |
| delete(rawLimits, creditsExhaustedKey) |
| } |
| } |
| if account.Platform == PlatformAntigravity && !wasOveragesEnabled && account.IsOveragesEnabled() { |
| delete(account.Extra, modelRateLimitsKey) |
| delete(account.Extra, "antigravity_credits_overages") |
| } |
| |
| if err := ValidateQuotaResetConfig(account.Extra); err != nil { |
| return nil, err |
| } |
| ComputeQuotaResetAt(account.Extra) |
| } |
| if input.ProxyID != nil { |
| |
| if *input.ProxyID == 0 { |
| account.ProxyID = nil |
| } else { |
| account.ProxyID = input.ProxyID |
| } |
| account.Proxy = nil |
| } |
| |
| if input.Concurrency != nil { |
| account.Concurrency = *input.Concurrency |
| } |
| |
| if input.Priority != nil { |
| account.Priority = *input.Priority |
| } |
| if input.RateMultiplier != nil { |
| if *input.RateMultiplier < 0 { |
| return nil, errors.New("rate_multiplier must be >= 0") |
| } |
| account.RateMultiplier = input.RateMultiplier |
| } |
| if input.LoadFactor != nil { |
| if *input.LoadFactor <= 0 { |
| account.LoadFactor = nil |
| } else if *input.LoadFactor > 10000 { |
| return nil, errors.New("load_factor must be <= 10000") |
| } else { |
| account.LoadFactor = input.LoadFactor |
| } |
| } |
| if input.Status != "" { |
| account.Status = input.Status |
| } |
| if input.ExpiresAt != nil { |
| if *input.ExpiresAt <= 0 { |
| account.ExpiresAt = nil |
| } else { |
| expiresAt := time.Unix(*input.ExpiresAt, 0) |
| account.ExpiresAt = &expiresAt |
| } |
| } |
| if input.AutoPauseOnExpired != nil { |
| account.AutoPauseOnExpired = *input.AutoPauseOnExpired |
| } |
|
|
| |
| if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { |
| baseURL, _ := account.Credentials["base_url"].(string) |
| baseURL = strings.TrimSpace(baseURL) |
| if baseURL == "" { |
| return nil, errors.New("sora apikey 账号必须设置 base_url") |
| } |
| if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { |
| return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") |
| } |
| } |
|
|
| |
| if input.GroupIDs != nil { |
| if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { |
| return nil, err |
| } |
|
|
| |
| if !input.SkipMixedChannelCheck { |
| if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil { |
| return nil, err |
| } |
| } |
| } |
|
|
| if err := s.accountRepo.Update(ctx, account); err != nil { |
| return nil, err |
| } |
|
|
| |
| if input.GroupIDs != nil { |
| if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { |
| return nil, err |
| } |
| } |
|
|
| |
| updated, err := s.accountRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| return updated, nil |
| } |
|
|
| |
| |
| func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { |
| result := &BulkUpdateAccountsResult{ |
| SuccessIDs: make([]int64, 0, len(input.AccountIDs)), |
| FailedIDs: make([]int64, 0, len(input.AccountIDs)), |
| Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), |
| } |
|
|
| if len(input.AccountIDs) == 0 { |
| return result, nil |
| } |
| if input.GroupIDs != nil { |
| if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { |
| return nil, err |
| } |
| } |
|
|
| needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck |
|
|
| |
| platformByID := map[int64]string{} |
| if needMixedChannelCheck { |
| accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) |
| if err != nil { |
| return nil, err |
| } |
| for _, account := range accounts { |
| if account != nil { |
| platformByID[account.ID] = account.Platform |
| } |
| } |
| } |
|
|
| |
| if needMixedChannelCheck { |
| for _, accountID := range input.AccountIDs { |
| platform := platformByID[accountID] |
| if platform == "" { |
| continue |
| } |
| if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { |
| return nil, err |
| } |
| } |
| } |
|
|
| if input.RateMultiplier != nil { |
| if *input.RateMultiplier < 0 { |
| return nil, errors.New("rate_multiplier must be >= 0") |
| } |
| } |
|
|
| |
| repoUpdates := AccountBulkUpdate{ |
| Credentials: input.Credentials, |
| Extra: input.Extra, |
| } |
| if input.Name != "" { |
| repoUpdates.Name = &input.Name |
| } |
| if input.ProxyID != nil { |
| repoUpdates.ProxyID = input.ProxyID |
| } |
| if input.Concurrency != nil { |
| repoUpdates.Concurrency = input.Concurrency |
| } |
| if input.Priority != nil { |
| repoUpdates.Priority = input.Priority |
| } |
| if input.RateMultiplier != nil { |
| repoUpdates.RateMultiplier = input.RateMultiplier |
| } |
| if input.LoadFactor != nil { |
| if *input.LoadFactor <= 0 { |
| repoUpdates.LoadFactor = nil |
| } else if *input.LoadFactor > 10000 { |
| return nil, errors.New("load_factor must be <= 10000") |
| } else { |
| repoUpdates.LoadFactor = input.LoadFactor |
| } |
| } |
| if input.Status != "" { |
| repoUpdates.Status = &input.Status |
| } |
| if input.Schedulable != nil { |
| repoUpdates.Schedulable = input.Schedulable |
| } |
|
|
| |
| if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { |
| return nil, err |
| } |
|
|
| |
| for _, accountID := range input.AccountIDs { |
| entry := BulkUpdateAccountResult{AccountID: accountID} |
|
|
| if input.GroupIDs != nil { |
| if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { |
| entry.Success = false |
| entry.Error = err.Error() |
| result.Failed++ |
| result.FailedIDs = append(result.FailedIDs, accountID) |
| result.Results = append(result.Results, entry) |
| continue |
| } |
| } |
|
|
| entry.Success = true |
| result.Success++ |
| result.SuccessIDs = append(result.SuccessIDs, accountID) |
| result.Results = append(result.Results, entry) |
| } |
|
|
| return result, nil |
| } |
|
|
| func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { |
| if err := s.accountRepo.Delete(ctx, id); err != nil { |
| return err |
| } |
| return nil |
| } |
|
|
| func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { |
| account, err := s.accountRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| |
| return account, nil |
| } |
|
|
| func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { |
| if err := s.accountRepo.ClearError(ctx, id); err != nil { |
| return nil, err |
| } |
| return s.accountRepo.GetByID(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { |
| return s.accountRepo.SetError(ctx, id, errorMsg) |
| } |
|
|
| func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { |
| if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { |
| return nil, err |
| } |
| updated, err := s.accountRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| return updated, nil |
| } |
|
|
| |
| func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) |
| if err != nil { |
| return nil, 0, err |
| } |
| return proxies, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) |
| if err != nil { |
| return nil, 0, err |
| } |
| s.attachProxyLatency(ctx, proxies) |
| return proxies, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { |
| return s.proxyRepo.ListActive(ctx) |
| } |
|
|
| func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { |
| proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx) |
| if err != nil { |
| return nil, err |
| } |
| s.attachProxyLatency(ctx, proxies) |
| return proxies, nil |
| } |
|
|
| func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { |
| return s.proxyRepo.GetByID(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) GetProxiesByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { |
| return s.proxyRepo.ListByIDs(ctx, ids) |
| } |
|
|
| func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { |
| proxy := &Proxy{ |
| Name: input.Name, |
| Protocol: input.Protocol, |
| Host: input.Host, |
| Port: input.Port, |
| Username: input.Username, |
| Password: input.Password, |
| Status: StatusActive, |
| } |
| if err := s.proxyRepo.Create(ctx, proxy); err != nil { |
| return nil, err |
| } |
| |
| go s.probeProxyLatency(context.Background(), proxy) |
| return proxy, nil |
| } |
|
|
| func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { |
| proxy, err := s.proxyRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
|
|
| if input.Name != "" { |
| proxy.Name = input.Name |
| } |
| if input.Protocol != "" { |
| proxy.Protocol = input.Protocol |
| } |
| if input.Host != "" { |
| proxy.Host = input.Host |
| } |
| if input.Port != 0 { |
| proxy.Port = input.Port |
| } |
| if input.Username != "" { |
| proxy.Username = input.Username |
| } |
| if input.Password != "" { |
| proxy.Password = input.Password |
| } |
| if input.Status != "" { |
| proxy.Status = input.Status |
| } |
|
|
| if err := s.proxyRepo.Update(ctx, proxy); err != nil { |
| return nil, err |
| } |
| return proxy, nil |
| } |
|
|
| func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { |
| count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) |
| if err != nil { |
| return err |
| } |
| if count > 0 { |
| return ErrProxyInUse |
| } |
| return s.proxyRepo.Delete(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) { |
| result := &ProxyBatchDeleteResult{} |
| if len(ids) == 0 { |
| return result, nil |
| } |
|
|
| for _, id := range ids { |
| count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id) |
| if err != nil { |
| result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ |
| ID: id, |
| Reason: err.Error(), |
| }) |
| continue |
| } |
| if count > 0 { |
| result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ |
| ID: id, |
| Reason: ErrProxyInUse.Error(), |
| }) |
| continue |
| } |
| if err := s.proxyRepo.Delete(ctx, id); err != nil { |
| result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{ |
| ID: id, |
| Reason: err.Error(), |
| }) |
| continue |
| } |
| result.DeletedIDs = append(result.DeletedIDs, id) |
| } |
|
|
| return result, nil |
| } |
|
|
| func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { |
| return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID) |
| } |
|
|
| func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { |
| return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) |
| } |
|
|
| |
| func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { |
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) |
| if err != nil { |
| return nil, 0, err |
| } |
| return codes, result.Total, nil |
| } |
|
|
| func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { |
| return s.redeemCodeRepo.GetByID(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { |
| |
| if input.Type == RedeemTypeSubscription { |
| if input.GroupID == nil { |
| return nil, errors.New("group_id is required for subscription type") |
| } |
| |
| group, err := s.groupRepo.GetByID(ctx, *input.GroupID) |
| if err != nil { |
| return nil, fmt.Errorf("group not found: %w", err) |
| } |
| if !group.IsSubscriptionType() { |
| return nil, errors.New("group must be subscription type") |
| } |
| } |
|
|
| codes := make([]RedeemCode, 0, input.Count) |
| for i := 0; i < input.Count; i++ { |
| codeValue, err := GenerateRedeemCode() |
| if err != nil { |
| return nil, err |
| } |
| code := RedeemCode{ |
| Code: codeValue, |
| Type: input.Type, |
| Value: input.Value, |
| Status: StatusUnused, |
| } |
| |
| if input.Type == RedeemTypeSubscription { |
| code.GroupID = input.GroupID |
| code.ValidityDays = input.ValidityDays |
| if code.ValidityDays <= 0 { |
| code.ValidityDays = 30 |
| } |
| } |
| if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { |
| return nil, err |
| } |
| codes = append(codes, code) |
| } |
| return codes, nil |
| } |
|
|
| func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { |
| return s.redeemCodeRepo.Delete(ctx, id) |
| } |
|
|
| func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { |
| var deleted int64 |
| for _, id := range ids { |
| if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { |
| deleted++ |
| } |
| } |
| return deleted, nil |
| } |
|
|
| func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { |
| code, err := s.redeemCodeRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
| code.Status = StatusExpired |
| if err := s.redeemCodeRepo.Update(ctx, code); err != nil { |
| return nil, err |
| } |
| return code, nil |
| } |
|
|
| func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { |
| proxy, err := s.proxyRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
|
|
| proxyURL := proxy.URL() |
| exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) |
| if err != nil { |
| s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ |
| Success: false, |
| Message: err.Error(), |
| UpdatedAt: time.Now(), |
| }) |
| return &ProxyTestResult{ |
| Success: false, |
| Message: err.Error(), |
| }, nil |
| } |
|
|
| latency := latencyMs |
| s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ |
| Success: true, |
| LatencyMs: &latency, |
| Message: "Proxy is accessible", |
| IPAddress: exitInfo.IP, |
| Country: exitInfo.Country, |
| CountryCode: exitInfo.CountryCode, |
| Region: exitInfo.Region, |
| City: exitInfo.City, |
| UpdatedAt: time.Now(), |
| }) |
| return &ProxyTestResult{ |
| Success: true, |
| Message: "Proxy is accessible", |
| LatencyMs: latencyMs, |
| IPAddress: exitInfo.IP, |
| City: exitInfo.City, |
| Region: exitInfo.Region, |
| Country: exitInfo.Country, |
| CountryCode: exitInfo.CountryCode, |
| }, nil |
| } |
|
|
| func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { |
| proxy, err := s.proxyRepo.GetByID(ctx, id) |
| if err != nil { |
| return nil, err |
| } |
|
|
| result := &ProxyQualityCheckResult{ |
| ProxyID: id, |
| Score: 100, |
| Grade: "A", |
| CheckedAt: time.Now().Unix(), |
| Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), |
| } |
|
|
| proxyURL := proxy.URL() |
| if s.proxyProber == nil { |
| result.Items = append(result.Items, ProxyQualityCheckItem{ |
| Target: "base_connectivity", |
| Status: "fail", |
| Message: "代理探测服务未配置", |
| }) |
| result.FailedCount++ |
| finalizeProxyQualityResult(result) |
| s.saveProxyQualitySnapshot(ctx, id, result, nil) |
| return result, nil |
| } |
|
|
| exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) |
| if err != nil { |
| result.Items = append(result.Items, ProxyQualityCheckItem{ |
| Target: "base_connectivity", |
| Status: "fail", |
| LatencyMs: latencyMs, |
| Message: err.Error(), |
| }) |
| result.FailedCount++ |
| finalizeProxyQualityResult(result) |
| s.saveProxyQualitySnapshot(ctx, id, result, nil) |
| return result, nil |
| } |
|
|
| result.ExitIP = exitInfo.IP |
| result.Country = exitInfo.Country |
| result.CountryCode = exitInfo.CountryCode |
| result.BaseLatencyMs = latencyMs |
| result.Items = append(result.Items, ProxyQualityCheckItem{ |
| Target: "base_connectivity", |
| Status: "pass", |
| LatencyMs: latencyMs, |
| Message: "代理出口连通正常", |
| }) |
| result.PassedCount++ |
|
|
| client, err := httpclient.GetClient(httpclient.Options{ |
| ProxyURL: proxyURL, |
| Timeout: proxyQualityRequestTimeout, |
| ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, |
| }) |
| if err != nil { |
| result.Items = append(result.Items, ProxyQualityCheckItem{ |
| Target: "http_client", |
| Status: "fail", |
| Message: fmt.Sprintf("创建检测客户端失败: %v", err), |
| }) |
| result.FailedCount++ |
| finalizeProxyQualityResult(result) |
| s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) |
| return result, nil |
| } |
|
|
| for _, target := range proxyQualityTargets { |
| item := runProxyQualityTarget(ctx, client, target) |
| result.Items = append(result.Items, item) |
| switch item.Status { |
| case "pass": |
| result.PassedCount++ |
| case "warn": |
| result.WarnCount++ |
| case "challenge": |
| result.ChallengeCount++ |
| default: |
| result.FailedCount++ |
| } |
| } |
|
|
| finalizeProxyQualityResult(result) |
| s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) |
| return result, nil |
| } |
|
|
| func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { |
| item := ProxyQualityCheckItem{ |
| Target: target.Target, |
| } |
|
|
| req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) |
| if err != nil { |
| item.Status = "fail" |
| item.Message = fmt.Sprintf("构建请求失败: %v", err) |
| return item |
| } |
| req.Header.Set("Accept", "application/json,text/html,*/*") |
| req.Header.Set("User-Agent", proxyQualityClientUserAgent) |
|
|
| start := time.Now() |
| resp, err := client.Do(req) |
| if err != nil { |
| item.Status = "fail" |
| item.LatencyMs = time.Since(start).Milliseconds() |
| item.Message = fmt.Sprintf("请求失败: %v", err) |
| return item |
| } |
| defer func() { _ = resp.Body.Close() }() |
| item.LatencyMs = time.Since(start).Milliseconds() |
| item.HTTPStatus = resp.StatusCode |
|
|
| body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) |
| if readErr != nil { |
| item.Status = "fail" |
| item.Message = fmt.Sprintf("读取响应失败: %v", readErr) |
| return item |
| } |
| if int64(len(body)) > proxyQualityMaxBodyBytes { |
| body = body[:proxyQualityMaxBodyBytes] |
| } |
|
|
| if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { |
| item.Status = "challenge" |
| item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) |
| item.Message = "Sora 命中 Cloudflare challenge" |
| return item |
| } |
|
|
| if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { |
| if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { |
| item.Status = "pass" |
| item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) |
| } else { |
| item.Status = "warn" |
| item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) |
| } |
| return item |
| } |
|
|
| if resp.StatusCode == http.StatusTooManyRequests { |
| item.Status = "warn" |
| item.Message = "目标返回 429,可能存在频控" |
| return item |
| } |
|
|
| item.Status = "fail" |
| item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) |
| return item |
| } |
|
|
| func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { |
| if result == nil { |
| return |
| } |
| score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 |
| if score < 0 { |
| score = 0 |
| } |
| result.Score = score |
| result.Grade = proxyQualityGrade(score) |
| result.Summary = fmt.Sprintf( |
| "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", |
| result.PassedCount, |
| result.WarnCount, |
| result.FailedCount, |
| result.ChallengeCount, |
| ) |
| } |
|
|
| func proxyQualityGrade(score int) string { |
| switch { |
| case score >= 90: |
| return "A" |
| case score >= 75: |
| return "B" |
| case score >= 60: |
| return "C" |
| case score >= 40: |
| return "D" |
| default: |
| return "F" |
| } |
| } |
|
|
| func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { |
| if result == nil { |
| return "" |
| } |
| if result.ChallengeCount > 0 { |
| return "challenge" |
| } |
| if result.FailedCount > 0 { |
| return "failed" |
| } |
| if result.WarnCount > 0 { |
| return "warn" |
| } |
| if result.PassedCount > 0 { |
| return "healthy" |
| } |
| return "failed" |
| } |
|
|
| func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { |
| if result == nil { |
| return "" |
| } |
| for _, item := range result.Items { |
| if item.CFRay != "" { |
| return item.CFRay |
| } |
| } |
| return "" |
| } |
|
|
| func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { |
| if result == nil { |
| return false |
| } |
| for _, item := range result.Items { |
| if item.Target == "base_connectivity" { |
| return item.Status == "pass" |
| } |
| } |
| return false |
| } |
|
|
| func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { |
| if result == nil { |
| return |
| } |
| score := result.Score |
| checkedAt := result.CheckedAt |
| info := &ProxyLatencyInfo{ |
| Success: proxyQualityBaseConnectivityPass(result), |
| Message: result.Summary, |
| QualityStatus: proxyQualityOverallStatus(result), |
| QualityScore: &score, |
| QualityGrade: result.Grade, |
| QualitySummary: result.Summary, |
| QualityCheckedAt: &checkedAt, |
| QualityCFRay: proxyQualityFirstCFRay(result), |
| UpdatedAt: time.Now(), |
| } |
| if result.BaseLatencyMs > 0 { |
| latency := result.BaseLatencyMs |
| info.LatencyMs = &latency |
| } |
| if exitInfo != nil { |
| info.IPAddress = exitInfo.IP |
| info.Country = exitInfo.Country |
| info.CountryCode = exitInfo.CountryCode |
| info.Region = exitInfo.Region |
| info.City = exitInfo.City |
| } |
| s.saveProxyLatency(ctx, proxyID, info) |
| } |
|
|
| func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { |
| if s.proxyProber == nil || proxy == nil { |
| return |
| } |
| exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL()) |
| if err != nil { |
| s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ |
| Success: false, |
| Message: err.Error(), |
| UpdatedAt: time.Now(), |
| }) |
| return |
| } |
|
|
| latency := latencyMs |
| s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ |
| Success: true, |
| LatencyMs: &latency, |
| Message: "Proxy is accessible", |
| IPAddress: exitInfo.IP, |
| Country: exitInfo.Country, |
| CountryCode: exitInfo.CountryCode, |
| Region: exitInfo.Region, |
| City: exitInfo.City, |
| UpdatedAt: time.Now(), |
| }) |
| } |
|
|
| |
| |
| func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { |
| |
| currentPlatform := getAccountPlatform(currentAccountPlatform) |
| if currentPlatform == "" { |
| |
| return nil |
| } |
|
|
| |
| for _, groupID := range groupIDs { |
| accounts, err := s.accountRepo.ListByGroup(ctx, groupID) |
| if err != nil { |
| return fmt.Errorf("get accounts in group %d: %w", groupID, err) |
| } |
|
|
| |
| for _, account := range accounts { |
| if currentAccountID > 0 && account.ID == currentAccountID { |
| continue |
| } |
|
|
| otherPlatform := getAccountPlatform(account.Platform) |
| if otherPlatform == "" { |
| continue |
| } |
|
|
| |
| if currentPlatform != otherPlatform { |
| group, _ := s.groupRepo.GetByID(ctx, groupID) |
| groupName := fmt.Sprintf("Group %d", groupID) |
| if group != nil { |
| groupName = group.Name |
| } |
|
|
| return &MixedChannelError{ |
| GroupID: groupID, |
| GroupName: groupName, |
| CurrentPlatform: currentPlatform, |
| OtherPlatform: otherPlatform, |
| } |
| } |
| } |
| } |
|
|
| return nil |
| } |
|
|
| func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { |
| if len(groupIDs) == 0 { |
| return nil |
| } |
| if s.groupRepo == nil { |
| return errors.New("group repository not configured") |
| } |
|
|
| if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { |
| existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) |
| if err != nil { |
| return fmt.Errorf("check groups exists: %w", err) |
| } |
| for _, groupID := range groupIDs { |
| if groupID <= 0 || !existsByID[groupID] { |
| return fmt.Errorf("get group: %w", ErrGroupNotFound) |
| } |
| } |
| return nil |
| } |
|
|
| for _, groupID := range groupIDs { |
| if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { |
| return fmt.Errorf("get group: %w", err) |
| } |
| } |
| return nil |
| } |
|
|
| |
| func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { |
| return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) |
| } |
|
|
| func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { |
| if s.proxyLatencyCache == nil || len(proxies) == 0 { |
| return |
| } |
|
|
| ids := make([]int64, 0, len(proxies)) |
| for i := range proxies { |
| ids = append(ids, proxies[i].ID) |
| } |
|
|
| latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) |
| if err != nil { |
| logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) |
| return |
| } |
|
|
| for i := range proxies { |
| info := latencies[proxies[i].ID] |
| if info == nil { |
| continue |
| } |
| if info.Success { |
| proxies[i].LatencyStatus = "success" |
| proxies[i].LatencyMs = info.LatencyMs |
| } else { |
| proxies[i].LatencyStatus = "failed" |
| } |
| proxies[i].LatencyMessage = info.Message |
| proxies[i].IPAddress = info.IPAddress |
| proxies[i].Country = info.Country |
| proxies[i].CountryCode = info.CountryCode |
| proxies[i].Region = info.Region |
| proxies[i].City = info.City |
| proxies[i].QualityStatus = info.QualityStatus |
| proxies[i].QualityScore = info.QualityScore |
| proxies[i].QualityGrade = info.QualityGrade |
| proxies[i].QualitySummary = info.QualitySummary |
| proxies[i].QualityChecked = info.QualityCheckedAt |
| } |
| } |
|
|
| func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) { |
| if s.proxyLatencyCache == nil || info == nil { |
| return |
| } |
|
|
| merged := *info |
| if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { |
| if existing := latencies[proxyID]; existing != nil { |
| if merged.QualityCheckedAt == nil && |
| merged.QualityScore == nil && |
| merged.QualityGrade == "" && |
| merged.QualityStatus == "" && |
| merged.QualitySummary == "" && |
| merged.QualityCFRay == "" { |
| merged.QualityStatus = existing.QualityStatus |
| merged.QualityScore = existing.QualityScore |
| merged.QualityGrade = existing.QualityGrade |
| merged.QualitySummary = existing.QualitySummary |
| merged.QualityCheckedAt = existing.QualityCheckedAt |
| merged.QualityCFRay = existing.QualityCFRay |
| } |
| } |
| } |
|
|
| if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { |
| logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) |
| } |
| } |
|
|
| |
| func getAccountPlatform(accountPlatform string) string { |
| switch strings.ToLower(strings.TrimSpace(accountPlatform)) { |
| case PlatformAntigravity: |
| return "Antigravity" |
| case PlatformAnthropic, "claude": |
| return "Anthropic" |
| default: |
| return "" |
| } |
| } |
|
|
| |
| type MixedChannelError struct { |
| GroupID int64 |
| GroupName string |
| CurrentPlatform string |
| OtherPlatform string |
| } |
|
|
| func (e *MixedChannelError) Error() string { |
| return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.", |
| e.GroupName, e.CurrentPlatform, e.OtherPlatform) |
| } |
|
|
| func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error { |
| return s.accountRepo.ResetQuotaUsed(ctx, id) |
| } |
|
|
| |
| |
| func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string { |
| if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { |
| return "" |
| } |
| if s.privacyClientFactory == nil { |
| return "" |
| } |
| if account.Extra != nil { |
| if _, ok := account.Extra["privacy_mode"]; ok { |
| return "" |
| } |
| } |
|
|
| token, _ := account.Credentials["access_token"].(string) |
| if token == "" { |
| return "" |
| } |
|
|
| var proxyURL string |
| if account.ProxyID != nil { |
| if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil { |
| proxyURL = p.URL() |
| } |
| } |
|
|
| mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL) |
| if mode == "" { |
| return "" |
| } |
|
|
| _ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}) |
| return mode |
| } |
|
|