File size: 13,316 Bytes
f606b10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
// Package amp implements the Amp CLI routing module, providing OAuth-based
// integration with Amp CLI for ChatGPT and Anthropic subscriptions.
package amp

import (
	"fmt"
	"net/http/httputil"
	"strings"
	"sync"

	"github.com/gin-gonic/gin"
	"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
	"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
	sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
	log "github.com/sirupsen/logrus"
)

// Option configures the AmpModule.
type Option func(*AmpModule)

// AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
// It provides:
//   - Reverse proxy to Amp control plane for OAuth/management
//   - Provider-specific route aliases (/api/provider/{provider}/...)
//   - Automatic gzip decompression for misconfigured upstreams
//   - Model mapping for routing unavailable models to alternatives
type AmpModule struct {
	secretSource    SecretSource
	proxy           *httputil.ReverseProxy
	proxyMu         sync.RWMutex // protects proxy for hot-reload
	accessManager   *sdkaccess.Manager
	authMiddleware_ gin.HandlerFunc
	modelMapper     *DefaultModelMapper
	enabled         bool
	registerOnce    sync.Once

	// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
	restrictToLocalhost bool
	restrictMu          sync.RWMutex

	// configMu protects lastConfig for partial reload comparison
	configMu   sync.RWMutex
	lastConfig *config.AmpCode
}

// New creates a new Amp routing module with the given options.
// This is the preferred constructor using the Option pattern.
//
// Example:
//
//	ampModule := amp.New(
//	    amp.WithAccessManager(accessManager),
//	    amp.WithAuthMiddleware(authMiddleware),
//	    amp.WithSecretSource(customSecret),
//	)
func New(opts ...Option) *AmpModule {
	m := &AmpModule{
		secretSource: nil, // Will be created on demand if not provided
	}
	for _, opt := range opts {
		opt(m)
	}
	return m
}

// NewLegacy creates a new Amp routing module using the legacy constructor signature.
// This is provided for backwards compatibility.
//
// DEPRECATED: Use New with options instead.
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
	return New(
		WithAccessManager(accessManager),
		WithAuthMiddleware(authMiddleware),
	)
}

// WithSecretSource sets a custom secret source for the module.
func WithSecretSource(source SecretSource) Option {
	return func(m *AmpModule) {
		m.secretSource = source
	}
}

// WithAccessManager sets the access manager for the module.
func WithAccessManager(am *sdkaccess.Manager) Option {
	return func(m *AmpModule) {
		m.accessManager = am
	}
}

// WithAuthMiddleware sets the authentication middleware for provider routes.
func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
	return func(m *AmpModule) {
		m.authMiddleware_ = middleware
	}
}

// Name returns the module identifier
func (m *AmpModule) Name() string {
	return "amp-routing"
}

// forceModelMappings returns whether model mappings should take precedence over local API keys
func (m *AmpModule) forceModelMappings() bool {
	m.configMu.RLock()
	defer m.configMu.RUnlock()
	if m.lastConfig == nil {
		return false
	}
	return m.lastConfig.ForceModelMappings
}

// Register sets up Amp routes if configured.
// This implements the RouteModuleV2 interface with Context.
// Routes are registered only once via sync.Once for idempotent behavior.
func (m *AmpModule) Register(ctx modules.Context) error {
	settings := ctx.Config.AmpCode
	upstreamURL := strings.TrimSpace(settings.UpstreamURL)

	// Determine auth middleware (from module or context)
	auth := m.getAuthMiddleware(ctx)

	// Use registerOnce to ensure routes are only registered once
	var regErr error
	m.registerOnce.Do(func() {
		// Initialize model mapper from config (for routing unavailable models to alternatives)
		m.modelMapper = NewModelMapper(settings.ModelMappings)

		// Store initial config for partial reload comparison
		settingsCopy := settings
		m.lastConfig = &settingsCopy

		// Initialize localhost restriction setting (hot-reloadable)
		m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)

		// Always register provider aliases - these work without an upstream
		m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)

		// Register management proxy routes once; middleware will gate access when upstream is unavailable.
		// Pass auth middleware to require valid API key for all management routes.
		m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)

		// If no upstream URL, skip proxy routes but provider aliases are still available
		if upstreamURL == "" {
			log.Debug("amp upstream proxy disabled (no upstream URL configured)")
			log.Debug("amp provider alias routes registered")
			m.enabled = false
			return
		}

		if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
			regErr = fmt.Errorf("failed to create amp proxy: %w", err)
			return
		}

		log.Debug("amp provider alias routes registered")
	})

	return regErr
}

// getAuthMiddleware returns the authentication middleware, preferring the
// module's configured middleware, then the context middleware, then a fallback.
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
	if m.authMiddleware_ != nil {
		return m.authMiddleware_
	}
	if ctx.AuthMiddleware != nil {
		return ctx.AuthMiddleware
	}
	// Fallback: no authentication (should not happen in production)
	log.Warn("amp module: no auth middleware provided, allowing all requests")
	return func(c *gin.Context) {
		c.Next()
	}
}

