| package jetbrains |
|
|
| import ( |
| "context" |
| "fmt" |
| "github.com/go-resty/resty/v2" |
| "jetbrains-ai-proxy/internal/balancer" |
| "jetbrains-ai-proxy/internal/config" |
| "jetbrains-ai-proxy/internal/types" |
| "jetbrains-ai-proxy/internal/utils" |
| "log" |
| "sync" |
| ) |
|
|
| var ( |
| jwtBalancer balancer.JWTBalancer |
| healthChecker *balancer.HealthChecker |
| initOnce sync.Once |
| configManager *config.Manager |
| ) |
|
|
| |
| func InitializeFromConfig() error { |
| var initErr error |
|
|
| initOnce.Do(func() { |
| configManager = config.GetGlobalConfig() |
|
|
| |
| if err := configManager.LoadConfig(); err != nil { |
| initErr = fmt.Errorf("failed to load config: %v", err) |
| return |
| } |
|
|
| |
| cfg := configManager.GetConfig() |
| tokens := configManager.GetJWTTokens() |
|
|
| if len(tokens) == 0 { |
| initErr = fmt.Errorf("no JWT tokens configured") |
| return |
| } |
|
|
| |
| jwtBalancer = balancer.NewJWTBalancer(tokens, cfg.LoadBalanceStrategy) |
|
|
| |
| healthChecker = balancer.NewHealthChecker(jwtBalancer) |
| if cfg.HealthCheckInterval > 0 { |
| healthChecker.SetCheckInterval(cfg.HealthCheckInterval) |
| } |
| healthChecker.Start() |
|
|
| log.Printf("JWT balancer initialized from config:") |
| log.Printf(" - Tokens: %d", len(tokens)) |
| log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy) |
| log.Printf(" - Health check interval: %v", cfg.HealthCheckInterval) |
| }) |
|
|
| return initErr |
| } |
|
|
| |
| func InitializeBalancer(tokens []string, strategy string) error { |
| if len(tokens) == 0 { |
| return fmt.Errorf("no JWT tokens provided") |
| } |
|
|
| var balanceStrategy config.LoadBalanceStrategy |
| switch strategy { |
| case "random": |
| balanceStrategy = config.Random |
| case "round_robin", "": |
| balanceStrategy = config.RoundRobin |
| default: |
| balanceStrategy = config.RoundRobin |
| } |
|
|
| |
| jwtBalancer = balancer.NewJWTBalancer(tokens, balanceStrategy) |
|
|
| |
| healthChecker = balancer.NewHealthChecker(jwtBalancer) |
| healthChecker.Start() |
|
|
| log.Printf("JWT balancer initialized with %d tokens, strategy: %s", len(tokens), string(balanceStrategy)) |
| return nil |
| } |
|
|
| |
| func ReloadConfig() error { |
| if configManager == nil { |
| return fmt.Errorf("config manager not initialized") |
| } |
|
|
| |
| if err := configManager.LoadConfig(); err != nil { |
| return fmt.Errorf("failed to reload config: %v", err) |
| } |
|
|
| |
| cfg := configManager.GetConfig() |
| tokens := configManager.GetJWTTokens() |
|
|
| if len(tokens) == 0 { |
| return fmt.Errorf("no JWT tokens in reloaded config") |
| } |
|
|
| |
| if jwtBalancer != nil { |
| jwtBalancer.RefreshTokens(tokens) |
| } |
|
|
| |
| if healthChecker != nil && cfg.HealthCheckInterval > 0 { |
| healthChecker.SetCheckInterval(cfg.HealthCheckInterval) |
| } |
|
|
| log.Printf("Config reloaded successfully:") |
| log.Printf(" - Tokens: %d", len(tokens)) |
| log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy) |
|
|
| return nil |
| } |
|
|
| |
| func StopBalancer() { |
| if healthChecker != nil { |
| healthChecker.Stop() |
| } |
| } |
|
|
| |
| func GetConfigManager() *config.Manager { |
| return configManager |
| } |
|
|
| func SendJetbrainsRequest(ctx context.Context, req *types.JetbrainsRequest) (*resty.Response, error) { |
| |
| token, err := jwtBalancer.GetToken() |
| if err != nil { |
| log.Printf("failed to get JWT token: %v", err) |
| return nil, fmt.Errorf("no available JWT tokens: %v", err) |
| } |
|
|
| resp, err := utils.RestySSEClient.R(). |
| SetContext(ctx). |
| SetHeader(types.JwtTokenKey, token). |
| SetDoNotParseResponse(true). |
| SetBody(req). |
| Post(types.ChatStreamV7) |
|
|
| if err != nil { |
| log.Printf("jetbrains ai req error: %v", err) |
| |
| jwtBalancer.MarkTokenUnhealthy(token) |
| return nil, err |
| } |
|
|
| |
| if resp.StatusCode() == 401 { |
| |
| jwtBalancer.MarkTokenUnhealthy(token) |
| log.Printf("JWT token invalid (401): %s...", token[:min(len(token), 10)]) |
| return nil, fmt.Errorf("JWT token invalid") |
| } else if resp.StatusCode() == 200 { |
| |
| jwtBalancer.MarkTokenHealthy(token) |
| } |
|
|
| return resp, nil |
| } |
|
|
| |
| func GetBalancerStats() (int, int) { |
| if jwtBalancer == nil { |
| return 0, 0 |
| } |
| return jwtBalancer.GetHealthyTokenCount(), jwtBalancer.GetTotalTokenCount() |
| } |
|
|
| |
| func min(a, b int) int { |
| if a < b { |
| return a |
| } |
| return b |
| } |
|
|