|
|
|
|
|
|
|
|
|
|
|
|
|
|
package api |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"crypto/subtle" |
|
|
"errors" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"os" |
|
|
"path/filepath" |
|
|
"strings" |
|
|
"sync" |
|
|
"sync/atomic" |
|
|
"time" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/access" |
|
|
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" |
|
|
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" |
|
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" |
|
|
log "github.com/sirupsen/logrus" |
|
|
"gopkg.in/yaml.v3" |
|
|
) |
|
|
|
|
|
const oauthCallbackSuccessHTML = `<html><head><meta charset="utf-8"><title>Authentication successful</title><script>setTimeout(function(){window.close();},5000);</script></head><body><h1>Authentication successful!</h1><p>You can close this window.</p><p>This window will close automatically in 5 seconds.</p></body></html>` |
|
|
|
|
|
type serverOptionConfig struct { |
|
|
extraMiddleware []gin.HandlerFunc |
|
|
engineConfigurator func(*gin.Engine) |
|
|
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config) |
|
|
requestLoggerFactory func(*config.Config, string) logging.RequestLogger |
|
|
localPassword string |
|
|
keepAliveEnabled bool |
|
|
keepAliveTimeout time.Duration |
|
|
keepAliveOnTimeout func() |
|
|
} |
|
|
|
|
|
|
|
|
type ServerOption func(*serverOptionConfig) |
|
|
|
|
|
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger { |
|
|
configDir := filepath.Dir(configPath) |
|
|
if base := util.WritablePath(); base != "" { |
|
|
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir) |
|
|
} |
|
|
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir) |
|
|
} |
|
|
|
|
|
|
|
|
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
cfg.extraMiddleware = append(cfg.extraMiddleware, mw...) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
cfg.engineConfigurator = fn |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
cfg.routerConfigurator = fn |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithLocalManagementPassword(password string) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
cfg.localPassword = password |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
if timeout <= 0 || onTimeout == nil { |
|
|
return |
|
|
} |
|
|
cfg.keepAliveEnabled = true |
|
|
cfg.keepAliveTimeout = timeout |
|
|
cfg.keepAliveOnTimeout = onTimeout |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption { |
|
|
return func(cfg *serverOptionConfig) { |
|
|
cfg.requestLoggerFactory = factory |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type Server struct { |
|
|
|
|
|
engine *gin.Engine |
|
|
|
|
|
|
|
|
server *http.Server |
|
|
|
|
|
|
|
|
handlers *handlers.BaseAPIHandler |
|
|
|
|
|
|
|
|
cfg *config.Config |
|
|
|
|
|
|
|
|
|
|
|
oldConfigYaml []byte |
|
|
|
|
|
|
|
|
accessManager *sdkaccess.Manager |
|
|
|
|
|
|
|
|
requestLogger logging.RequestLogger |
|
|
loggerToggle func(bool) |
|
|
|
|
|
|
|
|
configFilePath string |
|
|
|
|
|
|
|
|
currentPath string |
|
|
|
|
|
|
|
|
wsRouteMu sync.Mutex |
|
|
wsRoutes map[string]struct{} |
|
|
wsAuthChanged func(bool, bool) |
|
|
wsAuthEnabled atomic.Bool |
|
|
|
|
|
|
|
|
mgmt *managementHandlers.Handler |
|
|
|
|
|
|
|
|
ampModule *ampmodule.AmpModule |
|
|
|
|
|
|
|
|
managementRoutesRegistered atomic.Bool |
|
|
|
|
|
managementRoutesEnabled atomic.Bool |
|
|
|
|
|
|
|
|
envManagementSecret bool |
|
|
|
|
|
localPassword string |
|
|
|
|
|
keepAliveEnabled bool |
|
|
keepAliveTimeout time.Duration |
|
|
keepAliveOnTimeout func() |
|
|
keepAliveHeartbeat chan struct{} |
|
|
keepAliveStop chan struct{} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server { |
|
|
optionState := &serverOptionConfig{ |
|
|
requestLoggerFactory: defaultRequestLoggerFactory, |
|
|
} |
|
|
for i := range opts { |
|
|
opts[i](optionState) |
|
|
} |
|
|
|
|
|
if !cfg.Debug { |
|
|
gin.SetMode(gin.ReleaseMode) |
|
|
} |
|
|
|
|
|
|
|
|
engine := gin.New() |
|
|
if optionState.engineConfigurator != nil { |
|
|
optionState.engineConfigurator(engine) |
|
|
} |
|
|
|
|
|
|
|
|
engine.Use(logging.GinLogrusLogger()) |
|
|
engine.Use(logging.GinLogrusRecovery()) |
|
|
for _, mw := range optionState.extraMiddleware { |
|
|
engine.Use(mw) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var requestLogger logging.RequestLogger |
|
|
var toggle func(bool) |
|
|
if !cfg.CommercialMode { |
|
|
if optionState.requestLoggerFactory != nil { |
|
|
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath) |
|
|
} |
|
|
if requestLogger != nil { |
|
|
engine.Use(middleware.RequestLoggingMiddleware(requestLogger)) |
|
|
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok { |
|
|
toggle = setter.SetEnabled |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
engine.Use(corsMiddleware()) |
|
|
wd, err := os.Getwd() |
|
|
if err != nil { |
|
|
wd = configFilePath |
|
|
} |
|
|
|
|
|
envAdminPassword, envAdminPasswordSet := os.LookupEnv("MANAGEMENT_PASSWORD") |
|
|
envAdminPassword = strings.TrimSpace(envAdminPassword) |
|
|
envManagementSecret := envAdminPasswordSet && envAdminPassword != "" |
|
|
|
|
|
|
|
|
s := &Server{ |
|
|
engine: engine, |
|
|
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), |
|
|
cfg: cfg, |
|
|
accessManager: accessManager, |
|
|
requestLogger: requestLogger, |
|
|
loggerToggle: toggle, |
|
|
configFilePath: configFilePath, |
|
|
currentPath: wd, |
|
|
envManagementSecret: envManagementSecret, |
|
|
wsRoutes: make(map[string]struct{}), |
|
|
} |
|
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth) |
|
|
|
|
|
s.oldConfigYaml, _ = yaml.Marshal(cfg) |
|
|
s.applyAccessConfig(nil, cfg) |
|
|
if authManager != nil { |
|
|
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second) |
|
|
} |
|
|
managementasset.SetCurrentConfig(cfg) |
|
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling) |
|
|
|
|
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager) |
|
|
if optionState.localPassword != "" { |
|
|
s.mgmt.SetLocalPassword(optionState.localPassword) |
|
|
} |
|
|
logDir := filepath.Join(s.currentPath, "logs") |
|
|
if base := util.WritablePath(); base != "" { |
|
|
logDir = filepath.Join(base, "logs") |
|
|
} |
|
|
s.mgmt.SetLogDirectory(logDir) |
|
|
s.localPassword = optionState.localPassword |
|
|
|
|
|
|
|
|
s.setupRoutes() |
|
|
|
|
|
|
|
|
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager)) |
|
|
ctx := modules.Context{ |
|
|
Engine: engine, |
|
|
BaseHandler: s.handlers, |
|
|
Config: cfg, |
|
|
AuthMiddleware: AuthMiddleware(accessManager), |
|
|
} |
|
|
if err := modules.RegisterModule(ctx, s.ampModule); err != nil { |
|
|
log.Errorf("Failed to register Amp module: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
if optionState.routerConfigurator != nil { |
|
|
optionState.routerConfigurator(engine, s.handlers, cfg) |
|
|
} |
|
|
|
|
|
|
|
|
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret |
|
|
s.managementRoutesEnabled.Store(hasManagementSecret) |
|
|
if hasManagementSecret { |
|
|
s.registerManagementRoutes() |
|
|
} |
|
|
|
|
|
if optionState.keepAliveEnabled { |
|
|
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout) |
|
|
} |
|
|
|
|
|
|
|
|
s.server = &http.Server{ |
|
|
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), |
|
|
Handler: engine, |
|
|
} |
|
|
|
|
|
return s |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) setupRoutes() { |
|
|
s.engine.GET("/management.html", s.serveManagementControlPanel) |
|
|
s.engine.GET("/status.html", s.serveStatusPage) |
|
|
|
|
|
|
|
|
staticPath := os.Getenv("MANAGEMENT_STATIC_PATH") |
|
|
if staticPath == "" { |
|
|
staticPath = filepath.Join(s.currentPath, "static") |
|
|
} |
|
|
s.engine.Static("/static", staticPath) |
|
|
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) |
|
|
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers) |
|
|
geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers) |
|
|
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers) |
|
|
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers) |
|
|
|
|
|
|
|
|
v1 := s.engine.Group("/v1") |
|
|
v1.Use(AuthMiddleware(s.accessManager)) |
|
|
{ |
|
|
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) |
|
|
v1.POST("/chat/completions", openaiHandlers.ChatCompletions) |
|
|
v1.POST("/completions", openaiHandlers.Completions) |
|
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) |
|
|
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) |
|
|
v1.POST("/responses", openaiResponsesHandlers.Responses) |
|
|
} |
|
|
|
|
|
|
|
|
v1beta := s.engine.Group("/v1beta") |
|
|
v1beta.Use(AuthMiddleware(s.accessManager)) |
|
|
{ |
|
|
v1beta.GET("/models", geminiHandlers.GeminiModels) |
|
|
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler) |
|
|
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler) |
|
|
} |
|
|
|
|
|
|
|
|
s.engine.GET("/", func(c *gin.Context) { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"message": "CLI Proxy API Server", |
|
|
"endpoints": []string{ |
|
|
"POST /v1/chat/completions", |
|
|
"POST /v1/completions", |
|
|
"GET /v1/models", |
|
|
}, |
|
|
}) |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) { |
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"}) |
|
|
}) |
|
|
|
|
|
s.engine.POST("/api/provider/:provider/api/event_logging/batch", func(c *gin.Context) { |
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"}) |
|
|
}) |
|
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s.engine.GET("/anthropic/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
s.engine.GET("/codex/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
s.engine.GET("/google/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
s.engine.GET("/iflow/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
s.engine.GET("/antigravity/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
s.engine.GET("/kiro/callback", func(c *gin.Context) { |
|
|
code := c.Query("code") |
|
|
state := c.Query("state") |
|
|
errStr := c.Query("error") |
|
|
if errStr == "" { |
|
|
errStr = c.Query("error_description") |
|
|
} |
|
|
if state != "" { |
|
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr) |
|
|
} |
|
|
c.Header("Content-Type", "text/html; charset=utf-8") |
|
|
c.String(http.StatusOK, oauthCallbackSuccessHTML) |
|
|
}) |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) { |
|
|
if s == nil || s.engine == nil || handler == nil { |
|
|
return |
|
|
} |
|
|
trimmed := strings.TrimSpace(path) |
|
|
if trimmed == "" { |
|
|
trimmed = "/v1/ws" |
|
|
} |
|
|
if !strings.HasPrefix(trimmed, "/") { |
|
|
trimmed = "/" + trimmed |
|
|
} |
|
|
s.wsRouteMu.Lock() |
|
|
if _, exists := s.wsRoutes[trimmed]; exists { |
|
|
s.wsRouteMu.Unlock() |
|
|
return |
|
|
} |
|
|
s.wsRoutes[trimmed] = struct{}{} |
|
|
s.wsRouteMu.Unlock() |
|
|
|
|
|
authMiddleware := AuthMiddleware(s.accessManager) |
|
|
conditionalAuth := func(c *gin.Context) { |
|
|
if !s.wsAuthEnabled.Load() { |
|
|
c.Next() |
|
|
return |
|
|
} |
|
|
authMiddleware(c) |
|
|
} |
|
|
finalHandler := func(c *gin.Context) { |
|
|
handler.ServeHTTP(c.Writer, c.Request) |
|
|
c.Abort() |
|
|
} |
|
|
|
|
|
s.engine.GET(trimmed, conditionalAuth, finalHandler) |
|
|
} |
|
|
|
|
|
func (s *Server) registerManagementRoutes() { |
|
|
if s == nil || s.engine == nil || s.mgmt == nil { |
|
|
return |
|
|
} |
|
|
if !s.managementRoutesRegistered.CompareAndSwap(false, true) { |
|
|
return |
|
|
} |
|
|
|
|
|
log.Info("management routes registered after secret key configuration") |
|
|
|
|
|
mgmt := s.engine.Group("/v0/management") |
|
|
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware()) |
|
|
{ |
|
|
mgmt.GET("/usage", s.mgmt.GetUsageStatistics) |
|
|
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics) |
|
|
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics) |
|
|
mgmt.GET("/config", s.mgmt.GetConfig) |
|
|
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML) |
|
|
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML) |
|
|
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion) |
|
|
|
|
|
mgmt.GET("/debug", s.mgmt.GetDebug) |
|
|
mgmt.PUT("/debug", s.mgmt.PutDebug) |
|
|
mgmt.PATCH("/debug", s.mgmt.PutDebug) |
|
|
|
|
|
mgmt.GET("/logging-to-file", s.mgmt.GetLoggingToFile) |
|
|
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile) |
|
|
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile) |
|
|
|
|
|
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled) |
|
|
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled) |
|
|
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled) |
|
|
|
|
|
mgmt.GET("/proxy-url", s.mgmt.GetProxyURL) |
|
|
mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL) |
|
|
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL) |
|
|
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL) |
|
|
|
|
|
mgmt.POST("/api-call", s.mgmt.APICall) |
|
|
|
|
|
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject) |
|
|
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) |
|
|
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject) |
|
|
|
|
|
mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel) |
|
|
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) |
|
|
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel) |
|
|
|
|
|
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys) |
|
|
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys) |
|
|
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys) |
|
|
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys) |
|
|
|
|
|
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys) |
|
|
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys) |
|
|
mgmt.PATCH("/gemini-api-key", s.mgmt.PatchGeminiKey) |
|
|
mgmt.DELETE("/gemini-api-key", s.mgmt.DeleteGeminiKey) |
|
|
|
|
|
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMb) |
|
|
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMb) |
|
|
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMb) |
|
|
|
|
|
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy) |
|
|
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy) |
|
|
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy) |
|
|
|
|
|
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix) |
|
|
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix) |
|
|
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix) |
|
|
|
|
|
mgmt.GET("/logs", s.mgmt.GetLogs) |
|
|
mgmt.DELETE("/logs", s.mgmt.DeleteLogs) |
|
|
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs) |
|
|
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog) |
|
|
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID) |
|
|
mgmt.GET("/request-log", s.mgmt.GetRequestLog) |
|
|
mgmt.PUT("/request-log", s.mgmt.PutRequestLog) |
|
|
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog) |
|
|
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth) |
|
|
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) |
|
|
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) |
|
|
|
|
|
mgmt.GET("/ampcode", s.mgmt.GetAmpCode) |
|
|
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL) |
|
|
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) |
|
|
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL) |
|
|
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL) |
|
|
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey) |
|
|
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) |
|
|
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey) |
|
|
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey) |
|
|
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost) |
|
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) |
|
|
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost) |
|
|
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings) |
|
|
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings) |
|
|
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings) |
|
|
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings) |
|
|
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings) |
|
|
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) |
|
|
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings) |
|
|
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys) |
|
|
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys) |
|
|
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys) |
|
|
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys) |
|
|
|
|
|
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) |
|
|
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) |
|
|
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) |
|
|
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval) |
|
|
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval) |
|
|
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval) |
|
|
|
|
|
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys) |
|
|
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys) |
|
|
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey) |
|
|
mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey) |
|
|
|
|
|
mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys) |
|
|
mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys) |
|
|
mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey) |
|
|
mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey) |
|
|
|
|
|
mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat) |
|
|
mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat) |
|
|
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat) |
|
|
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat) |
|
|
|
|
|
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels) |
|
|
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels) |
|
|
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels) |
|
|
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels) |
|
|
|
|
|
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) |
|
|
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels) |
|
|
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) |
|
|
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) |
|
|
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile) |
|
|
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential) |
|
|
|
|
|
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) |
|
|
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken) |
|
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) |
|
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) |
|
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) |
|
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) |
|
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) |
|
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) |
|
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) |
|
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
if !s.managementRoutesEnabled.Load() { |
|
|
c.AbortWithStatus(http.StatusNotFound) |
|
|
return |
|
|
} |
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *Server) serveManagementControlPanel(c *gin.Context) { |
|
|
cfg := s.cfg |
|
|
if cfg == nil || cfg.RemoteManagement.DisableControlPanel { |
|
|
c.AbortWithStatus(http.StatusNotFound) |
|
|
return |
|
|
} |
|
|
filePath := managementasset.FilePath(s.configFilePath) |
|
|
if strings.TrimSpace(filePath) == "" { |
|
|
c.AbortWithStatus(http.StatusNotFound) |
|
|
return |
|
|
} |
|
|
if _, err := os.Stat(filePath); err != nil { |
|
|
if os.IsNotExist(err) { |
|
|
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) |
|
|
c.AbortWithStatus(http.StatusNotFound) |
|
|
return |
|
|
} |
|
|
log.WithError(err).Error("failed to stat management control panel asset") |
|
|
c.AbortWithStatus(http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
|
|
|
c.File(filePath) |
|
|
} |
|
|
|
|
|
func (s *Server) serveStatusPage(c *gin.Context) { |
|
|
staticPath := os.Getenv("MANAGEMENT_STATIC_PATH") |
|
|
if staticPath == "" { |
|
|
staticPath = filepath.Join(s.currentPath, "static") |
|
|
} |
|
|
filePath := filepath.Join(staticPath, "status.html") |
|
|
if _, err := os.Stat(filePath); err != nil { |
|
|
c.AbortWithStatus(http.StatusNotFound) |
|
|
return |
|
|
} |
|
|
c.File(filePath) |
|
|
} |
|
|
|
|
|
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) { |
|
|
if timeout <= 0 || onTimeout == nil { |
|
|
return |
|
|
} |
|
|
|
|
|
s.keepAliveEnabled = true |
|
|
s.keepAliveTimeout = timeout |
|
|
s.keepAliveOnTimeout = onTimeout |
|
|
s.keepAliveHeartbeat = make(chan struct{}, 1) |
|
|
s.keepAliveStop = make(chan struct{}, 1) |
|
|
|
|
|
s.engine.GET("/keep-alive", s.handleKeepAlive) |
|
|
|
|
|
go s.watchKeepAlive() |
|
|
} |
|
|
|
|
|
func (s *Server) handleKeepAlive(c *gin.Context) { |
|
|
if s.localPassword != "" { |
|
|
provided := strings.TrimSpace(c.GetHeader("Authorization")) |
|
|
if provided != "" { |
|
|
parts := strings.SplitN(provided, " ", 2) |
|
|
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") { |
|
|
provided = parts[1] |
|
|
} |
|
|
} |
|
|
if provided == "" { |
|
|
provided = strings.TrimSpace(c.GetHeader("X-Local-Password")) |
|
|
} |
|
|
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 { |
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"}) |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
s.signalKeepAlive() |
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"}) |
|
|
} |
|
|
|
|
|
func (s *Server) signalKeepAlive() { |
|
|
if !s.keepAliveEnabled { |
|
|
return |
|
|
} |
|
|
select { |
|
|
case s.keepAliveHeartbeat <- struct{}{}: |
|
|
default: |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *Server) watchKeepAlive() { |
|
|
if !s.keepAliveEnabled { |
|
|
return |
|
|
} |
|
|
|
|
|
timer := time.NewTimer(s.keepAliveTimeout) |
|
|
defer timer.Stop() |
|
|
|
|
|
for { |
|
|
select { |
|
|
case <-timer.C: |
|
|
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout) |
|
|
if s.keepAliveOnTimeout != nil { |
|
|
s.keepAliveOnTimeout() |
|
|
} |
|
|
return |
|
|
case <-s.keepAliveHeartbeat: |
|
|
if !timer.Stop() { |
|
|
select { |
|
|
case <-timer.C: |
|
|
default: |
|
|
} |
|
|
} |
|
|
timer.Reset(s.keepAliveTimeout) |
|
|
case <-s.keepAliveStop: |
|
|
return |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
userAgent := c.GetHeader("User-Agent") |
|
|
|
|
|
|
|
|
if strings.HasPrefix(userAgent, "claude-cli") { |
|
|
|
|
|
claudeHandler.ClaudeModels(c) |
|
|
} else { |
|
|
|
|
|
openaiHandler.OpenAIModels(c) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) Start() error { |
|
|
if s == nil || s.server == nil { |
|
|
return fmt.Errorf("failed to start HTTP server: server not initialized") |
|
|
} |
|
|
|
|
|
useTLS := s.cfg != nil && s.cfg.TLS.Enable |
|
|
if useTLS { |
|
|
cert := strings.TrimSpace(s.cfg.TLS.Cert) |
|
|
key := strings.TrimSpace(s.cfg.TLS.Key) |
|
|
if cert == "" || key == "" { |
|
|
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty") |
|
|
} |
|
|
log.Debugf("Starting API server on %s with TLS", s.server.Addr) |
|
|
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) { |
|
|
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|
|
|
log.Debugf("Starting API server on %s", s.server.Addr) |
|
|
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) { |
|
|
return fmt.Errorf("failed to start HTTP server: %v", errServe) |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) Stop(ctx context.Context) error { |
|
|
log.Debug("Stopping API server...") |
|
|
|
|
|
if s.keepAliveEnabled { |
|
|
select { |
|
|
case s.keepAliveStop <- struct{}{}: |
|
|
default: |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if err := s.server.Shutdown(ctx); err != nil { |
|
|
return fmt.Errorf("failed to shutdown HTTP server: %v", err) |
|
|
} |
|
|
|
|
|
log.Debug("API server stopped") |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func corsMiddleware() gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
c.Header("Access-Control-Allow-Origin", "*") |
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") |
|
|
c.Header("Access-Control-Allow-Headers", "*") |
|
|
|
|
|
if c.Request.Method == "OPTIONS" { |
|
|
c.AbortWithStatus(http.StatusNoContent) |
|
|
return |
|
|
} |
|
|
|
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) { |
|
|
if s == nil || s.accessManager == nil || newCfg == nil { |
|
|
return |
|
|
} |
|
|
if _, err := access.ApplyAccessProviders(s.accessManager, oldCfg, newCfg); err != nil { |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (s *Server) UpdateClients(cfg *config.Config) { |
|
|
|
|
|
var oldCfg *config.Config |
|
|
if len(s.oldConfigYaml) > 0 { |
|
|
_ = yaml.Unmarshal(s.oldConfigYaml, &oldCfg) |
|
|
} |
|
|
|
|
|
|
|
|
previousRequestLog := false |
|
|
if oldCfg != nil { |
|
|
previousRequestLog = oldCfg.RequestLog |
|
|
} |
|
|
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) { |
|
|
if s.loggerToggle != nil { |
|
|
s.loggerToggle(cfg.RequestLog) |
|
|
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok { |
|
|
toggler.SetEnabled(cfg.RequestLog) |
|
|
} |
|
|
if oldCfg != nil { |
|
|
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog) |
|
|
} else { |
|
|
log.Debugf("request logging toggled to %t", cfg.RequestLog) |
|
|
} |
|
|
} |
|
|
|
|
|
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { |
|
|
if err := logging.ConfigureLogOutput(cfg); err != nil { |
|
|
log.Errorf("failed to reconfigure log output: %v", err) |
|
|
} else { |
|
|
if oldCfg == nil { |
|
|
log.Debug("log output configuration refreshed") |
|
|
} else { |
|
|
if oldCfg.LoggingToFile != cfg.LoggingToFile { |
|
|
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile) |
|
|
} |
|
|
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { |
|
|
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled { |
|
|
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) |
|
|
if oldCfg != nil { |
|
|
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled) |
|
|
} else { |
|
|
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled) |
|
|
} |
|
|
} |
|
|
|
|
|
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling { |
|
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling) |
|
|
if oldCfg != nil { |
|
|
log.Debugf("disable_cooling updated from %t to %t", oldCfg.DisableCooling, cfg.DisableCooling) |
|
|
} else { |
|
|
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling) |
|
|
} |
|
|
} |
|
|
if s.handlers != nil && s.handlers.AuthManager != nil { |
|
|
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second) |
|
|
} |
|
|
|
|
|
|
|
|
if oldCfg == nil || oldCfg.Debug != cfg.Debug { |
|
|
util.SetLogLevel(cfg) |
|
|
if oldCfg != nil { |
|
|
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug) |
|
|
} else { |
|
|
log.Debugf("debug mode toggled to %t", cfg.Debug) |
|
|
} |
|
|
} |
|
|
|
|
|
prevSecretEmpty := true |
|
|
if oldCfg != nil { |
|
|
prevSecretEmpty = oldCfg.RemoteManagement.SecretKey == "" |
|
|
} |
|
|
newSecretEmpty := cfg.RemoteManagement.SecretKey == "" |
|
|
if s.envManagementSecret { |
|
|
s.registerManagementRoutes() |
|
|
if s.managementRoutesEnabled.CompareAndSwap(false, true) { |
|
|
log.Info("management routes enabled via MANAGEMENT_PASSWORD") |
|
|
} else { |
|
|
s.managementRoutesEnabled.Store(true) |
|
|
} |
|
|
} else { |
|
|
switch { |
|
|
case prevSecretEmpty && !newSecretEmpty: |
|
|
s.registerManagementRoutes() |
|
|
if s.managementRoutesEnabled.CompareAndSwap(false, true) { |
|
|
log.Info("management routes enabled after secret key update") |
|
|
} else { |
|
|
s.managementRoutesEnabled.Store(true) |
|
|
} |
|
|
case !prevSecretEmpty && newSecretEmpty: |
|
|
if s.managementRoutesEnabled.CompareAndSwap(true, false) { |
|
|
log.Info("management routes disabled after secret key removal") |
|
|
} else { |
|
|
s.managementRoutesEnabled.Store(false) |
|
|
} |
|
|
default: |
|
|
s.managementRoutesEnabled.Store(!newSecretEmpty) |
|
|
} |
|
|
} |
|
|
|
|
|
s.applyAccessConfig(oldCfg, cfg) |
|
|
s.cfg = cfg |
|
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth) |
|
|
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth { |
|
|
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth) |
|
|
} |
|
|
managementasset.SetCurrentConfig(cfg) |
|
|
|
|
|
s.oldConfigYaml, _ = yaml.Marshal(cfg) |
|
|
|
|
|
s.handlers.UpdateClients(&cfg.SDKConfig) |
|
|
|
|
|
if !cfg.RemoteManagement.DisableControlPanel { |
|
|
staticDir := managementasset.StaticDir(s.configFilePath) |
|
|
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) |
|
|
} |
|
|
if s.mgmt != nil { |
|
|
s.mgmt.SetConfig(cfg) |
|
|
s.mgmt.SetAuthManager(s.handlers.AuthManager) |
|
|
} |
|
|
|
|
|
|
|
|
if s.ampModule != nil { |
|
|
log.Debugf("triggering amp module config update") |
|
|
if err := s.ampModule.OnConfigUpdated(cfg); err != nil { |
|
|
log.Errorf("failed to update Amp module config: %v", err) |
|
|
} |
|
|
} else { |
|
|
log.Warnf("amp module is nil, skipping config update") |
|
|
} |
|
|
|
|
|
|
|
|
authFiles := util.CountAuthFiles(cfg.AuthDir) |
|
|
geminiAPIKeyCount := len(cfg.GeminiKey) |
|
|
claudeAPIKeyCount := len(cfg.ClaudeKey) |
|
|
codexAPIKeyCount := len(cfg.CodexKey) |
|
|
vertexAICompatCount := len(cfg.VertexCompatAPIKey) |
|
|
openAICompatCount := 0 |
|
|
for i := range cfg.OpenAICompatibility { |
|
|
entry := cfg.OpenAICompatibility[i] |
|
|
openAICompatCount += len(entry.APIKeyEntries) |
|
|
} |
|
|
|
|
|
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount |
|
|
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n", |
|
|
total, |
|
|
authFiles, |
|
|
geminiAPIKeyCount, |
|
|
claudeAPIKeyCount, |
|
|
codexAPIKeyCount, |
|
|
vertexAICompatCount, |
|
|
openAICompatCount, |
|
|
) |
|
|
} |
|
|
|
|
|
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) { |
|
|
if s == nil { |
|
|
return |
|
|
} |
|
|
s.wsAuthChanged = fn |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
if manager == nil { |
|
|
c.Next() |
|
|
return |
|
|
} |
|
|
|
|
|
result, err := manager.Authenticate(c.Request.Context(), c.Request) |
|
|
if err == nil { |
|
|
if result != nil { |
|
|
c.Set("apiKey", result.Principal) |
|
|
c.Set("accessProvider", result.Provider) |
|
|
if len(result.Metadata) > 0 { |
|
|
c.Set("accessMetadata", result.Metadata) |
|
|
} |
|
|
} |
|
|
c.Next() |
|
|
return |
|
|
} |
|
|
|
|
|
switch { |
|
|
case errors.Is(err, sdkaccess.ErrNoCredentials): |
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"}) |
|
|
case errors.Is(err, sdkaccess.ErrInvalidCredential): |
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"}) |
|
|
default: |
|
|
log.Errorf("authentication middleware error: %v", err) |
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"}) |
|
|
} |
|
|
} |
|
|
} |
|
|
|