// OnConfigUpdated handles configuration updates with partial reload support.
// Only updates components that have actually changed to avoid unnecessary work.
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
	newSettings := cfg.AmpCode

	// Get previous config for comparison
	m.configMu.RLock()
	oldSettings := m.lastConfig
	m.configMu.RUnlock()

	if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
		m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
	}

	newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
	oldUpstreamURL := ""
	if oldSettings != nil {
		oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
	}

	if !m.enabled && newUpstreamURL != "" {
		if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
			log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
		}
	}

	// Check model mappings change
	modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
	if modelMappingsChanged {
		if m.modelMapper != nil {
			m.modelMapper.UpdateMappings(newSettings.ModelMappings)
		} else if m.enabled {
			log.Warnf("amp model mapper not initialized, skipping model mapping update")
		}
	}

	if m.enabled {
		// Check upstream URL change - now supports hot-reload
		if newUpstreamURL == "" && oldUpstreamURL != "" {
			m.setProxy(nil)
			m.enabled = false
		} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
			// Recreate proxy with new URL
			proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
			if err != nil {
				log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
			} else {
				m.setProxy(proxy)
			}
		}

		// Check API key change (both default and per-client mappings)
		apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
		upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
		if apiKeyChanged || upstreamAPIKeysChanged {
			if m.secretSource != nil {
				if ms, ok := m.secretSource.(*MappedSecretSource); ok {
					if apiKeyChanged {
						ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
						ms.InvalidateCache()
					}
					if upstreamAPIKeysChanged {
						ms.UpdateMappings(newSettings.UpstreamAPIKeys)
					}
				} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
					ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
					ms.InvalidateCache()
				}
			}
		}

	}

	// Store current config for next comparison
	m.configMu.Lock()
	settingsCopy := newSettings // copy struct
	m.lastConfig = &settingsCopy
	m.configMu.Unlock()

	return nil
}

func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
	if m.secretSource == nil {
		// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
		defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
		mappedSource := NewMappedSecretSource(defaultSource)
		mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
		m.secretSource = mappedSource
	} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
		ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
		ms.InvalidateCache()
		ms.UpdateMappings(settings.UpstreamAPIKeys)
	} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
		// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
		ms.UpdateExplicitKey(settings.UpstreamAPIKey)
		ms.InvalidateCache()
		mappedSource := NewMappedSecretSource(ms)
		mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
		m.secretSource = mappedSource
	}

	proxy, err := createReverseProxy(upstreamURL, m.secretSource)
	if err != nil {
		return err
	}

	m.setProxy(proxy)
	m.enabled = true

	log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
	return nil
}

// hasModelMappingsChanged compares old and new model mappings.
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
	if old == nil {
		return len(new.ModelMappings) > 0
	}

	if len(old.ModelMappings) != len(new.ModelMappings) {
		return true
	}

	// Build map for efficient and robust comparison
	type mappingInfo struct {
		to    string
		regex bool
	}
	oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
	for _, mapping := range old.ModelMappings {
		oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
			to:    strings.TrimSpace(mapping.To),
			regex: mapping.Regex,
		}
	}

	for _, mapping := range new.ModelMappings {
		from := strings.TrimSpace(mapping.From)
		to := strings.TrimSpace(mapping.To)
		if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
			return true
		}
	}

	return false
}

// hasAPIKeyChanged compares old and new API keys.
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
	oldKey := ""
	if old != nil {
		oldKey = strings.TrimSpace(old.UpstreamAPIKey)
	}
	newKey := strings.TrimSpace(new.UpstreamAPIKey)
	return oldKey != newKey
}

// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
	if old == nil {
		return len(new.UpstreamAPIKeys) > 0
	}

	if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
		return true
	}

	// Build map for comparison: upstreamKey -> set of clientKeys
	type entryInfo struct {
		upstreamKey string
		clientKeys  map[string]struct{}
	}
	oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
	for i, entry := range old.UpstreamAPIKeys {
		clientKeys := make(map[string]struct{}, len(entry.APIKeys))
		for _, k := range entry.APIKeys {
			trimmed := strings.TrimSpace(k)
			if trimmed == "" {
				continue
			}
			clientKeys[trimmed] = struct{}{}
		}
		oldEntries[i] = entryInfo{
			upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
			clientKeys:  clientKeys,
		}
	}

	for i, newEntry := range new.UpstreamAPIKeys {
		if i >= len(oldEntries) {
			return true
		}
		oldE := oldEntries[i]
		if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
			return true
		}
		newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
		for _, k := range newEntry.APIKeys {
			trimmed := strings.TrimSpace(k)
			if trimmed == "" {
				continue
			}
			newKeys[trimmed] = struct{}{}
		}
		if len(newKeys) != len(oldE.clientKeys) {
			return true
		}
		for k := range newKeys {
			if _, ok := oldE.clientKeys[k]; !ok {
				return true
			}
		}
	}

	return false
}

// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
	return m.modelMapper
}

// getProxy returns the current proxy instance (thread-safe for hot-reload).
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
	m.proxyMu.RLock()
	defer m.proxyMu.RUnlock()
	return m.proxy
}

// setProxy updates the proxy instance (thread-safe for hot-reload).
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
	m.proxyMu.Lock()
	defer m.proxyMu.Unlock()
	m.proxy = proxy
}

// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
func (m *AmpModule) IsRestrictedToLocalhost() bool {
	m.restrictMu.RLock()
	defer m.restrictMu.RUnlock()
	return m.restrictToLocalhost
}

// setRestrictToLocalhost updates the localhost restriction setting.
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
	m.restrictMu.Lock()
	defer m.restrictMu.Unlock()
	m.restrictToLocalhost = restrict
}