|
|
|
|
|
|
|
|
package amp |
|
|
|
|
|
import ( |
|
|
"fmt" |
|
|
"net/http/httputil" |
|
|
"strings" |
|
|
"sync" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" |
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
|
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" |
|
|
log "github.com/sirupsen/logrus" |
|
|
) |
|
|
|
|
|
|
|
|
type Option func(*AmpModule) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type AmpModule struct { |
|
|
secretSource SecretSource |
|
|
proxy *httputil.ReverseProxy |
|
|
proxyMu sync.RWMutex |
|
|
accessManager *sdkaccess.Manager |
|
|
authMiddleware_ gin.HandlerFunc |
|
|
modelMapper *DefaultModelMapper |
|
|
enabled bool |
|
|
registerOnce sync.Once |
|
|
|
|
|
|
|
|
restrictToLocalhost bool |
|
|
restrictMu sync.RWMutex |
|
|
|
|
|
|
|
|
configMu sync.RWMutex |
|
|
lastConfig *config.AmpCode |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func New(opts ...Option) *AmpModule { |
|
|
m := &AmpModule{ |
|
|
secretSource: nil, |
|
|
} |
|
|
for _, opt := range opts { |
|
|
opt(m) |
|
|
} |
|
|
return m |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule { |
|
|
return New( |
|
|
WithAccessManager(accessManager), |
|
|
WithAuthMiddleware(authMiddleware), |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
func WithSecretSource(source SecretSource) Option { |
|
|
return func(m *AmpModule) { |
|
|
m.secretSource = source |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithAccessManager(am *sdkaccess.Manager) Option { |
|
|
return func(m *AmpModule) { |
|
|
m.accessManager = am |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func WithAuthMiddleware(middleware gin.HandlerFunc) Option { |
|
|
return func(m *AmpModule) { |
|
|
m.authMiddleware_ = middleware |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) Name() string { |
|
|
return "amp-routing" |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) forceModelMappings() bool { |
|
|
m.configMu.RLock() |
|
|
defer m.configMu.RUnlock() |
|
|
if m.lastConfig == nil { |
|
|
return false |
|
|
} |
|
|
return m.lastConfig.ForceModelMappings |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (m *AmpModule) Register(ctx modules.Context) error { |
|
|
settings := ctx.Config.AmpCode |
|
|
upstreamURL := strings.TrimSpace(settings.UpstreamURL) |
|
|
|
|
|
|
|
|
auth := m.getAuthMiddleware(ctx) |
|
|
|
|
|
|
|
|
var regErr error |
|
|
m.registerOnce.Do(func() { |
|
|
|
|
|
m.modelMapper = NewModelMapper(settings.ModelMappings) |
|
|
|
|
|
|
|
|
settingsCopy := settings |
|
|
m.lastConfig = &settingsCopy |
|
|
|
|
|
|
|
|
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost) |
|
|
|
|
|
|
|
|
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) |
|
|
|
|
|
|
|
|
|
|
|
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth) |
|
|
|
|
|
|
|
|
if upstreamURL == "" { |
|
|
log.Debug("amp upstream proxy disabled (no upstream URL configured)") |
|
|
log.Debug("amp provider alias routes registered") |
|
|
m.enabled = false |
|
|
return |
|
|
} |
|
|
|
|
|
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil { |
|
|
regErr = fmt.Errorf("failed to create amp proxy: %w", err) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Debug("amp provider alias routes registered") |
|
|
}) |
|
|
|
|
|
return regErr |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { |
|
|
if m.authMiddleware_ != nil { |
|
|
return m.authMiddleware_ |
|
|
} |
|
|
if ctx.AuthMiddleware != nil { |
|
|
return ctx.AuthMiddleware |
|
|
} |
|
|
|
|
|
log.Warn("amp module: no auth middleware provided, allowing all requests") |
|
|
return func(c *gin.Context) { |
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { |
|
|
newSettings := cfg.AmpCode |
|
|
|
|
|
|
|
|
m.configMu.RLock() |
|
|
oldSettings := m.lastConfig |
|
|
m.configMu.RUnlock() |
|
|
|
|
|
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost { |
|
|
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost) |
|
|
} |
|
|
|
|
|
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL) |
|
|
oldUpstreamURL := "" |
|
|
if oldSettings != nil { |
|
|
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL) |
|
|
} |
|
|
|
|
|
if !m.enabled && newUpstreamURL != "" { |
|
|
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil { |
|
|
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings) |
|
|
if modelMappingsChanged { |
|
|
if m.modelMapper != nil { |
|
|
m.modelMapper.UpdateMappings(newSettings.ModelMappings) |
|
|
} else if m.enabled { |
|
|
log.Warnf("amp model mapper not initialized, skipping model mapping update") |
|
|
} |
|
|
} |
|
|
|
|
|
if m.enabled { |
|
|
|
|
|
if newUpstreamURL == "" && oldUpstreamURL != "" { |
|
|
m.setProxy(nil) |
|
|
m.enabled = false |
|
|
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" { |
|
|
|
|
|
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource) |
|
|
if err != nil { |
|
|
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err) |
|
|
} else { |
|
|
m.setProxy(proxy) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) |
|
|
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings) |
|
|
if apiKeyChanged || upstreamAPIKeysChanged { |
|
|
if m.secretSource != nil { |
|
|
if ms, ok := m.secretSource.(*MappedSecretSource); ok { |
|
|
if apiKeyChanged { |
|
|
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey) |
|
|
ms.InvalidateCache() |
|
|
} |
|
|
if upstreamAPIKeysChanged { |
|
|
ms.UpdateMappings(newSettings.UpstreamAPIKeys) |
|
|
} |
|
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { |
|
|
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) |
|
|
ms.InvalidateCache() |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
m.configMu.Lock() |
|
|
settingsCopy := newSettings |
|
|
m.lastConfig = &settingsCopy |
|
|
m.configMu.Unlock() |
|
|
|
|
|
return nil |
|
|
} |
|
|
|
|
|
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error { |
|
|
if m.secretSource == nil { |
|
|
|
|
|
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 ) |
|
|
mappedSource := NewMappedSecretSource(defaultSource) |
|
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys) |
|
|
m.secretSource = mappedSource |
|
|
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok { |
|
|
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey) |
|
|
ms.InvalidateCache() |
|
|
ms.UpdateMappings(settings.UpstreamAPIKeys) |
|
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok { |
|
|
|
|
|
ms.UpdateExplicitKey(settings.UpstreamAPIKey) |
|
|
ms.InvalidateCache() |
|
|
mappedSource := NewMappedSecretSource(ms) |
|
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys) |
|
|
m.secretSource = mappedSource |
|
|
} |
|
|
|
|
|
proxy, err := createReverseProxy(upstreamURL, m.secretSource) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
m.setProxy(proxy) |
|
|
m.enabled = true |
|
|
|
|
|
log.Infof("amp upstream proxy enabled for: %s", upstreamURL) |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool { |
|
|
if old == nil { |
|
|
return len(new.ModelMappings) > 0 |
|
|
} |
|
|
|
|
|
if len(old.ModelMappings) != len(new.ModelMappings) { |
|
|
return true |
|
|
} |
|
|
|
|
|
|
|
|
type mappingInfo struct { |
|
|
to string |
|
|
regex bool |
|
|
} |
|
|
oldMap := make(map[string]mappingInfo, len(old.ModelMappings)) |
|
|
for _, mapping := range old.ModelMappings { |
|
|
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{ |
|
|
to: strings.TrimSpace(mapping.To), |
|
|
regex: mapping.Regex, |
|
|
} |
|
|
} |
|
|
|
|
|
for _, mapping := range new.ModelMappings { |
|
|
from := strings.TrimSpace(mapping.From) |
|
|
to := strings.TrimSpace(mapping.To) |
|
|
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex { |
|
|
return true |
|
|
} |
|
|
} |
|
|
|
|
|
return false |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool { |
|
|
oldKey := "" |
|
|
if old != nil { |
|
|
oldKey = strings.TrimSpace(old.UpstreamAPIKey) |
|
|
} |
|
|
newKey := strings.TrimSpace(new.UpstreamAPIKey) |
|
|
return oldKey != newKey |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool { |
|
|
if old == nil { |
|
|
return len(new.UpstreamAPIKeys) > 0 |
|
|
} |
|
|
|
|
|
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) { |
|
|
return true |
|
|
} |
|
|
|
|
|
|
|
|
type entryInfo struct { |
|
|
upstreamKey string |
|
|
clientKeys map[string]struct{} |
|
|
} |
|
|
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys)) |
|
|
for i, entry := range old.UpstreamAPIKeys { |
|
|
clientKeys := make(map[string]struct{}, len(entry.APIKeys)) |
|
|
for _, k := range entry.APIKeys { |
|
|
trimmed := strings.TrimSpace(k) |
|
|
if trimmed == "" { |
|
|
continue |
|
|
} |
|
|
clientKeys[trimmed] = struct{}{} |
|
|
} |
|
|
oldEntries[i] = entryInfo{ |
|
|
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey), |
|
|
clientKeys: clientKeys, |
|
|
} |
|
|
} |
|
|
|
|
|
for i, newEntry := range new.UpstreamAPIKeys { |
|
|
if i >= len(oldEntries) { |
|
|
return true |
|
|
} |
|
|
oldE := oldEntries[i] |
|
|
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey { |
|
|
return true |
|
|
} |
|
|
newKeys := make(map[string]struct{}, len(newEntry.APIKeys)) |
|
|
for _, k := range newEntry.APIKeys { |
|
|
trimmed := strings.TrimSpace(k) |
|
|
if trimmed == "" { |
|
|
continue |
|
|
} |
|
|
newKeys[trimmed] = struct{}{} |
|
|
} |
|
|
if len(newKeys) != len(oldE.clientKeys) { |
|
|
return true |
|
|
} |
|
|
for k := range newKeys { |
|
|
if _, ok := oldE.clientKeys[k]; !ok { |
|
|
return true |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return false |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) GetModelMapper() *DefaultModelMapper { |
|
|
return m.modelMapper |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) getProxy() *httputil.ReverseProxy { |
|
|
m.proxyMu.RLock() |
|
|
defer m.proxyMu.RUnlock() |
|
|
return m.proxy |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) { |
|
|
m.proxyMu.Lock() |
|
|
defer m.proxyMu.Unlock() |
|
|
m.proxy = proxy |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) IsRestrictedToLocalhost() bool { |
|
|
m.restrictMu.RLock() |
|
|
defer m.restrictMu.RUnlock() |
|
|
return m.restrictToLocalhost |
|
|
} |
|
|
|
|
|
|
|
|
func (m *AmpModule) setRestrictToLocalhost(restrict bool) { |
|
|
m.restrictMu.Lock() |
|
|
defer m.restrictMu.Unlock() |
|
|
m.restrictToLocalhost = restrict |
|
|
} |
|
|
|