kek commited on
Commit
f606b10
·
0 Parent(s):

Fresh start: Go 1.23 + go-git/v5 compatibility

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +56 -0
  2. Dockerfile +30 -0
  3. Dockerfile.hf +27 -0
  4. LICENSE +22 -0
  5. README.md +16 -0
  6. cmd/server/main.go +535 -0
  7. config.example.yaml +281 -0
  8. config.yaml +75 -0
  9. go.mod +78 -0
  10. go.sum +196 -0
  11. internal/access/config_access/provider.go +112 -0
  12. internal/access/reconcile.go +270 -0
  13. internal/api/handlers/management/api_tools.go +538 -0
  14. internal/api/handlers/management/auth_files.go +2606 -0
  15. internal/api/handlers/management/config_basic.go +243 -0
  16. internal/api/handlers/management/config_lists.go +1090 -0
  17. internal/api/handlers/management/handler.go +277 -0
  18. internal/api/handlers/management/logs.go +592 -0
  19. internal/api/handlers/management/oauth_callback.go +100 -0
  20. internal/api/handlers/management/oauth_sessions.go +290 -0
  21. internal/api/handlers/management/quota.go +18 -0
  22. internal/api/handlers/management/usage.go +79 -0
  23. internal/api/handlers/management/vertex_import.go +156 -0
  24. internal/api/middleware/request_logging.go +122 -0
  25. internal/api/middleware/response_writer.go +382 -0
  26. internal/api/modules/amp/amp.go +428 -0
  27. internal/api/modules/amp/amp_test.go +352 -0
  28. internal/api/modules/amp/fallback_handlers.go +329 -0
  29. internal/api/modules/amp/fallback_handlers_test.go +73 -0
  30. internal/api/modules/amp/gemini_bridge.go +59 -0
  31. internal/api/modules/amp/gemini_bridge_test.go +93 -0
  32. internal/api/modules/amp/model_mapping.go +147 -0
  33. internal/api/modules/amp/model_mapping_test.go +283 -0
  34. internal/api/modules/amp/proxy.go +266 -0
  35. internal/api/modules/amp/proxy_test.go +657 -0
  36. internal/api/modules/amp/response_rewriter.go +160 -0
  37. internal/api/modules/amp/routes.go +334 -0
  38. internal/api/modules/amp/routes_test.go +381 -0
  39. internal/api/modules/amp/secret.go +248 -0
  40. internal/api/modules/amp/secret_test.go +366 -0
  41. internal/api/modules/modules.go +92 -0
  42. internal/api/server.go +1056 -0
  43. internal/api/server_test.go +111 -0
  44. internal/auth/claude/anthropic.go +32 -0
  45. internal/auth/claude/anthropic_auth.go +346 -0
  46. internal/auth/claude/errors.go +167 -0
  47. internal/auth/claude/html_templates.go +218 -0
  48. internal/auth/claude/oauth_server.go +331 -0
  49. internal/auth/claude/pkce.go +56 -0
  50. internal/auth/claude/token.go +73 -0
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Binaries
2
+ cli-proxy-api
3
+ cliproxy
4
+ *.exe
5
+
6
+ # Configuration
7
+ config.yaml
8
+ .env
9
+
10
+ # Generated content
11
+ bin/*
12
+ logs/*
13
+ conv/*
14
+ temp/*
15
+ refs/*
16
+
17
+ # Storage backends
18
+ pgstore/*
19
+ gitstore/*
20
+ objectstore/*
21
+
22
+ # Static assets
23
+ static/*
24
+
25
+ # Authentication data
26
+ auths/*
27
+ !auths/.gitkeep
28
+
29
+ # Documentation
30
+ docs/*
31
+ AGENTS.md
32
+ CLAUDE.md
33
+ GEMINI.md
34
+
35
+ # Tooling metadata
36
+ .vscode/*
37
+ .codex/*
38
+ .claude/*
39
+ .gemini/*
40
+ .serena/*
41
+ .agent/*
42
+ .agents/*
43
+ .agents/*
44
+ .opencode/*
45
+ .bmad/*
46
+ _bmad/*
47
+ _bmad-output/*
48
+ .mcp/cache/
49
+
50
+ # macOS
51
+ .DS_Store
52
+ ._*
53
+ cli-proxy-api-plus
54
+ CLIProxyAPIPlus_*.tar.gz
55
+ cli-proxy-api-plus
56
+ cli-proxy-api-plus
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Schritt: Build des Go-Proxys
2
+ FROM golang:1.23-alpine AS builder
3
+ WORKDIR /app
4
+ COPY . .
5
+ RUN go mod download
6
+ RUN CGO_ENABLED=0 GOOS=linux go build -o /app/cliproxy ./cmd/server/
7
+
8
+ # 2. Schritt: Schlankes Runtime-Image
9
+ FROM alpine:latest
10
+ RUN apk add --no-cache ca-certificates bash
11
+
12
+ # Arbeitsverzeichnis
13
+ WORKDIR /app
14
+
15
+ # Kopiere den Proxy, die Config und den statischen Web-Ordner
16
+ COPY --from=builder /app/cliproxy /app/cliproxy
17
+ COPY config.yaml /app/config.yaml
18
+ COPY static /app/static
19
+
20
+ # Start-Skript
21
+ RUN echo "#!/bin/bash" > /start.sh && \
22
+ echo "Starting CLI Proxy API on Port 7860..." >> /start.sh && \
23
+ echo "exec /app/cliproxy -config /app/config.yaml" >> /start.sh && \
24
+ chmod +x /start.sh
25
+
26
+ # Port 7860 ist Pflicht für Hugging Face
27
+ EXPOSE 7860
28
+
29
+ # Proxy starten
30
+ ENTRYPOINT ["/start.sh"]
Dockerfile.hf ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Schritt: Wir bauen den Go-Proxy (CLIProxyAPIPlus)
2
+ FROM golang:1.24-alpine AS builder
3
+ WORKDIR /app
4
+ COPY . .
5
+ RUN go mod download
6
+ RUN CGO_ENABLED=0 GOOS=linux go build -o /app/cliproxy ./cmd/server/
7
+
8
+ # 2. Schritt: Wir nehmen Puter (den Desktop)
9
+ FROM heyputer/puter:latest
10
+
11
+ USER root
12
+
13
+ # Proxy und Config kopieren
14
+ COPY --from=builder /app/cliproxy /usr/local/bin/cliproxy
15
+ COPY config.yaml /etc/cliproxy/config.yaml
16
+
17
+ # Start-Skript sauber erstellen (Alles in einer RUN-Anweisung)
18
+ RUN echo "#!/bin/bash" > /start.sh && \
19
+ echo "echo 'Starting CLI Proxy...'" >> /start.sh && \
20
+ echo "/usr/local/bin/cliproxy -config /etc/cliproxy/config.yaml &" >> /start.sh && \
21
+ echo "echo 'Starting Puter on Port 7860...'" >> /start.sh && \
22
+ echo "exec python3 /opt/puter/puter/server.py --port 7860" >> /start.sh && \
23
+ chmod +x /start.sh
24
+
25
+ EXPOSE 7860
26
+
27
+ ENTRYPOINT ["/start.sh"]
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025-2005.9 Luis Pater
4
+ Copyright (c) 2025.9-present Router-For.ME
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Socializer Admin
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # CLIProxyAPI Plus (Socializer Admin)
11
+
12
+ English | [Chinese](README_CN.md)
13
+
14
+ This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project.
15
+
16
+ Running on Hugging Face Spaces with Puter OS.
cmd/server/main.go ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package main provides the entry point for the CLI Proxy API server.
2
+ // This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
3
+ // for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
4
+ package main
5
+
6
+ import (
7
+ "context"
8
+ "errors"
9
+ "flag"
10
+ "fmt"
11
+ "io/fs"
12
+ "net/url"
13
+ "os"
14
+ "path/filepath"
15
+ "strings"
16
+ "time"
17
+
18
+ "github.com/joho/godotenv"
19
+ configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
20
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
21
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
22
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
23
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
24
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
25
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
26
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/store"
27
+ _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
28
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
29
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
30
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
31
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
32
+ log "github.com/sirupsen/logrus"
33
+ )
34
+
35
+ var (
36
+ Version = "dev"
37
+ Commit = "none"
38
+ BuildDate = "unknown"
39
+ DefaultConfigPath = ""
40
+ )
41
+
42
+ // init initializes the shared logger setup.
43
+ func init() {
44
+ logging.SetupBaseLogger()
45
+ buildinfo.Version = Version
46
+ buildinfo.Commit = Commit
47
+ buildinfo.BuildDate = BuildDate
48
+ }
49
+
50
+ // setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
51
+ // Kiro defaults to incognito mode for multi-account support.
52
+ // Users can explicitly override with --incognito or --no-incognito flags.
53
+ func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
54
+ if useIncognito {
55
+ cfg.IncognitoBrowser = true
56
+ } else if noIncognito {
57
+ cfg.IncognitoBrowser = false
58
+ } else {
59
+ cfg.IncognitoBrowser = true // Kiro default
60
+ }
61
+ }
62
+
63
+ // main is the entry point of the application.
64
+ // It parses command-line flags, loads configuration, and starts the appropriate
65
+ // service based on the provided flags (login, codex-login, or server mode).
66
+ func main() {
67
+ fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
68
+
69
+ // Command-line flags to control the application's behavior.
70
+ var login bool
71
+ var codexLogin bool
72
+ var claudeLogin bool
73
+ var qwenLogin bool
74
+ var iflowLogin bool
75
+ var iflowCookie bool
76
+ var noBrowser bool
77
+ var antigravityLogin bool
78
+ var kiroLogin bool
79
+ var kiroGoogleLogin bool
80
+ var kiroAWSLogin bool
81
+ var kiroAWSAuthCode bool
82
+ var kiroImport bool
83
+ var githubCopilotLogin bool
84
+ var projectID string
85
+ var vertexImport string
86
+ var configPath string
87
+ var password string
88
+ var noIncognito bool
89
+ var useIncognito bool
90
+
91
+ // Define command-line flags for different operation modes.
92
+ flag.BoolVar(&login, "login", false, "Login Google Account")
93
+ flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
94
+ flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
95
+ flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
96
+ flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
97
+ flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
98
+ flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
99
+ flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
100
+ flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
101
+ flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
102
+ flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
103
+ flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
104
+ flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
105
+ flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
106
+ flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
107
+ flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
108
+ flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
109
+ flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
110
+ flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
111
+ flag.StringVar(&password, "password", "", "")
112
+
113
+ flag.CommandLine.Usage = func() {
114
+ out := flag.CommandLine.Output()
115
+ _, _ = fmt.Fprintf(out, "Usage of %s\n", os.Args[0])
116
+ flag.CommandLine.VisitAll(func(f *flag.Flag) {
117
+ if f.Name == "password" {
118
+ return
119
+ }
120
+ s := fmt.Sprintf(" -%s", f.Name)
121
+ name, unquoteUsage := flag.UnquoteUsage(f)
122
+ if name != "" {
123
+ s += " " + name
124
+ }
125
+ if len(s) <= 4 {
126
+ s += " "
127
+ } else {
128
+ s += "\n "
129
+ }
130
+ if unquoteUsage != "" {
131
+ s += unquoteUsage
132
+ }
133
+ if f.DefValue != "" && f.DefValue != "false" && f.DefValue != "0" {
134
+ s += fmt.Sprintf(" (default %s)", f.DefValue)
135
+ }
136
+ _, _ = fmt.Fprint(out, s+"\n")
137
+ })
138
+ }
139
+
140
+ // Parse the command-line flags.
141
+ flag.Parse()
142
+
143
+ // Core application variables.
144
+ var err error
145
+ var cfg *config.Config
146
+ var isCloudDeploy bool
147
+ var (
148
+ usePostgresStore bool
149
+ pgStoreDSN string
150
+ pgStoreSchema string
151
+ pgStoreLocalPath string
152
+ pgStoreInst *store.PostgresStore
153
+ useGitStore bool
154
+ gitStoreRemoteURL string
155
+ gitStoreUser string
156
+ gitStorePassword string
157
+ gitStoreLocalPath string
158
+ gitStoreInst *store.GitTokenStore
159
+ gitStoreRoot string
160
+ useObjectStore bool
161
+ objectStoreEndpoint string
162
+ objectStoreAccess string
163
+ objectStoreSecret string
164
+ objectStoreBucket string
165
+ objectStoreLocalPath string
166
+ objectStoreInst *store.ObjectTokenStore
167
+ )
168
+
169
+ wd, err := os.Getwd()
170
+ if err != nil {
171
+ log.Errorf("failed to get working directory: %v", err)
172
+ return
173
+ }
174
+
175
+ // Load environment variables from .env if present.
176
+ if errLoad := godotenv.Load(filepath.Join(wd, ".env")); errLoad != nil {
177
+ if !errors.Is(errLoad, os.ErrNotExist) {
178
+ log.WithError(errLoad).Warn("failed to load .env file")
179
+ }
180
+ }
181
+
182
+ lookupEnv := func(keys ...string) (string, bool) {
183
+ for _, key := range keys {
184
+ if value, ok := os.LookupEnv(key); ok {
185
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
186
+ return trimmed, true
187
+ }
188
+ }
189
+ }
190
+ return "", false
191
+ }
192
+ writableBase := util.WritablePath()
193
+ if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
194
+ usePostgresStore = true
195
+ pgStoreDSN = value
196
+ }
197
+ if usePostgresStore {
198
+ if value, ok := lookupEnv("PGSTORE_SCHEMA", "pgstore_schema"); ok {
199
+ pgStoreSchema = value
200
+ }
201
+ if value, ok := lookupEnv("PGSTORE_LOCAL_PATH", "pgstore_local_path"); ok {
202
+ pgStoreLocalPath = value
203
+ }
204
+ if pgStoreLocalPath == "" {
205
+ if writableBase != "" {
206
+ pgStoreLocalPath = writableBase
207
+ } else {
208
+ pgStoreLocalPath = wd
209
+ }
210
+ }
211
+ useGitStore = false
212
+ }
213
+ if value, ok := lookupEnv("GITSTORE_GIT_URL", "gitstore_git_url"); ok {
214
+ useGitStore = true
215
+ gitStoreRemoteURL = value
216
+ }
217
+ if value, ok := lookupEnv("GITSTORE_GIT_USERNAME", "gitstore_git_username"); ok {
218
+ gitStoreUser = value
219
+ }
220
+ if value, ok := lookupEnv("GITSTORE_GIT_TOKEN", "gitstore_git_token"); ok {
221
+ gitStorePassword = value
222
+ }
223
+ if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
224
+ gitStoreLocalPath = value
225
+ }
226
+ if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
227
+ useObjectStore = true
228
+ objectStoreEndpoint = value
229
+ }
230
+ if value, ok := lookupEnv("OBJECTSTORE_ACCESS_KEY", "objectstore_access_key"); ok {
231
+ objectStoreAccess = value
232
+ }
233
+ if value, ok := lookupEnv("OBJECTSTORE_SECRET_KEY", "objectstore_secret_key"); ok {
234
+ objectStoreSecret = value
235
+ }
236
+ if value, ok := lookupEnv("OBJECTSTORE_BUCKET", "objectstore_bucket"); ok {
237
+ objectStoreBucket = value
238
+ }
239
+ if value, ok := lookupEnv("OBJECTSTORE_LOCAL_PATH", "objectstore_local_path"); ok {
240
+ objectStoreLocalPath = value
241
+ }
242
+
243
+ // Check for cloud deploy mode only on first execution
244
+ // Read env var name in uppercase: DEPLOY
245
+ deployEnv := os.Getenv("DEPLOY")
246
+ if deployEnv == "cloud" {
247
+ isCloudDeploy = true
248
+ }
249
+
250
+ // Determine and load the configuration file.
251
+ // Prefer the Postgres store when configured, otherwise fallback to git or local files.
252
+ var configFilePath string
253
+ if usePostgresStore {
254
+ if pgStoreLocalPath == "" {
255
+ pgStoreLocalPath = wd
256
+ }
257
+ pgStoreLocalPath = filepath.Join(pgStoreLocalPath, "pgstore")
258
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
259
+ pgStoreInst, err = store.NewPostgresStore(ctx, store.PostgresStoreConfig{
260
+ DSN: pgStoreDSN,
261
+ Schema: pgStoreSchema,
262
+ SpoolDir: pgStoreLocalPath,
263
+ })
264
+ cancel()
265
+ if err != nil {
266
+ log.Errorf("failed to initialize postgres token store: %v", err)
267
+ return
268
+ }
269
+ examplePath := filepath.Join(wd, "config.example.yaml")
270
+ ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
271
+ if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
272
+ cancel()
273
+ log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
274
+ return
275
+ }
276
+ cancel()
277
+ configFilePath = pgStoreInst.ConfigPath()
278
+ cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
279
+ if err == nil {
280
+ cfg.AuthDir = pgStoreInst.AuthDir()
281
+ log.Infof("postgres-backed token store enabled, workspace path: %s", pgStoreInst.WorkDir())
282
+ }
283
+ } else if useObjectStore {
284
+ if objectStoreLocalPath == "" {
285
+ if writableBase != "" {
286
+ objectStoreLocalPath = writableBase
287
+ } else {
288
+ objectStoreLocalPath = wd
289
+ }
290
+ }
291
+ objectStoreRoot := filepath.Join(objectStoreLocalPath, "objectstore")
292
+ resolvedEndpoint := strings.TrimSpace(objectStoreEndpoint)
293
+ useSSL := true
294
+ if strings.Contains(resolvedEndpoint, "://") {
295
+ parsed, errParse := url.Parse(resolvedEndpoint)
296
+ if errParse != nil {
297
+ log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
298
+ return
299
+ }
300
+ switch strings.ToLower(parsed.Scheme) {
301
+ case "http":
302
+ useSSL = false
303
+ case "https":
304
+ useSSL = true
305
+ default:
306
+ log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
307
+ return
308
+ }
309
+ if parsed.Host == "" {
310
+ log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
311
+ return
312
+ }
313
+ resolvedEndpoint = parsed.Host
314
+ if parsed.Path != "" && parsed.Path != "/" {
315
+ resolvedEndpoint = strings.TrimSuffix(parsed.Host+parsed.Path, "/")
316
+ }
317
+ }
318
+ resolvedEndpoint = strings.TrimRight(resolvedEndpoint, "/")
319
+ objCfg := store.ObjectStoreConfig{
320
+ Endpoint: resolvedEndpoint,
321
+ Bucket: objectStoreBucket,
322
+ AccessKey: objectStoreAccess,
323
+ SecretKey: objectStoreSecret,
324
+ LocalRoot: objectStoreRoot,
325
+ UseSSL: useSSL,
326
+ PathStyle: true,
327
+ }
328
+ objectStoreInst, err = store.NewObjectTokenStore(objCfg)
329
+ if err != nil {
330
+ log.Errorf("failed to initialize object token store: %v", err)
331
+ return
332
+ }
333
+ examplePath := filepath.Join(wd, "config.example.yaml")
334
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
335
+ if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
336
+ cancel()
337
+ log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
338
+ return
339
+ }
340
+ cancel()
341
+ configFilePath = objectStoreInst.ConfigPath()
342
+ cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
343
+ if err == nil {
344
+ if cfg == nil {
345
+ cfg = &config.Config{}
346
+ }
347
+ cfg.AuthDir = objectStoreInst.AuthDir()
348
+ log.Infof("object-backed token store enabled, bucket: %s", objectStoreBucket)
349
+ }
350
+ } else if useGitStore {
351
+ if gitStoreLocalPath == "" {
352
+ if writableBase != "" {
353
+ gitStoreLocalPath = writableBase
354
+ } else {
355
+ gitStoreLocalPath = wd
356
+ }
357
+ }
358
+ gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
359
+ authDir := filepath.Join(gitStoreRoot, "auths")
360
+ gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
361
+ gitStoreInst.SetBaseDir(authDir)
362
+ if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
363
+ log.Errorf("failed to prepare git token store: %v", errRepo)
364
+ return
365
+ }
366
+ configFilePath = gitStoreInst.ConfigPath()
367
+ if configFilePath == "" {
368
+ configFilePath = filepath.Join(gitStoreRoot, "config", "config.yaml")
369
+ }
370
+ if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
371
+ examplePath := filepath.Join(wd, "config.example.yaml")
372
+ if _, errExample := os.Stat(examplePath); errExample != nil {
373
+ log.Errorf("failed to find template config file: %v", errExample)
374
+ return
375
+ }
376
+ if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
377
+ log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
378
+ return
379
+ }
380
+ if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
381
+ log.Errorf("failed to commit initial git-backed config: %v", errCommit)
382
+ return
383
+ }
384
+ log.Infof("git-backed config initialized from template: %s", configFilePath)
385
+ } else if statErr != nil {
386
+ log.Errorf("failed to inspect git-backed config: %v", statErr)
387
+ return
388
+ }
389
+ cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
390
+ if err == nil {
391
+ cfg.AuthDir = gitStoreInst.AuthDir()
392
+ log.Infof("git-backed token store enabled, repository path: %s", gitStoreRoot)
393
+ }
394
+ } else if configPath != "" {
395
+ configFilePath = configPath
396
+ cfg, err = config.LoadConfigOptional(configPath, isCloudDeploy)
397
+ } else {
398
+ wd, err = os.Getwd()
399
+ if err != nil {
400
+ log.Errorf("failed to get working directory: %v", err)
401
+ return
402
+ }
403
+ configFilePath = filepath.Join(wd, "config.yaml")
404
+ cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
405
+ }
406
+ if err != nil {
407
+ log.Errorf("failed to load config: %v", err)
408
+ return
409
+ }
410
+ if cfg == nil {
411
+ cfg = &config.Config{}
412
+ }
413
+
414
+ // In cloud deploy mode, check if we have a valid configuration
415
+ var configFileExists bool
416
+ if isCloudDeploy {
417
+ if info, errStat := os.Stat(configFilePath); errStat != nil {
418
+ // Don't mislead: API server will not start until configuration is provided.
419
+ log.Info("Cloud deploy mode: No configuration file detected; standing by for configuration")
420
+ configFileExists = false
421
+ } else if info.IsDir() {
422
+ log.Info("Cloud deploy mode: Config path is a directory; standing by for configuration")
423
+ configFileExists = false
424
+ } else if cfg.Port == 0 {
425
+ // LoadConfigOptional returns empty config when file is empty or invalid.
426
+ // Config file exists but is empty or invalid; treat as missing config
427
+ log.Info("Cloud deploy mode: Configuration file is empty or invalid; standing by for valid configuration")
428
+ configFileExists = false
429
+ } else {
430
+ log.Info("Cloud deploy mode: Configuration file detected; starting service")
431
+ configFileExists = true
432
+ }
433
+ }
434
+ usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
435
+ coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
436
+
437
+ if err = logging.ConfigureLogOutput(cfg); err != nil {
438
+ log.Errorf("failed to configure log output: %v", err)
439
+ return
440
+ }
441
+
442
+ log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
443
+
444
+ // Set the log level based on the configuration.
445
+ util.SetLogLevel(cfg)
446
+
447
+ if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
448
+ log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
449
+ return
450
+ } else {
451
+ cfg.AuthDir = resolvedAuthDir
452
+ }
453
+ managementasset.SetCurrentConfig(cfg)
454
+
455
+ // Create login options to be used in authentication flows.
456
+ options := &cmd.LoginOptions{
457
+ NoBrowser: noBrowser,
458
+ }
459
+
460
+ // Register the shared token store once so all components use the same persistence backend.
461
+ if usePostgresStore {
462
+ sdkAuth.RegisterTokenStore(pgStoreInst)
463
+ } else if useObjectStore {
464
+ sdkAuth.RegisterTokenStore(objectStoreInst)
465
+ } else if useGitStore {
466
+ sdkAuth.RegisterTokenStore(gitStoreInst)
467
+ } else {
468
+ sdkAuth.RegisterTokenStore(sdkAuth.NewFileTokenStore())
469
+ }
470
+
471
+ // Register built-in access providers before constructing services.
472
+ configaccess.Register()
473
+
474
+ // Handle different command modes based on the provided flags.
475
+
476
+ if vertexImport != "" {
477
+ // Handle Vertex service account import
478
+ cmd.DoVertexImport(cfg, vertexImport)
479
+ } else if login {
480
+ // Handle Google/Gemini login
481
+ cmd.DoLogin(cfg, projectID, options)
482
+ } else if antigravityLogin {
483
+ // Handle Antigravity login
484
+ cmd.DoAntigravityLogin(cfg, options)
485
+ } else if githubCopilotLogin {
486
+ // Handle GitHub Copilot login
487
+ cmd.DoGitHubCopilotLogin(cfg, options)
488
+ } else if codexLogin {
489
+ // Handle Codex login
490
+ cmd.DoCodexLogin(cfg, options)
491
+ } else if claudeLogin {
492
+ // Handle Claude login
493
+ cmd.DoClaudeLogin(cfg, options)
494
+ } else if qwenLogin {
495
+ cmd.DoQwenLogin(cfg, options)
496
+ } else if iflowLogin {
497
+ cmd.DoIFlowLogin(cfg, options)
498
+ } else if iflowCookie {
499
+ cmd.DoIFlowCookieAuth(cfg, options)
500
+ } else if kiroLogin {
501
+ // For Kiro auth, default to incognito mode for multi-account support
502
+ // Users can explicitly override with --no-incognito
503
+ // Note: This config mutation is safe - auth commands exit after completion
504
+ // and don't share config with StartService (which is in the else branch)
505
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
506
+ cmd.DoKiroLogin(cfg, options)
507
+ } else if kiroGoogleLogin {
508
+ // For Kiro auth, default to incognito mode for multi-account support
509
+ // Users can explicitly override with --no-incognito
510
+ // Note: This config mutation is safe - auth commands exit after completion
511
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
512
+ cmd.DoKiroGoogleLogin(cfg, options)
513
+ } else if kiroAWSLogin {
514
+ // For Kiro auth, default to incognito mode for multi-account support
515
+ // Users can explicitly override with --no-incognito
516
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
517
+ cmd.DoKiroAWSLogin(cfg, options)
518
+ } else if kiroAWSAuthCode {
519
+ // For Kiro auth with authorization code flow (better UX)
520
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
521
+ cmd.DoKiroAWSAuthCodeLogin(cfg, options)
522
+ } else if kiroImport {
523
+ cmd.DoKiroImport(cfg, options)
524
+ } else {
525
+ // In cloud deploy mode without config file, just wait for shutdown signals
526
+ if isCloudDeploy && !configFileExists {
527
+ // No config file available, just wait for shutdown
528
+ cmd.WaitForCloudDeploy()
529
+ return
530
+ }
531
+ // Start the main proxy service
532
+ managementasset.StartAutoUpdater(context.Background(), configFilePath)
533
+ cmd.StartService(cfg, configFilePath, password)
534
+ }
535
+ }
config.example.yaml ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
2
+ # Use "127.0.0.1" or "localhost" to restrict access to local machine only.
3
+ host: ""
4
+
5
+ # Server port
6
+ port: 8317
7
+
8
+ # TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
9
+ tls:
10
+ enable: false
11
+ cert: ""
12
+ key: ""
13
+
14
+ # Management API settings
15
+ remote-management:
16
+ # Whether to allow remote (non-localhost) management access.
17
+ # When false, only localhost can access management endpoints (a key is still required).
18
+ allow-remote: false
19
+
20
+ # Management key. If a plaintext value is provided here, it will be hashed on startup.
21
+ # All management requests (even from localhost) require this key.
22
+ # Leave empty to disable the Management API entirely (404 for all /v0/management routes).
23
+ secret-key: ""
24
+
25
+ # Disable the bundled management control panel asset download and HTTP route when true.
26
+ disable-control-panel: false
27
+
28
+ # GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
29
+ panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
30
+
31
+ # Authentication directory (supports ~ for home directory)
32
+ auth-dir: "~/.cli-proxy-api"
33
+
34
+ # API keys for authentication
35
+ api-keys:
36
+ - "your-api-key-1"
37
+ - "your-api-key-2"
38
+ - "your-api-key-3"
39
+
40
+ # Enable debug logging
41
+ debug: false
42
+
43
+ # When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
44
+ commercial-mode: false
45
+
46
+ # Open OAuth URLs in incognito/private browser mode.
47
+ # Useful when you want to login with a different account without logging out from your current session.
48
+ # Default: false (but Kiro auth defaults to true for multi-account support)
49
+ incognito-browser: true
50
+
51
+ # When true, write application logs to rotating files instead of stdout
52
+ logging-to-file: false
53
+
54
+ # Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
55
+ # files are deleted until within the limit. Set to 0 to disable.
56
+ logs-max-total-size-mb: 0
57
+
58
+ # When false, disable in-memory usage statistics aggregation
59
+ usage-statistics-enabled: false
60
+
61
+ # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
62
+ proxy-url: ""
63
+
64
+ # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
65
+ force-model-prefix: false
66
+
67
+ # Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
68
+ request-retry: 3
69
+
70
+ # Maximum wait time in seconds for a cooled-down credential before triggering a retry.
71
+ max-retry-interval: 30
72
+
73
+ # Quota exceeded behavior
74
+ quota-exceeded:
75
+ switch-project: true # Whether to automatically switch to another project when a quota is exceeded
76
+ switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
77
+
78
+ # Routing strategy for selecting credentials when multiple match.
79
+ routing:
80
+ strategy: "round-robin" # round-robin (default), fill-first
81
+
82
+ # When true, enable authentication for the WebSocket API (/v1/ws).
83
+ ws-auth: false
84
+
85
+ # Streaming behavior (SSE keep-alives + safe bootstrap retries).
86
+ # streaming:
87
+ # keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
88
+ # bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
89
+
90
+ # Gemini API keys
91
+ # gemini-api-key:
92
+ # - api-key: "AIzaSy...01"
93
+ # prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
94
+ # base-url: "https://generativelanguage.googleapis.com"
95
+ # headers:
96
+ # X-Custom-Header: "custom-value"
97
+ # proxy-url: "socks5://proxy.example.com:1080"
98
+ # models:
99
+ # - name: "gemini-2.5-flash" # upstream model name
100
+ # alias: "gemini-flash" # client alias mapped to the upstream model
101
+ # excluded-models:
102
+ # - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
103
+ # - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
104
+ # - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
105
+ # - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
106
+ # - api-key: "AIzaSy...02"
107
+
108
+ # Codex API keys
109
+ # codex-api-key:
110
+ # - api-key: "sk-atSM..."
111
+ # prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
112
+ # base-url: "https://www.example.com" # use the custom codex API endpoint
113
+ # headers:
114
+ # X-Custom-Header: "custom-value"
115
+ # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
116
+ # models:
117
+ # - name: "gpt-5-codex" # upstream model name
118
+ # alias: "codex-latest" # client alias mapped to the upstream model
119
+ # excluded-models:
120
+ # - "gpt-5.1" # exclude specific models (exact match)
121
+ # - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
122
+ # - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini)
123
+ # - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low)
124
+
125
+ # Claude API keys
126
+ # claude-api-key:
127
+ # - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
128
+ # - api-key: "sk-atSM..."
129
+ # prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
130
+ # base-url: "https://www.example.com" # use the custom claude API endpoint
131
+ # headers:
132
+ # X-Custom-Header: "custom-value"
133
+ # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
134
+ # models:
135
+ # - name: "claude-3-5-sonnet-20241022" # upstream model name
136
+ # alias: "claude-sonnet-latest" # client alias mapped to the upstream model
137
+ # excluded-models:
138
+ # - "claude-opus-4-5-20251101" # exclude specific models (exact match)
139
+ # - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
140
+ # - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
141
+ # - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
142
+
143
+ # Kiro (AWS CodeWhisperer) configuration
144
+ # Note: Kiro API currently only operates in us-east-1 region
145
+ #kiro:
146
+ # - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
147
+ # agent-task-type: "" # optional: "vibe" or empty (API default)
148
+ # - access-token: "aoaAAAAA..." # or provide tokens directly
149
+ # refresh-token: "aorAAAAA..."
150
+ # profile-arn: "arn:aws:codewhisperer:us-east-1:..."
151
+ # proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
152
+
153
+ # OpenAI compatibility providers
154
+ # openai-compatibility:
155
+ # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
156
+ # prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
157
+ # base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
158
+ # headers:
159
+ # X-Custom-Header: "custom-value"
160
+ # api-key-entries:
161
+ # - api-key: "sk-or-v1-...b780"
162
+ # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
163
+ # - api-key: "sk-or-v1-...b781" # without proxy-url
164
+ # models: # The models supported by the provider.
165
+ # - name: "moonshotai/kimi-k2:free" # The actual model name.
166
+ # alias: "kimi-k2" # The alias used in the API.
167
+
168
+ # Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
169
+ # vertex-api-key:
170
+ # - api-key: "vk-123..." # x-goog-api-key header
171
+ # prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
172
+ # base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
173
+ # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
174
+ # headers:
175
+ # X-Custom-Header: "custom-value"
176
+ # models: # optional: map aliases to upstream model names
177
+ # - name: "gemini-2.5-flash" # upstream model name
178
+ # alias: "vertex-flash" # client-visible alias
179
+ # - name: "gemini-2.5-pro"
180
+ # alias: "vertex-pro"
181
+
182
+ # Amp Integration
183
+ # ampcode:
184
+ # # Configure upstream URL for Amp CLI OAuth and management features
185
+ # upstream-url: "https://ampcode.com"
186
+ # # Optional: Override API key for Amp upstream (otherwise uses env or file)
187
+ # upstream-api-key: ""
188
+ # # Per-client upstream API key mapping
189
+ # # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
190
+ # # Useful when different clients need to use different Amp accounts/quotas.
191
+ # # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
192
+ # upstream-api-keys:
193
+ # - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
194
+ # api-keys: # Client keys that use this upstream key
195
+ # - "your-api-key-1"
196
+ # - "your-api-key-2"
197
+ # - upstream-api-key: "amp_key_for_team_b"
198
+ # api-keys:
199
+ # - "your-api-key-3"
200
+ # # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
201
+ # restrict-management-to-localhost: false
202
+ # # Force model mappings to run before checking local API keys (default: false)
203
+ # force-model-mappings: false
204
+ # # Amp Model Mappings
205
+ # # Route unavailable Amp models to alternative models available in your local proxy.
206
+ # # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
207
+ # # but you have a similar model available (e.g., Claude Sonnet 4).
208
+ # model-mappings:
209
+ # - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
210
+ # to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
211
+ # - from: "claude-sonnet-4-5-20250929"
212
+ # to: "gemini-claude-sonnet-4-5-thinking"
213
+ # - from: "claude-haiku-4-5-20251001"
214
+ # to: "gemini-2.5-flash"
215
+
216
+ # Global OAuth model name mappings (per channel)
217
+ # These mappings rename model IDs for both model listing and request routing.
218
+ # Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
219
+ # NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
220
+ # oauth-model-mappings:
221
+ # gemini-cli:
222
+ # - name: "gemini-2.5-pro" # original model name under this channel
223
+ # alias: "g2.5p" # client-visible alias
224
+ # vertex:
225
+ # - name: "gemini-2.5-pro"
226
+ # alias: "g2.5p"
227
+ # aistudio:
228
+ # - name: "gemini-2.5-pro"
229
+ # alias: "g2.5p"
230
+ # antigravity:
231
+ # - name: "gemini-3-pro-preview"
232
+ # alias: "g3p"
233
+ # claude:
234
+ # - name: "claude-sonnet-4-5-20250929"
235
+ # alias: "cs4.5"
236
+ # codex:
237
+ # - name: "gpt-5"
238
+ # alias: "g5"
239
+ # qwen:
240
+ # - name: "qwen3-coder-plus"
241
+ # alias: "qwen-plus"
242
+ # iflow:
243
+ # - name: "glm-4.7"
244
+ # alias: "glm-god"
245
+
246
+ # OAuth provider excluded models
247
+ # oauth-excluded-models:
248
+ # gemini-cli:
249
+ # - "gemini-2.5-pro" # exclude specific models (exact match)
250
+ # - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
251
+ # - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
252
+ # - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
253
+ # vertex:
254
+ # - "gemini-3-pro-preview"
255
+ # aistudio:
256
+ # - "gemini-3-pro-preview"
257
+ # antigravity:
258
+ # - "gemini-3-pro-preview"
259
+ # claude:
260
+ # - "claude-3-5-haiku-20241022"
261
+ # codex:
262
+ # - "gpt-5-codex-mini"
263
+ # qwen:
264
+ # - "vision-model"
265
+ # iflow:
266
+ # - "tstars2.0"
267
+
268
+ # Optional payload configuration
269
+ # payload:
270
+ # default: # Default rules only set parameters when they are missing in the payload.
271
+ # - models:
272
+ # - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
273
+ # protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
274
+ # params: # JSON path (gjson/sjson syntax) -> value
275
+ # "generationConfig.thinkingConfig.thinkingBudget": 32768
276
+ # override: # Override rules always set parameters, overwriting any existing values.
277
+ # - models:
278
+ # - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
279
+ # protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
280
+ # params: # JSON path (gjson/sjson syntax) -> value
281
+ # "reasoning.effort": "high"
config.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIProxyAPI Plus - Ultimate Power Config
2
+ host: ""
3
+ port: 7860
4
+
5
+ # TLS settings (disabled for HF)
6
+ tls:
7
+ enable: false
8
+
9
+ # Management API settings (REMOTE ENABLED FOR HF)
10
+ remote-management:
11
+ allow-remote: true
12
+ secret-key: "$2a$10$Yt27TytUvABKw192YdTW2urLkQ5oQkHGuSz6PrzFFlsNJ5TE1EOFe"
13
+ disable-control-panel: false
14
+ panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
15
+
16
+ # Storage
17
+ auth-dir: "./auth"
18
+
19
+ # Client Keys for YOU to access the API
20
+ api-keys:
21
+ - "sk-admin-power-1"
22
+ - "sk-client-key-custom"
23
+
24
+ # Performance & Logging
25
+ debug: false
26
+ commercial-mode: false
27
+ incognito-browser: true
28
+ logging-to-file: false
29
+ usage-statistics-enabled: true
30
+
31
+ # Network
32
+ proxy-url: ""
33
+
34
+ # Smart Routing & Retries
35
+ request-retry: 5
36
+ max-retry-interval: 15
37
+ routing:
38
+ strategy: "round-robin" # Rotates through your 16 keys!
39
+
40
+ quota-exceeded:
41
+ switch-project: true
42
+ switch-preview-model: true
43
+
44
+ # --- PROVIDER SECTIONS (FILL THESE IN THE ADMIN PANEL) ---
45
+
46
+ # 1. Gemini API Keys (Google AI Studio)
47
+ gemini-api-key:
48
+ - api-key: "DEIN_GEMINI_KEY_1"
49
+ - api-key: "DEIN_GEMINI_KEY_2"
50
+
51
+ # 2. Claude API Keys (Anthropic)
52
+ claude-api-key:
53
+ - api-key: "DEIN_CLAUDE_KEY_1"
54
+
55
+ # 3. OpenAI / Compatibility (OpenRouter, DeepSeek, Groq, etc.)
56
+ openai-compatibility:
57
+ - name: "openrouter"
58
+ base-url: "https://openrouter.ai/api/v1"
59
+ api-key-entries:
60
+ - api-key: "DEIN_OPENROUTER_KEY"
61
+ - name: "deepseek"
62
+ base-url: "https://api.deepseek.com/v1"
63
+ api-key-entries:
64
+ - api-key: "DEIN_DEEPSEEK_KEY"
65
+
66
+ # 4. Global Model Mappings (Rename models for easier use)
67
+ oauth-model-mappings:
68
+ gemini-cli:
69
+ - name: "gemini-2.5-pro"
70
+ alias: "g-pro"
71
+ - name: "gemini-2.5-flash"
72
+ alias: "g-flash"
73
+ claude:
74
+ - name: "claude-3-5-sonnet-20241022"
75
+ alias: "sonnet"
go.mod ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module github.com/router-for-me/CLIProxyAPI/v6
2
+
3
+ go 1.23
4
+
5
+ require (
6
+ github.com/andybalholm/brotli v1.0.6
7
+ github.com/fsnotify/fsnotify v1.9.0
8
+ github.com/gin-gonic/gin v1.10.1
9
+ github.com/go-git/go-git/v5 v5.12.0
10
+ github.com/google/uuid v1.6.0
11
+ github.com/gorilla/websocket v1.5.3
12
+ github.com/jackc/pgx/v5 v5.7.0
13
+ github.com/joho/godotenv v1.5.1
14
+ github.com/klauspost/compress v1.17.4
15
+ github.com/minio/minio-go/v7 v7.0.66
16
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
17
+ github.com/sirupsen/logrus v1.9.3
18
+ github.com/tidwall/gjson v1.18.0
19
+ github.com/tidwall/sjson v1.2.5
20
+ github.com/tiktoken-go/tokenizer v0.7.0
21
+ golang.org/x/crypto v0.46.0
22
+ golang.org/x/net v0.48.0
23
+ golang.org/x/oauth2 v0.21.0
24
+ golang.org/x/term v0.38.0
25
+ gopkg.in/natefinch/lumberjack.v2 v2.2.1
26
+ gopkg.in/yaml.v3 v3.0.1
27
+ )
28
+
29
+ require (
30
+ cloud.google.com/go/compute/metadata v0.3.0 // indirect
31
+ github.com/Microsoft/go-winio v0.6.2 // indirect
32
+ github.com/ProtonMail/go-crypto v1.3.0 // indirect
33
+ github.com/bytedance/sonic v1.11.6 // indirect
34
+ github.com/bytedance/sonic/loader v0.1.1 // indirect
35
+ github.com/cloudflare/circl v1.6.1 // indirect
36
+ github.com/cloudwego/base64x v0.1.4 // indirect
37
+ github.com/cloudwego/iasm v0.2.0 // indirect
38
+ github.com/cyphar/filepath-securejoin v0.6.1 // indirect
39
+ github.com/dlclark/regexp2 v1.11.5 // indirect
40
+ github.com/dustin/go-humanize v1.0.1 // indirect
41
+ github.com/emirpasic/gods v1.18.1 // indirect
42
+ github.com/gabriel-vasile/mimetype v1.4.3 // indirect
43
+ github.com/gin-contrib/sse v0.1.0 // indirect
44
+ github.com/go-git/gcfg v2.0.2 // indirect
45
+ github.com/go-playground/locales v0.14.1 // indirect
46
+ github.com/go-playground/universal-translator v0.18.1 // indirect
47
+ github.com/go-playground/validator/v10 v10.20.0 // indirect
48
+ github.com/goccy/go-json v0.10.2 // indirect
49
+ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
50
+ github.com/google/go-cmp v0.6.0 // indirect
51
+ github.com/jackc/pgpassfile v1.0.0 // indirect
52
+ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
53
+ github.com/jackc/puddle/v2 v2.2.1 // indirect
54
+ github.com/json-iterator/go v1.1.12 // indirect
55
+ github.com/kevinburke/ssh_config v1.4.0 // indirect
56
+ github.com/klauspost/cpuid/v2 v2.3.0 // indirect
57
+ github.com/kr/text v0.2.0 // indirect
58
+ github.com/leodido/go-urn v1.4.0 // indirect
59
+ github.com/mattn/go-isatty v0.0.20 // indirect
60
+ github.com/minio/md5-simd v1.1.2 // indirect
61
+ github.com/minio/sha256-simd v1.0.1 // indirect
62
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
63
+ github.com/modern-go/reflect2 v1.0.2 // indirect
64
+ github.com/pelletier/go-toml/v2 v2.2.2 // indirect
65
+ github.com/pjbgf/sha1cd v0.5.0 // indirect
66
+ github.com/rs/xid v1.5.0 // indirect
67
+ github.com/sergi/go-diff v1.4.0 // indirect
68
+ github.com/tidwall/match v1.1.1 // indirect
69
+ github.com/tidwall/pretty v1.2.0 // indirect
70
+ github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
71
+ github.com/ugorji/go/codec v1.2.12 // indirect
72
+ golang.org/x/arch v0.8.0 // indirect
73
+ golang.org/x/sync v0.19.0 // indirect
74
+ golang.org/x/sys v0.39.0 // indirect
75
+ golang.org/x/text v0.32.0 // indirect
76
+ google.golang.org/protobuf v1.34.1 // indirect
77
+ gopkg.in/ini.v1 v1.67.0 // indirect
78
+ )
go.sum ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
2
+ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
3
+ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
4
+ github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
5
+ github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
6
+ github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
7
+ github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
8
+ github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
9
+ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
10
+ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
11
+ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
12
+ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
13
+ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
14
+ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
15
+ github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
16
+ github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
17
+ github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
18
+ github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
19
+ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
20
+ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
21
+ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
22
+ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
23
+ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
24
+ github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE=
25
+ github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc=
26
+ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
27
+ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
28
+ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
29
+ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
30
+ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
31
+ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
32
+ github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
33
+ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
34
+ github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
35
+ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
36
+ github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
37
+ github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
38
+ github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
39
+ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
40
+ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
41
+ github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
42
+ github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
43
+ github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
44
+ github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
45
+ github.com/go-git/gcfg/v2 v2.0.2 h1:MY5SIIfTGGEMhdA7d7JePuVVxtKL7Hp+ApGDJAJ7dpo=
46
+ github.com/go-git/gcfg/v2 v2.0.2/go.mod h1:/lv2NsxvhepuMrldsFilrgct6pxzpGdSRC13ydTLSLs=
47
+ github.com/go-git/go-billy/v6 v6.0.0-20251217170237-e9738f50a3cd h1:Gd/f9cGi/3h1JOPaa6er+CkKUGyGX2DBJdFbDKVO+R0=
48
+ github.com/go-git/go-billy/v6 v6.0.0-20251217170237-e9738f50a3cd/go.mod h1:d3XQcsHu1idnquxt48kAv+h+1MUiYKLH/e7LAzjP+pI=
49
+ github.com/go-git/go-git-fixtures/v5 v5.1.2-0.20251229094738-4b14af179146 h1:xYfxAopYyL44ot6dMBIb1Z1njFM0ZBQ99HdIB99KxLs=
50
+ github.com/go-git/go-git-fixtures/v5 v5.1.2-0.20251229094738-4b14af179146/go.mod h1:QE/75B8tBSLNGyUUbA9tw3EGHoFtYOtypa2h8YJxsWI=
51
+ github.com/go-git/go-git/v6 v6.0.0-20251231065035-29ae690a9f19 h1:0lz2eJScP8v5YZQsrEw+ggWC5jNySjg4bIZo5BIh6iI=
52
+ github.com/go-git/go-git/v6 v6.0.0-20251231065035-29ae690a9f19/go.mod h1:L+Evfcs7EdTqxwv854354cb6+++7TFL3hJn3Wy4g+3w=
53
+ github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
54
+ github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
55
+ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
56
+ github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
57
+ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
58
+ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
59
+ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
60
+ github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
61
+ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
62
+ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
63
+ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
64
+ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
65
+ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
66
+ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
67
+ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
68
+ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
69
+ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
70
+ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
71
+ github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
72
+ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
73
+ github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
74
+ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
75
+ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
76
+ github.com/jackc/pgx/v5 v5.7.0 h1:FG6VLIdzvAPhnYqP14sQ2xhFLkiUQHCs6ySqO91kF4g=
77
+ github.com/jackc/pgx/v5 v5.7.0/go.mod h1:awP1KNnjylvpxHuHP63gzjhnGkI1iw+PMoIwvoleN/8=
78
+ github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
79
+ github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
80
+ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
81
+ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
82
+ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
83
+ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
84
+ github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
85
+ github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
86
+ github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
87
+ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
88
+ github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
89
+ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
90
+ github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
91
+ github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
92
+ github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
93
+ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
94
+ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
95
+ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
96
+ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
97
+ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
98
+ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
99
+ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
100
+ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
101
+ github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
102
+ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
103
+ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
104
+ github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
105
+ github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
106
+ github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
107
+ github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
108
+ github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
109
+ github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
110
+ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
111
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
112
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
113
+ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
114
+ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
115
+ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
116
+ github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
117
+ github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
118
+ github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
119
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
120
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
121
+ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
122
+ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
123
+ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
124
+ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
125
+ github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
126
+ github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
127
+ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
128
+ github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
129
+ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
130
+ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
131
+ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
132
+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
133
+ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
134
+ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
135
+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
136
+ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
137
+ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
138
+ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
139
+ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
140
+ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
141
+ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
142
+ github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
143
+ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
144
+ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
145
+ github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
146
+ github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
147
+ github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
148
+ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
149
+ github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
150
+ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
151
+ github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
152
+ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
153
+ github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
154
+ github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
155
+ github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
156
+ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
157
+ github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
158
+ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
159
+ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
160
+ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
161
+ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
162
+ golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
163
+ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
164
+ golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
165
+ golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
166
+ golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
167
+ golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
168
+ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
169
+ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
170
+ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
171
+ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
172
+ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
173
+ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
174
+ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
175
+ golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
176
+ golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
177
+ golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
178
+ golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
179
+ golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
180
+ google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
181
+ google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
182
+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
183
+ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
184
+ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
185
+ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
186
+ gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
187
+ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
188
+ gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
189
+ gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
190
+ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
191
+ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
192
+ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
193
+ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
194
+ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
195
+ nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
196
+ rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
internal/access/config_access/provider.go ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package configaccess
2
+
3
+ import (
4
+ "context"
5
+ "net/http"
6
+ "strings"
7
+ "sync"
8
+
9
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
10
+ sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
11
+ )
12
+
13
+ var registerOnce sync.Once
14
+
15
+ // Register ensures the config-access provider is available to the access manager.
16
+ func Register() {
17
+ registerOnce.Do(func() {
18
+ sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
19
+ })
20
+ }
21
+
22
+ type provider struct {
23
+ name string
24
+ keys map[string]struct{}
25
+ }
26
+
27
+ func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
28
+ name := cfg.Name
29
+ if name == "" {
30
+ name = sdkconfig.DefaultAccessProviderName
31
+ }
32
+ keys := make(map[string]struct{}, len(cfg.APIKeys))
33
+ for _, key := range cfg.APIKeys {
34
+ if key == "" {
35
+ continue
36
+ }
37
+ keys[key] = struct{}{}
38
+ }
39
+ return &provider{name: name, keys: keys}, nil
40
+ }
41
+
42
+ func (p *provider) Identifier() string {
43
+ if p == nil || p.name == "" {
44
+ return sdkconfig.DefaultAccessProviderName
45
+ }
46
+ return p.name
47
+ }
48
+
49
+ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
50
+ if p == nil {
51
+ return nil, sdkaccess.ErrNotHandled
52
+ }
53
+ if len(p.keys) == 0 {
54
+ return nil, sdkaccess.ErrNotHandled
55
+ }
56
+ authHeader := r.Header.Get("Authorization")
57
+ authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
58
+ authHeaderAnthropic := r.Header.Get("X-Api-Key")
59
+ queryKey := ""
60
+ queryAuthToken := ""
61
+ if r.URL != nil {
62
+ queryKey = r.URL.Query().Get("key")
63
+ queryAuthToken = r.URL.Query().Get("auth_token")
64
+ }
65
+ if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
66
+ return nil, sdkaccess.ErrNoCredentials
67
+ }
68
+
69
+ apiKey := extractBearerToken(authHeader)
70
+
71
+ candidates := []struct {
72
+ value string
73
+ source string
74
+ }{
75
+ {apiKey, "authorization"},
76
+ {authHeaderGoogle, "x-goog-api-key"},
77
+ {authHeaderAnthropic, "x-api-key"},
78
+ {queryKey, "query-key"},
79
+ {queryAuthToken, "query-auth-token"},
80
+ }
81
+
82
+ for _, candidate := range candidates {
83
+ if candidate.value == "" {
84
+ continue
85
+ }
86
+ if _, ok := p.keys[candidate.value]; ok {
87
+ return &sdkaccess.Result{
88
+ Provider: p.Identifier(),
89
+ Principal: candidate.value,
90
+ Metadata: map[string]string{
91
+ "source": candidate.source,
92
+ },
93
+ }, nil
94
+ }
95
+ }
96
+
97
+ return nil, sdkaccess.ErrInvalidCredential
98
+ }
99
+
100
+ func extractBearerToken(header string) string {
101
+ if header == "" {
102
+ return ""
103
+ }
104
+ parts := strings.SplitN(header, " ", 2)
105
+ if len(parts) != 2 {
106
+ return header
107
+ }
108
+ if strings.ToLower(parts[0]) != "bearer" {
109
+ return header
110
+ }
111
+ return strings.TrimSpace(parts[1])
112
+ }
internal/access/reconcile.go ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package access
2
+
3
+ import (
4
+ "fmt"
5
+ "reflect"
6
+ "sort"
7
+ "strings"
8
+
9
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
10
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
11
+ sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
12
+ log "github.com/sirupsen/logrus"
13
+ )
14
+
15
+ // ReconcileProviders builds the desired provider list by reusing existing providers when possible
16
+ // and creating or removing providers only when their configuration changed. It returns the final
17
+ // ordered provider slice along with the identifiers of providers that were added, updated, or
18
+ // removed compared to the previous configuration.
19
+ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
20
+ if newCfg == nil {
21
+ return nil, nil, nil, nil, nil
22
+ }
23
+
24
+ existingMap := make(map[string]sdkaccess.Provider, len(existing))
25
+ for _, provider := range existing {
26
+ if provider == nil {
27
+ continue
28
+ }
29
+ existingMap[provider.Identifier()] = provider
30
+ }
31
+
32
+ oldCfgMap := accessProviderMap(oldCfg)
33
+ newEntries := collectProviderEntries(newCfg)
34
+
35
+ result = make([]sdkaccess.Provider, 0, len(newEntries))
36
+ finalIDs := make(map[string]struct{}, len(newEntries))
37
+
38
+ isInlineProvider := func(id string) bool {
39
+ return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
40
+ }
41
+ appendChange := func(list *[]string, id string) {
42
+ if isInlineProvider(id) {
43
+ return
44
+ }
45
+ *list = append(*list, id)
46
+ }
47
+
48
+ for _, providerCfg := range newEntries {
49
+ key := providerIdentifier(providerCfg)
50
+ if key == "" {
51
+ continue
52
+ }
53
+
54
+ forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
55
+ if oldCfgProvider, ok := oldCfgMap[key]; ok {
56
+ isAliased := oldCfgProvider == providerCfg
57
+ if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
58
+ if existingProvider, okExisting := existingMap[key]; okExisting {
59
+ result = append(result, existingProvider)
60
+ finalIDs[key] = struct{}{}
61
+ continue
62
+ }
63
+ }
64
+ }
65
+
66
+ provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
67
+ if buildErr != nil {
68
+ return nil, nil, nil, nil, buildErr
69
+ }
70
+ if _, ok := oldCfgMap[key]; ok {
71
+ if _, existed := existingMap[key]; existed {
72
+ appendChange(&updated, key)
73
+ } else {
74
+ appendChange(&added, key)
75
+ }
76
+ } else {
77
+ appendChange(&added, key)
78
+ }
79
+ result = append(result, provider)
80
+ finalIDs[key] = struct{}{}
81
+ }
82
+
83
+ if len(result) == 0 {
84
+ if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
85
+ key := providerIdentifier(inline)
86
+ if key != "" {
87
+ if oldCfgProvider, ok := oldCfgMap[key]; ok {
88
+ if providerConfigEqual(oldCfgProvider, inline) {
89
+ if existingProvider, okExisting := existingMap[key]; okExisting {
90
+ result = append(result, existingProvider)
91
+ finalIDs[key] = struct{}{}
92
+ goto inlineDone
93
+ }
94
+ }
95
+ }
96
+ provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
97
+ if buildErr != nil {
98
+ return nil, nil, nil, nil, buildErr
99
+ }
100
+ if _, existed := existingMap[key]; existed {
101
+ appendChange(&updated, key)
102
+ } else if _, hadOld := oldCfgMap[key]; hadOld {
103
+ appendChange(&updated, key)
104
+ } else {
105
+ appendChange(&added, key)
106
+ }
107
+ result = append(result, provider)
108
+ finalIDs[key] = struct{}{}
109
+ }
110
+ }
111
+ inlineDone:
112
+ }
113
+
114
+ removedSet := make(map[string]struct{})
115
+ for id := range existingMap {
116
+ if _, ok := finalIDs[id]; !ok {
117
+ if isInlineProvider(id) {
118
+ continue
119
+ }
120
+ removedSet[id] = struct{}{}
121
+ }
122
+ }
123
+
124
+ removed = make([]string, 0, len(removedSet))
125
+ for id := range removedSet {
126
+ removed = append(removed, id)
127
+ }
128
+
129
+ sort.Strings(added)
130
+ sort.Strings(updated)
131
+ sort.Strings(removed)
132
+
133
+ return result, added, updated, removed, nil
134
+ }
135
+
136
+ // ApplyAccessProviders reconciles the configured access providers against the
137
+ // currently registered providers and updates the manager. It logs a concise
138
+ // summary of the detected changes and returns whether any provider changed.
139
+ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Config) (bool, error) {
140
+ if manager == nil || newCfg == nil {
141
+ return false, nil
142
+ }
143
+
144
+ existing := manager.Providers()
145
+ providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
146
+ if err != nil {
147
+ log.Errorf("failed to reconcile request auth providers: %v", err)
148
+ return false, fmt.Errorf("reconciling access providers: %w", err)
149
+ }
150
+
151
+ manager.SetProviders(providers)
152
+
153
+ if len(added)+len(updated)+len(removed) > 0 {
154
+ log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
155
+ log.Debugf("auth providers changes details - added=%v updated=%v removed=%v", added, updated, removed)
156
+ return true, nil
157
+ }
158
+
159
+ log.Debug("auth providers unchanged after config update")
160
+ return false, nil
161
+ }
162
+
163
+ func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
164
+ result := make(map[string]*sdkConfig.AccessProvider)
165
+ if cfg == nil {
166
+ return result
167
+ }
168
+ for i := range cfg.Access.Providers {
169
+ providerCfg := &cfg.Access.Providers[i]
170
+ if providerCfg.Type == "" {
171
+ continue
172
+ }
173
+ key := providerIdentifier(providerCfg)
174
+ if key == "" {
175
+ continue
176
+ }
177
+ result[key] = providerCfg
178
+ }
179
+ if len(result) == 0 && len(cfg.APIKeys) > 0 {
180
+ if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
181
+ if key := providerIdentifier(provider); key != "" {
182
+ result[key] = provider
183
+ }
184
+ }
185
+ }
186
+ return result
187
+ }
188
+
189
+ func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
190
+ entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
191
+ for i := range cfg.Access.Providers {
192
+ providerCfg := &cfg.Access.Providers[i]
193
+ if providerCfg.Type == "" {
194
+ continue
195
+ }
196
+ if key := providerIdentifier(providerCfg); key != "" {
197
+ entries = append(entries, providerCfg)
198
+ }
199
+ }
200
+ if len(entries) == 0 && len(cfg.APIKeys) > 0 {
201
+ if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
202
+ entries = append(entries, inline)
203
+ }
204
+ }
205
+ return entries
206
+ }
207
+
208
+ func providerIdentifier(provider *sdkConfig.AccessProvider) string {
209
+ if provider == nil {
210
+ return ""
211
+ }
212
+ if name := strings.TrimSpace(provider.Name); name != "" {
213
+ return name
214
+ }
215
+ typ := strings.TrimSpace(provider.Type)
216
+ if typ == "" {
217
+ return ""
218
+ }
219
+ if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
220
+ return sdkConfig.DefaultAccessProviderName
221
+ }
222
+ return typ
223
+ }
224
+
225
+ func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
226
+ if a == nil || b == nil {
227
+ return a == nil && b == nil
228
+ }
229
+ if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
230
+ return false
231
+ }
232
+ if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
233
+ return false
234
+ }
235
+ if !stringSetEqual(a.APIKeys, b.APIKeys) {
236
+ return false
237
+ }
238
+ if len(a.Config) != len(b.Config) {
239
+ return false
240
+ }
241
+ if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
242
+ return false
243
+ }
244
+ return true
245
+ }
246
+
247
+ func stringSetEqual(a, b []string) bool {
248
+ if len(a) != len(b) {
249
+ return false
250
+ }
251
+ if len(a) == 0 {
252
+ return true
253
+ }
254
+ seen := make(map[string]int, len(a))
255
+ for _, val := range a {
256
+ seen[val]++
257
+ }
258
+ for _, val := range b {
259
+ count := seen[val]
260
+ if count == 0 {
261
+ return false
262
+ }
263
+ if count == 1 {
264
+ delete(seen, val)
265
+ } else {
266
+ seen[val] = count - 1
267
+ }
268
+ }
269
+ return len(seen) == 0
270
+ }
internal/api/handlers/management/api_tools.go ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "fmt"
7
+ "io"
8
+ "net"
9
+ "net/http"
10
+ "net/url"
11
+ "strings"
12
+ "time"
13
+
14
+ "github.com/gin-gonic/gin"
15
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
16
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
17
+ log "github.com/sirupsen/logrus"
18
+ "golang.org/x/net/proxy"
19
+ "golang.org/x/oauth2"
20
+ "golang.org/x/oauth2/google"
21
+ )
22
+
23
+ const defaultAPICallTimeout = 60 * time.Second
24
+
25
+ const (
26
+ geminiOAuthClientID = "YOUR_CLIENT_ID"
27
+ geminiOAuthClientSecret = "YOUR_CLIENT_SECRET"
28
+ )
29
+
30
+ var geminiOAuthScopes = []string{
31
+ "https://www.googleapis.com/auth/cloud-platform",
32
+ "https://www.googleapis.com/auth/userinfo.email",
33
+ "https://www.googleapis.com/auth/userinfo.profile",
34
+ }
35
+
36
+ type apiCallRequest struct {
37
+ AuthIndexSnake *string `json:"auth_index"`
38
+ AuthIndexCamel *string `json:"authIndex"`
39
+ AuthIndexPascal *string `json:"AuthIndex"`
40
+ Method string `json:"method"`
41
+ URL string `json:"url"`
42
+ Header map[string]string `json:"header"`
43
+ Data string `json:"data"`
44
+ }
45
+
46
+ type apiCallResponse struct {
47
+ StatusCode int `json:"status_code"`
48
+ Header map[string][]string `json:"header"`
49
+ Body string `json:"body"`
50
+ }
51
+
52
+ // APICall makes a generic HTTP request on behalf of the management API caller.
53
+ // It is protected by the management middleware.
54
+ //
55
+ // Endpoint:
56
+ //
57
+ // POST /v0/management/api-call
58
+ //
59
+ // Authentication:
60
+ //
61
+ // Same as other management APIs (requires a management key and remote-management rules).
62
+ // You can provide the key via:
63
+ // - Authorization: Bearer <key>
64
+ // - X-Management-Key: <key>
65
+ //
66
+ // Request JSON:
67
+ // - auth_index / authIndex / AuthIndex (optional):
68
+ // The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
69
+ // If omitted or not found, credential-specific proxy/token substitution is skipped.
70
+ // - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
71
+ // - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
72
+ // - header (optional): Request headers map.
73
+ // Supports magic variable "$TOKEN$" which is replaced using the selected credential:
74
+ // 1) metadata.access_token
75
+ // 2) attributes.api_key
76
+ // 3) metadata.token / metadata.id_token / metadata.cookie
77
+ // Example: {"Authorization":"Bearer $TOKEN$"}.
78
+ // Note: if you need to override the HTTP Host header, set header["Host"].
79
+ // - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
80
+ //
81
+ // Proxy selection (highest priority first):
82
+ // 1. Selected credential proxy_url
83
+ // 2. Global config proxy-url
84
+ // 3. Direct connect (environment proxies are not used)
85
+ //
86
+ // Response JSON (returned with HTTP 200 when the APICall itself succeeds):
87
+ // - status_code: Upstream HTTP status code.
88
+ // - header: Upstream response headers.
89
+ // - body: Upstream response body as string.
90
+ //
91
+ // Example:
92
+ //
93
+ // curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
94
+ // -H "Authorization: Bearer <MANAGEMENT_KEY>" \
95
+ // -H "Content-Type: application/json" \
96
+ // -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
97
+ //
98
+ // curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
99
+ // -H "Authorization: Bearer 831227" \
100
+ // -H "Content-Type: application/json" \
101
+ // -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
102
+ func (h *Handler) APICall(c *gin.Context) {
103
+ var body apiCallRequest
104
+ if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
105
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
106
+ return
107
+ }
108
+
109
+ method := strings.ToUpper(strings.TrimSpace(body.Method))
110
+ if method == "" {
111
+ c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
112
+ return
113
+ }
114
+
115
+ urlStr := strings.TrimSpace(body.URL)
116
+ if urlStr == "" {
117
+ c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
118
+ return
119
+ }
120
+ parsedURL, errParseURL := url.Parse(urlStr)
121
+ if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
122
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
123
+ return
124
+ }
125
+
126
+ authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
127
+ auth := h.authByIndex(authIndex)
128
+
129
+ reqHeaders := body.Header
130
+ if reqHeaders == nil {
131
+ reqHeaders = map[string]string{}
132
+ }
133
+
134
+ var hostOverride string
135
+ var token string
136
+ var tokenResolved bool
137
+ var tokenErr error
138
+ for key, value := range reqHeaders {
139
+ if !strings.Contains(value, "$TOKEN$") {
140
+ continue
141
+ }
142
+ if !tokenResolved {
143
+ token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
144
+ tokenResolved = true
145
+ }
146
+ if auth != nil && token == "" {
147
+ if tokenErr != nil {
148
+ c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
149
+ return
150
+ }
151
+ c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
152
+ return
153
+ }
154
+ if token == "" {
155
+ continue
156
+ }
157
+ reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
158
+ }
159
+
160
+ var requestBody io.Reader
161
+ if body.Data != "" {
162
+ requestBody = strings.NewReader(body.Data)
163
+ }
164
+
165
+ req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
166
+ if errNewRequest != nil {
167
+ c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
168
+ return
169
+ }
170
+
171
+ for key, value := range reqHeaders {
172
+ if strings.EqualFold(key, "host") {
173
+ hostOverride = strings.TrimSpace(value)
174
+ continue
175
+ }
176
+ req.Header.Set(key, value)
177
+ }
178
+ if hostOverride != "" {
179
+ req.Host = hostOverride
180
+ }
181
+
182
+ httpClient := &http.Client{
183
+ Timeout: defaultAPICallTimeout,
184
+ }
185
+ httpClient.Transport = h.apiCallTransport(auth)
186
+
187
+ resp, errDo := httpClient.Do(req)
188
+ if errDo != nil {
189
+ log.WithError(errDo).Debug("management APICall request failed")
190
+ c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
191
+ return
192
+ }
193
+ defer func() {
194
+ if errClose := resp.Body.Close(); errClose != nil {
195
+ log.Errorf("response body close error: %v", errClose)
196
+ }
197
+ }()
198
+
199
+ respBody, errReadAll := io.ReadAll(resp.Body)
200
+ if errReadAll != nil {
201
+ c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
202
+ return
203
+ }
204
+
205
+ c.JSON(http.StatusOK, apiCallResponse{
206
+ StatusCode: resp.StatusCode,
207
+ Header: resp.Header,
208
+ Body: string(respBody),
209
+ })
210
+ }
211
+
212
+ func firstNonEmptyString(values ...*string) string {
213
+ for _, v := range values {
214
+ if v == nil {
215
+ continue
216
+ }
217
+ if out := strings.TrimSpace(*v); out != "" {
218
+ return out
219
+ }
220
+ }
221
+ return ""
222
+ }
223
+
224
+ func tokenValueForAuth(auth *coreauth.Auth) string {
225
+ if auth == nil {
226
+ return ""
227
+ }
228
+ if v := tokenValueFromMetadata(auth.Metadata); v != "" {
229
+ return v
230
+ }
231
+ if auth.Attributes != nil {
232
+ if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
233
+ return v
234
+ }
235
+ }
236
+ if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
237
+ if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
238
+ return v
239
+ }
240
+ }
241
+ return ""
242
+ }
243
+
244
+ func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
245
+ if auth == nil {
246
+ return "", nil
247
+ }
248
+
249
+ provider := strings.ToLower(strings.TrimSpace(auth.Provider))
250
+ if provider == "gemini-cli" {
251
+ token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
252
+ return token, errToken
253
+ }
254
+
255
+ return tokenValueForAuth(auth), nil
256
+ }
257
+
258
+ func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
259
+ if ctx == nil {
260
+ ctx = context.Background()
261
+ }
262
+ if auth == nil {
263
+ return "", nil
264
+ }
265
+
266
+ metadata, updater := geminiOAuthMetadata(auth)
267
+ if len(metadata) == 0 {
268
+ return "", fmt.Errorf("gemini oauth metadata missing")
269
+ }
270
+
271
+ base := make(map[string]any)
272
+ if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
273
+ base = cloneMap(tokenRaw)
274
+ }
275
+
276
+ var token oauth2.Token
277
+ if len(base) > 0 {
278
+ if raw, errMarshal := json.Marshal(base); errMarshal == nil {
279
+ _ = json.Unmarshal(raw, &token)
280
+ }
281
+ }
282
+
283
+ if token.AccessToken == "" {
284
+ token.AccessToken = stringValue(metadata, "access_token")
285
+ }
286
+ if token.RefreshToken == "" {
287
+ token.RefreshToken = stringValue(metadata, "refresh_token")
288
+ }
289
+ if token.TokenType == "" {
290
+ token.TokenType = stringValue(metadata, "token_type")
291
+ }
292
+ if token.Expiry.IsZero() {
293
+ if expiry := stringValue(metadata, "expiry"); expiry != "" {
294
+ if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
295
+ token.Expiry = ts
296
+ }
297
+ }
298
+ }
299
+
300
+ conf := &oauth2.Config{
301
+ ClientID: geminiOAuthClientID,
302
+ ClientSecret: geminiOAuthClientSecret,
303
+ Scopes: geminiOAuthScopes,
304
+ Endpoint: google.Endpoint,
305
+ }
306
+
307
+ ctxToken := ctx
308
+ httpClient := &http.Client{
309
+ Timeout: defaultAPICallTimeout,
310
+ Transport: h.apiCallTransport(auth),
311
+ }
312
+ ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
313
+
314
+ src := conf.TokenSource(ctxToken, &token)
315
+ currentToken, errToken := src.Token()
316
+ if errToken != nil {
317
+ return "", errToken
318
+ }
319
+
320
+ merged := buildOAuthTokenMap(base, currentToken)
321
+ fields := buildOAuthTokenFields(currentToken, merged)
322
+ if updater != nil {
323
+ updater(fields)
324
+ }
325
+ return strings.TrimSpace(currentToken.AccessToken), nil
326
+ }
327
+
328
+ func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
329
+ if auth == nil {
330
+ return nil, nil
331
+ }
332
+ if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
333
+ snapshot := shared.MetadataSnapshot()
334
+ return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
335
+ }
336
+ return auth.Metadata, func(fields map[string]any) {
337
+ if auth.Metadata == nil {
338
+ auth.Metadata = make(map[string]any)
339
+ }
340
+ for k, v := range fields {
341
+ auth.Metadata[k] = v
342
+ }
343
+ }
344
+ }
345
+
346
+ func stringValue(metadata map[string]any, key string) string {
347
+ if len(metadata) == 0 || key == "" {
348
+ return ""
349
+ }
350
+ if v, ok := metadata[key].(string); ok {
351
+ return strings.TrimSpace(v)
352
+ }
353
+ return ""
354
+ }
355
+
356
+ func cloneMap(in map[string]any) map[string]any {
357
+ if len(in) == 0 {
358
+ return nil
359
+ }
360
+ out := make(map[string]any, len(in))
361
+ for k, v := range in {
362
+ out[k] = v
363
+ }
364
+ return out
365
+ }
366
+
367
+ func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
368
+ merged := cloneMap(base)
369
+ if merged == nil {
370
+ merged = make(map[string]any)
371
+ }
372
+ if tok == nil {
373
+ return merged
374
+ }
375
+ if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
376
+ var tokenMap map[string]any
377
+ if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
378
+ for k, v := range tokenMap {
379
+ merged[k] = v
380
+ }
381
+ }
382
+ }
383
+ return merged
384
+ }
385
+
386
+ func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
387
+ fields := make(map[string]any, 5)
388
+ if tok != nil && tok.AccessToken != "" {
389
+ fields["access_token"] = tok.AccessToken
390
+ }
391
+ if tok != nil && tok.TokenType != "" {
392
+ fields["token_type"] = tok.TokenType
393
+ }
394
+ if tok != nil && tok.RefreshToken != "" {
395
+ fields["refresh_token"] = tok.RefreshToken
396
+ }
397
+ if tok != nil && !tok.Expiry.IsZero() {
398
+ fields["expiry"] = tok.Expiry.Format(time.RFC3339)
399
+ }
400
+ if len(merged) > 0 {
401
+ fields["token"] = cloneMap(merged)
402
+ }
403
+ return fields
404
+ }
405
+
406
+ func tokenValueFromMetadata(metadata map[string]any) string {
407
+ if len(metadata) == 0 {
408
+ return ""
409
+ }
410
+ if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
411
+ return strings.TrimSpace(v)
412
+ }
413
+ if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
414
+ return strings.TrimSpace(v)
415
+ }
416
+ if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
417
+ switch typed := tokenRaw.(type) {
418
+ case string:
419
+ if v := strings.TrimSpace(typed); v != "" {
420
+ return v
421
+ }
422
+ case map[string]any:
423
+ if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
424
+ return strings.TrimSpace(v)
425
+ }
426
+ if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
427
+ return strings.TrimSpace(v)
428
+ }
429
+ case map[string]string:
430
+ if v := strings.TrimSpace(typed["access_token"]); v != "" {
431
+ return v
432
+ }
433
+ if v := strings.TrimSpace(typed["accessToken"]); v != "" {
434
+ return v
435
+ }
436
+ }
437
+ }
438
+ if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
439
+ return strings.TrimSpace(v)
440
+ }
441
+ if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
442
+ return strings.TrimSpace(v)
443
+ }
444
+ if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
445
+ return strings.TrimSpace(v)
446
+ }
447
+ return ""
448
+ }
449
+
450
+ func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
451
+ authIndex = strings.TrimSpace(authIndex)
452
+ if authIndex == "" || h == nil || h.authManager == nil {
453
+ return nil
454
+ }
455
+ auths := h.authManager.List()
456
+ for _, auth := range auths {
457
+ if auth == nil {
458
+ continue
459
+ }
460
+ auth.EnsureIndex()
461
+ if auth.Index == authIndex {
462
+ return auth
463
+ }
464
+ }
465
+ return nil
466
+ }
467
+
468
+ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
469
+ var proxyCandidates []string
470
+ if auth != nil {
471
+ if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
472
+ proxyCandidates = append(proxyCandidates, proxyStr)
473
+ }
474
+ }
475
+ if h != nil && h.cfg != nil {
476
+ if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
477
+ proxyCandidates = append(proxyCandidates, proxyStr)
478
+ }
479
+ }
480
+
481
+ for _, proxyStr := range proxyCandidates {
482
+ if transport := buildProxyTransport(proxyStr); transport != nil {
483
+ return transport
484
+ }
485
+ }
486
+
487
+ transport, ok := http.DefaultTransport.(*http.Transport)
488
+ if !ok || transport == nil {
489
+ return &http.Transport{Proxy: nil}
490
+ }
491
+ clone := transport.Clone()
492
+ clone.Proxy = nil
493
+ return clone
494
+ }
495
+
496
+ func buildProxyTransport(proxyStr string) *http.Transport {
497
+ proxyStr = strings.TrimSpace(proxyStr)
498
+ if proxyStr == "" {
499
+ return nil
500
+ }
501
+
502
+ proxyURL, errParse := url.Parse(proxyStr)
503
+ if errParse != nil {
504
+ log.WithError(errParse).Debug("parse proxy URL failed")
505
+ return nil
506
+ }
507
+ if proxyURL.Scheme == "" || proxyURL.Host == "" {
508
+ log.Debug("proxy URL missing scheme/host")
509
+ return nil
510
+ }
511
+
512
+ if proxyURL.Scheme == "socks5" {
513
+ var proxyAuth *proxy.Auth
514
+ if proxyURL.User != nil {
515
+ username := proxyURL.User.Username()
516
+ password, _ := proxyURL.User.Password()
517
+ proxyAuth = &proxy.Auth{User: username, Password: password}
518
+ }
519
+ dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
520
+ if errSOCKS5 != nil {
521
+ log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
522
+ return nil
523
+ }
524
+ return &http.Transport{
525
+ Proxy: nil,
526
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
527
+ return dialer.Dial(network, addr)
528
+ },
529
+ }
530
+ }
531
+
532
+ if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
533
+ return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
534
+ }
535
+
536
+ log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
537
+ return nil
538
+ }
internal/api/handlers/management/auth_files.go ADDED
@@ -0,0 +1,2606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "bytes"
5
+ "context"
6
+ "crypto/rand"
7
+ "crypto/sha256"
8
+ "encoding/base64"
9
+ "encoding/json"
10
+ "errors"
11
+ "fmt"
12
+ "io"
13
+ "net"
14
+ "net/http"
15
+ "net/url"
16
+ "os"
17
+ "path/filepath"
18
+ "sort"
19
+ "strconv"
20
+ "strings"
21
+ "sync"
22
+ "time"
23
+
24
+ "github.com/gin-gonic/gin"
25
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
26
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
27
+ geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
28
+ iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
29
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
30
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
31
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
32
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
33
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
34
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
35
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
36
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
37
+ log "github.com/sirupsen/logrus"
38
+ "github.com/tidwall/gjson"
39
+ "golang.org/x/oauth2"
40
+ "golang.org/x/oauth2/google"
41
+ )
42
+
43
+ var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
44
+
45
+ const (
46
+ anthropicCallbackPort = 54545
47
+ geminiCallbackPort = 8085
48
+ codexCallbackPort = 1455
49
+ geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
50
+ geminiCLIVersion = "v1internal"
51
+ geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
52
+ geminiCLIApiClient = "gl-node/22.17.0"
53
+ geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
54
+ )
55
+
56
+ type callbackForwarder struct {
57
+ provider string
58
+ server *http.Server
59
+ done chan struct{}
60
+ }
61
+
62
+ var (
63
+ callbackForwardersMu sync.Mutex
64
+ callbackForwarders = make(map[int]*callbackForwarder)
65
+ )
66
+
67
+ func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
68
+ if len(meta) == 0 {
69
+ return time.Time{}, false
70
+ }
71
+ for _, key := range lastRefreshKeys {
72
+ if val, ok := meta[key]; ok {
73
+ if ts, ok1 := parseLastRefreshValue(val); ok1 {
74
+ return ts, true
75
+ }
76
+ }
77
+ }
78
+ return time.Time{}, false
79
+ }
80
+
81
+ func parseLastRefreshValue(v any) (time.Time, bool) {
82
+ switch val := v.(type) {
83
+ case string:
84
+ s := strings.TrimSpace(val)
85
+ if s == "" {
86
+ return time.Time{}, false
87
+ }
88
+ layouts := []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z07:00"}
89
+ for _, layout := range layouts {
90
+ if ts, err := time.Parse(layout, s); err == nil {
91
+ return ts.UTC(), true
92
+ }
93
+ }
94
+ if unix, err := strconv.ParseInt(s, 10, 64); err == nil && unix > 0 {
95
+ return time.Unix(unix, 0).UTC(), true
96
+ }
97
+ case float64:
98
+ if val <= 0 {
99
+ return time.Time{}, false
100
+ }
101
+ return time.Unix(int64(val), 0).UTC(), true
102
+ case int64:
103
+ if val <= 0 {
104
+ return time.Time{}, false
105
+ }
106
+ return time.Unix(val, 0).UTC(), true
107
+ case int:
108
+ if val <= 0 {
109
+ return time.Time{}, false
110
+ }
111
+ return time.Unix(int64(val), 0).UTC(), true
112
+ case json.Number:
113
+ if i, err := val.Int64(); err == nil && i > 0 {
114
+ return time.Unix(i, 0).UTC(), true
115
+ }
116
+ }
117
+ return time.Time{}, false
118
+ }
119
+
120
+ func isWebUIRequest(c *gin.Context) bool {
121
+ raw := strings.TrimSpace(c.Query("is_webui"))
122
+ if raw == "" {
123
+ return false
124
+ }
125
+ switch strings.ToLower(raw) {
126
+ case "1", "true", "yes", "on":
127
+ return true
128
+ default:
129
+ return false
130
+ }
131
+ }
132
+
133
+ func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
134
+ callbackForwardersMu.Lock()
135
+ prev := callbackForwarders[port]
136
+ if prev != nil {
137
+ delete(callbackForwarders, port)
138
+ }
139
+ callbackForwardersMu.Unlock()
140
+
141
+ if prev != nil {
142
+ stopForwarderInstance(port, prev)
143
+ }
144
+
145
+ addr := fmt.Sprintf("127.0.0.1:%d", port)
146
+ ln, err := net.Listen("tcp", addr)
147
+ if err != nil {
148
+ return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
149
+ }
150
+
151
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152
+ target := targetBase
153
+ if raw := r.URL.RawQuery; raw != "" {
154
+ if strings.Contains(target, "?") {
155
+ target = target + "&" + raw
156
+ } else {
157
+ target = target + "?" + raw
158
+ }
159
+ }
160
+ w.Header().Set("Cache-Control", "no-store")
161
+ http.Redirect(w, r, target, http.StatusFound)
162
+ })
163
+
164
+ srv := &http.Server{
165
+ Handler: handler,
166
+ ReadHeaderTimeout: 5 * time.Second,
167
+ WriteTimeout: 5 * time.Second,
168
+ }
169
+ done := make(chan struct{})
170
+
171
+ go func() {
172
+ if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
173
+ log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
174
+ }
175
+ close(done)
176
+ }()
177
+
178
+ forwarder := &callbackForwarder{
179
+ provider: provider,
180
+ server: srv,
181
+ done: done,
182
+ }
183
+
184
+ callbackForwardersMu.Lock()
185
+ callbackForwarders[port] = forwarder
186
+ callbackForwardersMu.Unlock()
187
+
188
+ log.Infof("callback forwarder for %s listening on %s", provider, addr)
189
+
190
+ return forwarder, nil
191
+ }
192
+
193
+ func stopCallbackForwarder(port int) {
194
+ callbackForwardersMu.Lock()
195
+ forwarder := callbackForwarders[port]
196
+ if forwarder != nil {
197
+ delete(callbackForwarders, port)
198
+ }
199
+ callbackForwardersMu.Unlock()
200
+
201
+ stopForwarderInstance(port, forwarder)
202
+ }
203
+
204
+ func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
205
+ if forwarder == nil {
206
+ return
207
+ }
208
+ callbackForwardersMu.Lock()
209
+ if current := callbackForwarders[port]; current == forwarder {
210
+ delete(callbackForwarders, port)
211
+ }
212
+ callbackForwardersMu.Unlock()
213
+
214
+ stopForwarderInstance(port, forwarder)
215
+ }
216
+
217
+ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
218
+ if forwarder == nil || forwarder.server == nil {
219
+ return
220
+ }
221
+
222
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
223
+ defer cancel()
224
+
225
+ if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
226
+ log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
227
+ }
228
+
229
+ select {
230
+ case <-forwarder.done:
231
+ case <-time.After(2 * time.Second):
232
+ }
233
+
234
+ log.Infof("callback forwarder on port %d stopped", port)
235
+ }
236
+
237
+ func sanitizeAntigravityFileName(email string) string {
238
+ if strings.TrimSpace(email) == "" {
239
+ return "antigravity.json"
240
+ }
241
+ replacer := strings.NewReplacer("@", "_", ".", "_")
242
+ return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
243
+ }
244
+
245
+ func (h *Handler) managementCallbackURL(path string) (string, error) {
246
+ if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
247
+ return "", fmt.Errorf("server port is not configured")
248
+ }
249
+ if !strings.HasPrefix(path, "/") {
250
+ path = "/" + path
251
+ }
252
+ scheme := "http"
253
+ if h.cfg.TLS.Enable {
254
+ scheme = "https"
255
+ }
256
+ return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
257
+ }
258
+
259
+ func (h *Handler) ListAuthFiles(c *gin.Context) {
260
+ if h == nil {
261
+ c.JSON(500, gin.H{"error": "handler not initialized"})
262
+ return
263
+ }
264
+ if h.authManager == nil {
265
+ h.listAuthFilesFromDisk(c)
266
+ return
267
+ }
268
+ auths := h.authManager.List()
269
+ files := make([]gin.H, 0, len(auths))
270
+ for _, auth := range auths {
271
+ if entry := h.buildAuthFileEntry(auth); entry != nil {
272
+ files = append(files, entry)
273
+ }
274
+ }
275
+ sort.Slice(files, func(i, j int) bool {
276
+ nameI, _ := files[i]["name"].(string)
277
+ nameJ, _ := files[j]["name"].(string)
278
+ return strings.ToLower(nameI) < strings.ToLower(nameJ)
279
+ })
280
+ c.JSON(200, gin.H{"files": files})
281
+ }
282
+
283
+ // GetAuthFileModels returns the models supported by a specific auth file
284
+ func (h *Handler) GetAuthFileModels(c *gin.Context) {
285
+ name := c.Query("name")
286
+ if name == "" {
287
+ c.JSON(400, gin.H{"error": "name is required"})
288
+ return
289
+ }
290
+
291
+ // Try to find auth ID via authManager
292
+ var authID string
293
+ if h.authManager != nil {
294
+ auths := h.authManager.List()
295
+ for _, auth := range auths {
296
+ if auth.FileName == name || auth.ID == name {
297
+ authID = auth.ID
298
+ break
299
+ }
300
+ }
301
+ }
302
+
303
+ if authID == "" {
304
+ authID = name // fallback to filename as ID
305
+ }
306
+
307
+ // Get models from registry
308
+ reg := registry.GetGlobalRegistry()
309
+ models := reg.GetModelsForClient(authID)
310
+
311
+ result := make([]gin.H, 0, len(models))
312
+ for _, m := range models {
313
+ entry := gin.H{
314
+ "id": m.ID,
315
+ }
316
+ if m.DisplayName != "" {
317
+ entry["display_name"] = m.DisplayName
318
+ }
319
+ if m.Type != "" {
320
+ entry["type"] = m.Type
321
+ }
322
+ if m.OwnedBy != "" {
323
+ entry["owned_by"] = m.OwnedBy
324
+ }
325
+ result = append(result, entry)
326
+ }
327
+
328
+ c.JSON(200, gin.H{"models": result})
329
+ }
330
+
331
+ // List auth files from disk when the auth manager is unavailable.
332
+ func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
333
+ entries, err := os.ReadDir(h.cfg.AuthDir)
334
+ if err != nil {
335
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
336
+ return
337
+ }
338
+ files := make([]gin.H, 0)
339
+ for _, e := range entries {
340
+ if e.IsDir() {
341
+ continue
342
+ }
343
+ name := e.Name()
344
+ if !strings.HasSuffix(strings.ToLower(name), ".json") {
345
+ continue
346
+ }
347
+ if info, errInfo := e.Info(); errInfo == nil {
348
+ fileData := gin.H{"name": name, "size": info.Size(), "modtime": info.ModTime()}
349
+
350
+ // Read file to get type field
351
+ full := filepath.Join(h.cfg.AuthDir, name)
352
+ if data, errRead := os.ReadFile(full); errRead == nil {
353
+ typeValue := gjson.GetBytes(data, "type").String()
354
+ emailValue := gjson.GetBytes(data, "email").String()
355
+ fileData["type"] = typeValue
356
+ fileData["email"] = emailValue
357
+ }
358
+
359
+ files = append(files, fileData)
360
+ }
361
+ }
362
+ c.JSON(200, gin.H{"files": files})
363
+ }
364
+
365
+ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
366
+ if auth == nil {
367
+ return nil
368
+ }
369
+ auth.EnsureIndex()
370
+ runtimeOnly := isRuntimeOnlyAuth(auth)
371
+ if runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled) {
372
+ return nil
373
+ }
374
+ path := strings.TrimSpace(authAttribute(auth, "path"))
375
+ if path == "" && !runtimeOnly {
376
+ return nil
377
+ }
378
+ name := strings.TrimSpace(auth.FileName)
379
+ if name == "" {
380
+ name = auth.ID
381
+ }
382
+ entry := gin.H{
383
+ "id": auth.ID,
384
+ "auth_index": auth.Index,
385
+ "name": name,
386
+ "type": strings.TrimSpace(auth.Provider),
387
+ "provider": strings.TrimSpace(auth.Provider),
388
+ "label": auth.Label,
389
+ "status": auth.Status,
390
+ "status_message": auth.StatusMessage,
391
+ "disabled": auth.Disabled,
392
+ "unavailable": auth.Unavailable,
393
+ "runtime_only": runtimeOnly,
394
+ "source": "memory",
395
+ "size": int64(0),
396
+ }
397
+ if email := authEmail(auth); email != "" {
398
+ entry["email"] = email
399
+ }
400
+ if accountType, account := auth.AccountInfo(); accountType != "" || account != "" {
401
+ if accountType != "" {
402
+ entry["account_type"] = accountType
403
+ }
404
+ if account != "" {
405
+ entry["account"] = account
406
+ }
407
+ }
408
+ if !auth.CreatedAt.IsZero() {
409
+ entry["created_at"] = auth.CreatedAt
410
+ }
411
+ if !auth.UpdatedAt.IsZero() {
412
+ entry["modtime"] = auth.UpdatedAt
413
+ entry["updated_at"] = auth.UpdatedAt
414
+ }
415
+ if !auth.LastRefreshedAt.IsZero() {
416
+ entry["last_refresh"] = auth.LastRefreshedAt
417
+ }
418
+ if path != "" {
419
+ entry["path"] = path
420
+ entry["source"] = "file"
421
+ if info, err := os.Stat(path); err == nil {
422
+ entry["size"] = info.Size()
423
+ entry["modtime"] = info.ModTime()
424
+ } else if os.IsNotExist(err) {
425
+ // Hide credentials removed from disk but still lingering in memory.
426
+ if !runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled || strings.EqualFold(strings.TrimSpace(auth.StatusMessage), "removed via management api")) {
427
+ return nil
428
+ }
429
+ entry["source"] = "memory"
430
+ } else {
431
+ log.WithError(err).Warnf("failed to stat auth file %s", path)
432
+ }
433
+ }
434
+ if claims := extractCodexIDTokenClaims(auth); claims != nil {
435
+ entry["id_token"] = claims
436
+ }
437
+ return entry
438
+ }
439
+
440
+ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
441
+ if auth == nil || auth.Metadata == nil {
442
+ return nil
443
+ }
444
+ if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
445
+ return nil
446
+ }
447
+ idTokenRaw, ok := auth.Metadata["id_token"].(string)
448
+ if !ok {
449
+ return nil
450
+ }
451
+ idToken := strings.TrimSpace(idTokenRaw)
452
+ if idToken == "" {
453
+ return nil
454
+ }
455
+ claims, err := codex.ParseJWTToken(idToken)
456
+ if err != nil || claims == nil {
457
+ return nil
458
+ }
459
+
460
+ result := gin.H{}
461
+ if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
462
+ result["chatgpt_account_id"] = v
463
+ }
464
+ if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
465
+ result["plan_type"] = v
466
+ }
467
+
468
+ if len(result) == 0 {
469
+ return nil
470
+ }
471
+ return result
472
+ }
473
+
474
+ func authEmail(auth *coreauth.Auth) string {
475
+ if auth == nil {
476
+ return ""
477
+ }
478
+ if auth.Metadata != nil {
479
+ if v, ok := auth.Metadata["email"].(string); ok {
480
+ return strings.TrimSpace(v)
481
+ }
482
+ }
483
+ if auth.Attributes != nil {
484
+ if v := strings.TrimSpace(auth.Attributes["email"]); v != "" {
485
+ return v
486
+ }
487
+ if v := strings.TrimSpace(auth.Attributes["account_email"]); v != "" {
488
+ return v
489
+ }
490
+ }
491
+ return ""
492
+ }
493
+
494
+ func authAttribute(auth *coreauth.Auth, key string) string {
495
+ if auth == nil || len(auth.Attributes) == 0 {
496
+ return ""
497
+ }
498
+ return auth.Attributes[key]
499
+ }
500
+
501
+ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
502
+ if auth == nil || len(auth.Attributes) == 0 {
503
+ return false
504
+ }
505
+ return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
506
+ }
507
+
508
+ // Download single auth file by name
509
+ func (h *Handler) DownloadAuthFile(c *gin.Context) {
510
+ name := c.Query("name")
511
+ if name == "" || strings.Contains(name, string(os.PathSeparator)) {
512
+ c.JSON(400, gin.H{"error": "invalid name"})
513
+ return
514
+ }
515
+ if !strings.HasSuffix(strings.ToLower(name), ".json") {
516
+ c.JSON(400, gin.H{"error": "name must end with .json"})
517
+ return
518
+ }
519
+ full := filepath.Join(h.cfg.AuthDir, name)
520
+ data, err := os.ReadFile(full)
521
+ if err != nil {
522
+ if os.IsNotExist(err) {
523
+ c.JSON(404, gin.H{"error": "file not found"})
524
+ } else {
525
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
526
+ }
527
+ return
528
+ }
529
+ c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", name))
530
+ c.Data(200, "application/json", data)
531
+ }
532
+
533
+ // Upload auth file: multipart or raw JSON with ?name=
534
+ func (h *Handler) UploadAuthFile(c *gin.Context) {
535
+ if h.authManager == nil {
536
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
537
+ return
538
+ }
539
+ ctx := c.Request.Context()
540
+ if file, err := c.FormFile("file"); err == nil && file != nil {
541
+ name := filepath.Base(file.Filename)
542
+ if !strings.HasSuffix(strings.ToLower(name), ".json") {
543
+ c.JSON(400, gin.H{"error": "file must be .json"})
544
+ return
545
+ }
546
+ dst := filepath.Join(h.cfg.AuthDir, name)
547
+ if !filepath.IsAbs(dst) {
548
+ if abs, errAbs := filepath.Abs(dst); errAbs == nil {
549
+ dst = abs
550
+ }
551
+ }
552
+ if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
553
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
554
+ return
555
+ }
556
+ data, errRead := os.ReadFile(dst)
557
+ if errRead != nil {
558
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
559
+ return
560
+ }
561
+ if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
562
+ c.JSON(500, gin.H{"error": errReg.Error()})
563
+ return
564
+ }
565
+ c.JSON(200, gin.H{"status": "ok"})
566
+ return
567
+ }
568
+ name := c.Query("name")
569
+ if name == "" || strings.Contains(name, string(os.PathSeparator)) {
570
+ c.JSON(400, gin.H{"error": "invalid name"})
571
+ return
572
+ }
573
+ if !strings.HasSuffix(strings.ToLower(name), ".json") {
574
+ c.JSON(400, gin.H{"error": "name must end with .json"})
575
+ return
576
+ }
577
+ data, err := io.ReadAll(c.Request.Body)
578
+ if err != nil {
579
+ c.JSON(400, gin.H{"error": "failed to read body"})
580
+ return
581
+ }
582
+ dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
583
+ if !filepath.IsAbs(dst) {
584
+ if abs, errAbs := filepath.Abs(dst); errAbs == nil {
585
+ dst = abs
586
+ }
587
+ }
588
+ if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
589
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
590
+ return
591
+ }
592
+ if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
593
+ c.JSON(500, gin.H{"error": err.Error()})
594
+ return
595
+ }
596
+ c.JSON(200, gin.H{"status": "ok"})
597
+ }
598
+
599
+ // Delete auth files: single by name or all
600
+ func (h *Handler) DeleteAuthFile(c *gin.Context) {
601
+ if h.authManager == nil {
602
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
603
+ return
604
+ }
605
+ ctx := c.Request.Context()
606
+ if all := c.Query("all"); all == "true" || all == "1" || all == "*" {
607
+ entries, err := os.ReadDir(h.cfg.AuthDir)
608
+ if err != nil {
609
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
610
+ return
611
+ }
612
+ deleted := 0
613
+ for _, e := range entries {
614
+ if e.IsDir() {
615
+ continue
616
+ }
617
+ name := e.Name()
618
+ if !strings.HasSuffix(strings.ToLower(name), ".json") {
619
+ continue
620
+ }
621
+ full := filepath.Join(h.cfg.AuthDir, name)
622
+ if !filepath.IsAbs(full) {
623
+ if abs, errAbs := filepath.Abs(full); errAbs == nil {
624
+ full = abs
625
+ }
626
+ }
627
+ if err = os.Remove(full); err == nil {
628
+ if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
629
+ c.JSON(500, gin.H{"error": errDel.Error()})
630
+ return
631
+ }
632
+ deleted++
633
+ h.disableAuth(ctx, full)
634
+ }
635
+ }
636
+ c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
637
+ return
638
+ }
639
+ name := c.Query("name")
640
+ if name == "" || strings.Contains(name, string(os.PathSeparator)) {
641
+ c.JSON(400, gin.H{"error": "invalid name"})
642
+ return
643
+ }
644
+ full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
645
+ if !filepath.IsAbs(full) {
646
+ if abs, errAbs := filepath.Abs(full); errAbs == nil {
647
+ full = abs
648
+ }
649
+ }
650
+ if err := os.Remove(full); err != nil {
651
+ if os.IsNotExist(err) {
652
+ c.JSON(404, gin.H{"error": "file not found"})
653
+ } else {
654
+ c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
655
+ }
656
+ return
657
+ }
658
+ if err := h.deleteTokenRecord(ctx, full); err != nil {
659
+ c.JSON(500, gin.H{"error": err.Error()})
660
+ return
661
+ }
662
+ h.disableAuth(ctx, full)
663
+ c.JSON(200, gin.H{"status": "ok"})
664
+ }
665
+
666
+ func (h *Handler) authIDForPath(path string) string {
667
+ path = strings.TrimSpace(path)
668
+ if path == "" {
669
+ return ""
670
+ }
671
+ if h == nil || h.cfg == nil {
672
+ return path
673
+ }
674
+ authDir := strings.TrimSpace(h.cfg.AuthDir)
675
+ if authDir == "" {
676
+ return path
677
+ }
678
+ if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
679
+ return rel
680
+ }
681
+ return path
682
+ }
683
+
684
+ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
685
+ if h.authManager == nil {
686
+ return nil
687
+ }
688
+ if path == "" {
689
+ return fmt.Errorf("auth path is empty")
690
+ }
691
+ if data == nil {
692
+ var err error
693
+ data, err = os.ReadFile(path)
694
+ if err != nil {
695
+ return fmt.Errorf("failed to read auth file: %w", err)
696
+ }
697
+ }
698
+ metadata := make(map[string]any)
699
+ if err := json.Unmarshal(data, &metadata); err != nil {
700
+ return fmt.Errorf("invalid auth file: %w", err)
701
+ }
702
+ provider, _ := metadata["type"].(string)
703
+ if provider == "" {
704
+ provider = "unknown"
705
+ }
706
+ label := provider
707
+ if email, ok := metadata["email"].(string); ok && email != "" {
708
+ label = email
709
+ }
710
+ lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
711
+
712
+ authID := h.authIDForPath(path)
713
+ if authID == "" {
714
+ authID = path
715
+ }
716
+ attr := map[string]string{
717
+ "path": path,
718
+ "source": path,
719
+ }
720
+ auth := &coreauth.Auth{
721
+ ID: authID,
722
+ Provider: provider,
723
+ FileName: filepath.Base(path),
724
+ Label: label,
725
+ Status: coreauth.StatusActive,
726
+ Attributes: attr,
727
+ Metadata: metadata,
728
+ CreatedAt: time.Now(),
729
+ UpdatedAt: time.Now(),
730
+ }
731
+ if hasLastRefresh {
732
+ auth.LastRefreshedAt = lastRefresh
733
+ }
734
+ if existing, ok := h.authManager.GetByID(authID); ok {
735
+ auth.CreatedAt = existing.CreatedAt
736
+ if !hasLastRefresh {
737
+ auth.LastRefreshedAt = existing.LastRefreshedAt
738
+ }
739
+ auth.NextRefreshAfter = existing.NextRefreshAfter
740
+ auth.Runtime = existing.Runtime
741
+ _, err := h.authManager.Update(ctx, auth)
742
+ return err
743
+ }
744
+ _, err := h.authManager.Register(ctx, auth)
745
+ return err
746
+ }
747
+
748
+ func (h *Handler) disableAuth(ctx context.Context, id string) {
749
+ if h == nil || h.authManager == nil {
750
+ return
751
+ }
752
+ authID := h.authIDForPath(id)
753
+ if authID == "" {
754
+ authID = strings.TrimSpace(id)
755
+ }
756
+ if authID == "" {
757
+ return
758
+ }
759
+ if auth, ok := h.authManager.GetByID(authID); ok {
760
+ auth.Disabled = true
761
+ auth.Status = coreauth.StatusDisabled
762
+ auth.StatusMessage = "removed via management API"
763
+ auth.UpdatedAt = time.Now()
764
+ _, _ = h.authManager.Update(ctx, auth)
765
+ }
766
+ }
767
+
768
+ func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
769
+ if strings.TrimSpace(path) == "" {
770
+ return fmt.Errorf("auth path is empty")
771
+ }
772
+ store := h.tokenStoreWithBaseDir()
773
+ if store == nil {
774
+ return fmt.Errorf("token store unavailable")
775
+ }
776
+ return store.Delete(ctx, path)
777
+ }
778
+
779
+ func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
780
+ if h == nil {
781
+ return nil
782
+ }
783
+ store := h.tokenStore
784
+ if store == nil {
785
+ store = sdkAuth.GetTokenStore()
786
+ h.tokenStore = store
787
+ }
788
+ if h.cfg != nil {
789
+ if dirSetter, ok := store.(interface{ SetBaseDir(string) }); ok {
790
+ dirSetter.SetBaseDir(h.cfg.AuthDir)
791
+ }
792
+ }
793
+ return store
794
+ }
795
+
796
+ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
797
+ if record == nil {
798
+ return "", fmt.Errorf("token record is nil")
799
+ }
800
+ store := h.tokenStoreWithBaseDir()
801
+ if store == nil {
802
+ return "", fmt.Errorf("token store unavailable")
803
+ }
804
+ return store.Save(ctx, record)
805
+ }
806
+
807
+ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
808
+ ctx := context.Background()
809
+
810
+ fmt.Println("Initializing Claude authentication...")
811
+
812
+ // Generate PKCE codes
813
+ pkceCodes, err := claude.GeneratePKCECodes()
814
+ if err != nil {
815
+ log.Errorf("Failed to generate PKCE codes: %v", err)
816
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
817
+ return
818
+ }
819
+
820
+ // Generate random state parameter
821
+ state, err := misc.GenerateRandomState()
822
+ if err != nil {
823
+ log.Errorf("Failed to generate state parameter: %v", err)
824
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
825
+ return
826
+ }
827
+
828
+ // Initialize Claude auth service
829
+ anthropicAuth := claude.NewClaudeAuth(h.cfg)
830
+
831
+ // Generate authorization URL (then override redirect_uri to reuse server port)
832
+ authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
833
+ if err != nil {
834
+ log.Errorf("Failed to generate authorization URL: %v", err)
835
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
836
+ return
837
+ }
838
+
839
+ RegisterOAuthSession(state, "anthropic")
840
+
841
+ isWebUI := isWebUIRequest(c)
842
+ var forwarder *callbackForwarder
843
+ if isWebUI {
844
+ targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
845
+ if errTarget != nil {
846
+ log.WithError(errTarget).Error("failed to compute anthropic callback target")
847
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
848
+ return
849
+ }
850
+ var errStart error
851
+ if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
852
+ log.WithError(errStart).Error("failed to start anthropic callback forwarder")
853
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
854
+ return
855
+ }
856
+ }
857
+
858
+ go func() {
859
+ if isWebUI {
860
+ defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
861
+ }
862
+
863
+ // Helper: wait for callback file
864
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
865
+ waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
866
+ deadline := time.Now().Add(timeout)
867
+ for {
868
+ if !IsOAuthSessionPending(state, "anthropic") {
869
+ return nil, errOAuthSessionNotPending
870
+ }
871
+ if time.Now().After(deadline) {
872
+ SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
873
+ return nil, fmt.Errorf("timeout waiting for OAuth callback")
874
+ }
875
+ data, errRead := os.ReadFile(path)
876
+ if errRead == nil {
877
+ var m map[string]string
878
+ _ = json.Unmarshal(data, &m)
879
+ _ = os.Remove(path)
880
+ return m, nil
881
+ }
882
+ time.Sleep(500 * time.Millisecond)
883
+ }
884
+ }
885
+
886
+ fmt.Println("Waiting for authentication callback...")
887
+ // Wait up to 5 minutes
888
+ resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
889
+ if errWait != nil {
890
+ if errors.Is(errWait, errOAuthSessionNotPending) {
891
+ return
892
+ }
893
+ authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
894
+ log.Error(claude.GetUserFriendlyMessage(authErr))
895
+ return
896
+ }
897
+ if errStr := resultMap["error"]; errStr != "" {
898
+ oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
899
+ log.Error(claude.GetUserFriendlyMessage(oauthErr))
900
+ SetOAuthSessionError(state, "Bad request")
901
+ return
902
+ }
903
+ if resultMap["state"] != state {
904
+ authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
905
+ log.Error(claude.GetUserFriendlyMessage(authErr))
906
+ SetOAuthSessionError(state, "State code error")
907
+ return
908
+ }
909
+
910
+ // Parse code (Claude may append state after '#')
911
+ rawCode := resultMap["code"]
912
+ code := strings.Split(rawCode, "#")[0]
913
+
914
+ // Exchange code for tokens (replicate logic using updated redirect_uri)
915
+ // Extract client_id from the modified auth URL
916
+ clientID := ""
917
+ if u2, errP := url.Parse(authURL); errP == nil {
918
+ clientID = u2.Query().Get("client_id")
919
+ }
920
+ // Build request
921
+ bodyMap := map[string]any{
922
+ "code": code,
923
+ "state": state,
924
+ "grant_type": "authorization_code",
925
+ "client_id": clientID,
926
+ "redirect_uri": "http://localhost:54545/callback",
927
+ "code_verifier": pkceCodes.CodeVerifier,
928
+ }
929
+ bodyJSON, _ := json.Marshal(bodyMap)
930
+
931
+ httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
932
+ req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
933
+ req.Header.Set("Content-Type", "application/json")
934
+ req.Header.Set("Accept", "application/json")
935
+ resp, errDo := httpClient.Do(req)
936
+ if errDo != nil {
937
+ authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
938
+ log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
939
+ SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
940
+ return
941
+ }
942
+ defer func() {
943
+ if errClose := resp.Body.Close(); errClose != nil {
944
+ log.Errorf("failed to close response body: %v", errClose)
945
+ }
946
+ }()
947
+ respBody, _ := io.ReadAll(resp.Body)
948
+ if resp.StatusCode != http.StatusOK {
949
+ log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
950
+ SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
951
+ return
952
+ }
953
+ var tResp struct {
954
+ AccessToken string `json:"access_token"`
955
+ RefreshToken string `json:"refresh_token"`
956
+ ExpiresIn int `json:"expires_in"`
957
+ Account struct {
958
+ EmailAddress string `json:"email_address"`
959
+ } `json:"account"`
960
+ }
961
+ if errU := json.Unmarshal(respBody, &tResp); errU != nil {
962
+ log.Errorf("failed to parse token response: %v", errU)
963
+ SetOAuthSessionError(state, "Failed to parse token response")
964
+ return
965
+ }
966
+ bundle := &claude.ClaudeAuthBundle{
967
+ TokenData: claude.ClaudeTokenData{
968
+ AccessToken: tResp.AccessToken,
969
+ RefreshToken: tResp.RefreshToken,
970
+ Email: tResp.Account.EmailAddress,
971
+ Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
972
+ },
973
+ LastRefresh: time.Now().Format(time.RFC3339),
974
+ }
975
+
976
+ // Create token storage
977
+ tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
978
+ record := &coreauth.Auth{
979
+ ID: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
980
+ Provider: "claude",
981
+ FileName: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
982
+ Storage: tokenStorage,
983
+ Metadata: map[string]any{"email": tokenStorage.Email},
984
+ }
985
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
986
+ if errSave != nil {
987
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
988
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
989
+ return
990
+ }
991
+
992
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
993
+ if bundle.APIKey != "" {
994
+ fmt.Println("API key obtained and saved")
995
+ }
996
+ fmt.Println("You can now use Claude services through this CLI")
997
+ CompleteOAuthSession(state)
998
+ CompleteOAuthSessionsByProvider("anthropic")
999
+ }()
1000
+
1001
+ c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
1002
+ }
1003
+
1004
+ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
1005
+ ctx := context.Background()
1006
+ proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
1007
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
1008
+
1009
+ // Optional project ID from query
1010
+ projectID := c.Query("project_id")
1011
+
1012
+ fmt.Println("Initializing Google authentication...")
1013
+
1014
+ // OAuth2 configuration (mirrors internal/auth/gemini)
1015
+ conf := &oauth2.Config{
1016
+ ClientID: "YOUR_CLIENT_ID",
1017
+ ClientSecret: "YOUR_CLIENT_SECRET",
1018
+ RedirectURL: "http://localhost:8085/oauth2callback",
1019
+ Scopes: []string{
1020
+ "https://www.googleapis.com/auth/cloud-platform",
1021
+ "https://www.googleapis.com/auth/userinfo.email",
1022
+ "https://www.googleapis.com/auth/userinfo.profile",
1023
+ },
1024
+ Endpoint: google.Endpoint,
1025
+ }
1026
+
1027
+ // Build authorization URL and return it immediately
1028
+ state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
1029
+ authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
1030
+
1031
+ RegisterOAuthSession(state, "gemini")
1032
+
1033
+ isWebUI := isWebUIRequest(c)
1034
+ var forwarder *callbackForwarder
1035
+ if isWebUI {
1036
+ targetURL, errTarget := h.managementCallbackURL("/google/callback")
1037
+ if errTarget != nil {
1038
+ log.WithError(errTarget).Error("failed to compute gemini callback target")
1039
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
1040
+ return
1041
+ }
1042
+ var errStart error
1043
+ if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
1044
+ log.WithError(errStart).Error("failed to start gemini callback forwarder")
1045
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
1046
+ return
1047
+ }
1048
+ }
1049
+
1050
+ go func() {
1051
+ if isWebUI {
1052
+ defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
1053
+ }
1054
+
1055
+ // Wait for callback file written by server route
1056
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
1057
+ fmt.Println("Waiting for authentication callback...")
1058
+ deadline := time.Now().Add(5 * time.Minute)
1059
+ var authCode string
1060
+ for {
1061
+ if !IsOAuthSessionPending(state, "gemini") {
1062
+ return
1063
+ }
1064
+ if time.Now().After(deadline) {
1065
+ log.Error("oauth flow timed out")
1066
+ SetOAuthSessionError(state, "OAuth flow timed out")
1067
+ return
1068
+ }
1069
+ if data, errR := os.ReadFile(waitFile); errR == nil {
1070
+ var m map[string]string
1071
+ _ = json.Unmarshal(data, &m)
1072
+ _ = os.Remove(waitFile)
1073
+ if errStr := m["error"]; errStr != "" {
1074
+ log.Errorf("Authentication failed: %s", errStr)
1075
+ SetOAuthSessionError(state, "Authentication failed")
1076
+ return
1077
+ }
1078
+ authCode = m["code"]
1079
+ if authCode == "" {
1080
+ log.Errorf("Authentication failed: code not found")
1081
+ SetOAuthSessionError(state, "Authentication failed: code not found")
1082
+ return
1083
+ }
1084
+ break
1085
+ }
1086
+ time.Sleep(500 * time.Millisecond)
1087
+ }
1088
+
1089
+ // Exchange authorization code for token
1090
+ token, err := conf.Exchange(ctx, authCode)
1091
+ if err != nil {
1092
+ log.Errorf("Failed to exchange token: %v", err)
1093
+ SetOAuthSessionError(state, "Failed to exchange token")
1094
+ return
1095
+ }
1096
+
1097
+ requestedProjectID := strings.TrimSpace(projectID)
1098
+
1099
+ // Create token storage (mirrors internal/auth/gemini createTokenStorage)
1100
+ authHTTPClient := conf.Client(ctx, token)
1101
+ req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
1102
+ if errNewRequest != nil {
1103
+ log.Errorf("Could not get user info: %v", errNewRequest)
1104
+ SetOAuthSessionError(state, "Could not get user info")
1105
+ return
1106
+ }
1107
+ req.Header.Set("Content-Type", "application/json")
1108
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
1109
+
1110
+ resp, errDo := authHTTPClient.Do(req)
1111
+ if errDo != nil {
1112
+ log.Errorf("Failed to execute request: %v", errDo)
1113
+ SetOAuthSessionError(state, "Failed to execute request")
1114
+ return
1115
+ }
1116
+ defer func() {
1117
+ if errClose := resp.Body.Close(); errClose != nil {
1118
+ log.Printf("warn: failed to close response body: %v", errClose)
1119
+ }
1120
+ }()
1121
+
1122
+ bodyBytes, _ := io.ReadAll(resp.Body)
1123
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1124
+ log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
1125
+ SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
1126
+ return
1127
+ }
1128
+
1129
+ email := gjson.GetBytes(bodyBytes, "email").String()
1130
+ if email != "" {
1131
+ fmt.Printf("Authenticated user email: %s\n", email)
1132
+ } else {
1133
+ fmt.Println("Failed to get user email from token")
1134
+ }
1135
+
1136
+ // Marshal/unmarshal oauth2.Token to generic map and enrich fields
1137
+ var ifToken map[string]any
1138
+ jsonData, _ := json.Marshal(token)
1139
+ if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
1140
+ log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
1141
+ SetOAuthSessionError(state, "Failed to unmarshal token")
1142
+ return
1143
+ }
1144
+
1145
+ ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
1146
+ ifToken["client_id"] = "YOUR_CLIENT_ID"
1147
+ ifToken["client_secret"] = "YOUR_CLIENT_SECRET"
1148
+ ifToken["scopes"] = []string{
1149
+ "https://www.googleapis.com/auth/cloud-platform",
1150
+ "https://www.googleapis.com/auth/userinfo.email",
1151
+ "https://www.googleapis.com/auth/userinfo.profile",
1152
+ }
1153
+ ifToken["universe_domain"] = "googleapis.com"
1154
+
1155
+ ts := geminiAuth.GeminiTokenStorage{
1156
+ Token: ifToken,
1157
+ ProjectID: requestedProjectID,
1158
+ Email: email,
1159
+ Auto: requestedProjectID == "",
1160
+ }
1161
+
1162
+ // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
1163
+ gemAuth := geminiAuth.NewGeminiAuth()
1164
+ gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
1165
+ NoBrowser: true,
1166
+ })
1167
+ if errGetClient != nil {
1168
+ log.Errorf("failed to get authenticated client: %v", errGetClient)
1169
+ SetOAuthSessionError(state, "Failed to get authenticated client")
1170
+ return
1171
+ }
1172
+ fmt.Println("Authentication successful.")
1173
+
1174
+ if strings.EqualFold(requestedProjectID, "ALL") {
1175
+ ts.Auto = false
1176
+ projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
1177
+ if errAll != nil {
1178
+ log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
1179
+ SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
1180
+ return
1181
+ }
1182
+ if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
1183
+ log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
1184
+ SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
1185
+ return
1186
+ }
1187
+ ts.ProjectID = strings.Join(projects, ",")
1188
+ ts.Checked = true
1189
+ } else {
1190
+ if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
1191
+ log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
1192
+ SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
1193
+ return
1194
+ }
1195
+
1196
+ if strings.TrimSpace(ts.ProjectID) == "" {
1197
+ log.Error("Onboarding did not return a project ID")
1198
+ SetOAuthSessionError(state, "Failed to resolve project ID")
1199
+ return
1200
+ }
1201
+
1202
+ isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
1203
+ if errCheck != nil {
1204
+ log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
1205
+ SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
1206
+ return
1207
+ }
1208
+ ts.Checked = isChecked
1209
+ if !isChecked {
1210
+ log.Error("Cloud AI API is not enabled for the selected project")
1211
+ SetOAuthSessionError(state, "Cloud AI API not enabled")
1212
+ return
1213
+ }
1214
+ }
1215
+
1216
+ recordMetadata := map[string]any{
1217
+ "email": ts.Email,
1218
+ "project_id": ts.ProjectID,
1219
+ "auto": ts.Auto,
1220
+ "checked": ts.Checked,
1221
+ }
1222
+
1223
+ fileName := geminiAuth.CredentialFileName(ts.Email, ts.ProjectID, true)
1224
+ record := &coreauth.Auth{
1225
+ ID: fileName,
1226
+ Provider: "gemini",
1227
+ FileName: fileName,
1228
+ Storage: &ts,
1229
+ Metadata: recordMetadata,
1230
+ }
1231
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1232
+ if errSave != nil {
1233
+ log.Errorf("Failed to save token to file: %v", errSave)
1234
+ SetOAuthSessionError(state, "Failed to save token to file")
1235
+ return
1236
+ }
1237
+
1238
+ CompleteOAuthSession(state)
1239
+ CompleteOAuthSessionsByProvider("gemini")
1240
+ fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
1241
+ }()
1242
+
1243
+ c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
1244
+ }
1245
+
1246
+ func (h *Handler) RequestCodexToken(c *gin.Context) {
1247
+ ctx := context.Background()
1248
+
1249
+ fmt.Println("Initializing Codex authentication...")
1250
+
1251
+ // Generate PKCE codes
1252
+ pkceCodes, err := codex.GeneratePKCECodes()
1253
+ if err != nil {
1254
+ log.Errorf("Failed to generate PKCE codes: %v", err)
1255
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
1256
+ return
1257
+ }
1258
+
1259
+ // Generate random state parameter
1260
+ state, err := misc.GenerateRandomState()
1261
+ if err != nil {
1262
+ log.Errorf("Failed to generate state parameter: %v", err)
1263
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
1264
+ return
1265
+ }
1266
+
1267
+ // Initialize Codex auth service
1268
+ openaiAuth := codex.NewCodexAuth(h.cfg)
1269
+
1270
+ // Generate authorization URL
1271
+ authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
1272
+ if err != nil {
1273
+ log.Errorf("Failed to generate authorization URL: %v", err)
1274
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
1275
+ return
1276
+ }
1277
+
1278
+ RegisterOAuthSession(state, "codex")
1279
+
1280
+ isWebUI := isWebUIRequest(c)
1281
+ var forwarder *callbackForwarder
1282
+ if isWebUI {
1283
+ targetURL, errTarget := h.managementCallbackURL("/codex/callback")
1284
+ if errTarget != nil {
1285
+ log.WithError(errTarget).Error("failed to compute codex callback target")
1286
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
1287
+ return
1288
+ }
1289
+ var errStart error
1290
+ if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
1291
+ log.WithError(errStart).Error("failed to start codex callback forwarder")
1292
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
1293
+ return
1294
+ }
1295
+ }
1296
+
1297
+ go func() {
1298
+ if isWebUI {
1299
+ defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
1300
+ }
1301
+
1302
+ // Wait for callback file
1303
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
1304
+ deadline := time.Now().Add(5 * time.Minute)
1305
+ var code string
1306
+ for {
1307
+ if !IsOAuthSessionPending(state, "codex") {
1308
+ return
1309
+ }
1310
+ if time.Now().After(deadline) {
1311
+ authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
1312
+ log.Error(codex.GetUserFriendlyMessage(authErr))
1313
+ SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
1314
+ return
1315
+ }
1316
+ if data, errR := os.ReadFile(waitFile); errR == nil {
1317
+ var m map[string]string
1318
+ _ = json.Unmarshal(data, &m)
1319
+ _ = os.Remove(waitFile)
1320
+ if errStr := m["error"]; errStr != "" {
1321
+ oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
1322
+ log.Error(codex.GetUserFriendlyMessage(oauthErr))
1323
+ SetOAuthSessionError(state, "Bad Request")
1324
+ return
1325
+ }
1326
+ if m["state"] != state {
1327
+ authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
1328
+ SetOAuthSessionError(state, "State code error")
1329
+ log.Error(codex.GetUserFriendlyMessage(authErr))
1330
+ return
1331
+ }
1332
+ code = m["code"]
1333
+ break
1334
+ }
1335
+ time.Sleep(500 * time.Millisecond)
1336
+ }
1337
+
1338
+ log.Debug("Authorization code received, exchanging for tokens...")
1339
+ // Extract client_id from authURL
1340
+ clientID := ""
1341
+ if u2, errP := url.Parse(authURL); errP == nil {
1342
+ clientID = u2.Query().Get("client_id")
1343
+ }
1344
+ // Exchange code for tokens with redirect equal to mgmtRedirect
1345
+ form := url.Values{
1346
+ "grant_type": {"authorization_code"},
1347
+ "client_id": {clientID},
1348
+ "code": {code},
1349
+ "redirect_uri": {"http://localhost:1455/auth/callback"},
1350
+ "code_verifier": {pkceCodes.CodeVerifier},
1351
+ }
1352
+ httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
1353
+ req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
1354
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1355
+ req.Header.Set("Accept", "application/json")
1356
+ resp, errDo := httpClient.Do(req)
1357
+ if errDo != nil {
1358
+ authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
1359
+ SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
1360
+ log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
1361
+ return
1362
+ }
1363
+ defer func() { _ = resp.Body.Close() }()
1364
+ respBody, _ := io.ReadAll(resp.Body)
1365
+ if resp.StatusCode != http.StatusOK {
1366
+ SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
1367
+ log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
1368
+ return
1369
+ }
1370
+ var tokenResp struct {
1371
+ AccessToken string `json:"access_token"`
1372
+ RefreshToken string `json:"refresh_token"`
1373
+ IDToken string `json:"id_token"`
1374
+ ExpiresIn int `json:"expires_in"`
1375
+ }
1376
+ if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
1377
+ SetOAuthSessionError(state, "Failed to parse token response")
1378
+ log.Errorf("failed to parse token response: %v", errU)
1379
+ return
1380
+ }
1381
+ claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
1382
+ email := ""
1383
+ accountID := ""
1384
+ if claims != nil {
1385
+ email = claims.GetUserEmail()
1386
+ accountID = claims.GetAccountID()
1387
+ }
1388
+ // Build bundle compatible with existing storage
1389
+ bundle := &codex.CodexAuthBundle{
1390
+ TokenData: codex.CodexTokenData{
1391
+ IDToken: tokenResp.IDToken,
1392
+ AccessToken: tokenResp.AccessToken,
1393
+ RefreshToken: tokenResp.RefreshToken,
1394
+ AccountID: accountID,
1395
+ Email: email,
1396
+ Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
1397
+ },
1398
+ LastRefresh: time.Now().Format(time.RFC3339),
1399
+ }
1400
+
1401
+ // Create token storage and persist
1402
+ tokenStorage := openaiAuth.CreateTokenStorage(bundle)
1403
+ record := &coreauth.Auth{
1404
+ ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
1405
+ Provider: "codex",
1406
+ FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
1407
+ Storage: tokenStorage,
1408
+ Metadata: map[string]any{
1409
+ "email": tokenStorage.Email,
1410
+ "account_id": tokenStorage.AccountID,
1411
+ },
1412
+ }
1413
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1414
+ if errSave != nil {
1415
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
1416
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
1417
+ return
1418
+ }
1419
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
1420
+ if bundle.APIKey != "" {
1421
+ fmt.Println("API key obtained and saved")
1422
+ }
1423
+ fmt.Println("You can now use Codex services through this CLI")
1424
+ CompleteOAuthSession(state)
1425
+ CompleteOAuthSessionsByProvider("codex")
1426
+ }()
1427
+
1428
+ c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
1429
+ }
1430
+
1431
+ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
1432
+ const (
1433
+ antigravityCallbackPort = 51121
1434
+ antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
1435
+ antigravityClientSecret = "YOUR_ANTIGRAVITY_CLIENT_SECRET"
1436
+ )
1437
+ var antigravityScopes = []string{
1438
+ "https://www.googleapis.com/auth/cloud-platform",
1439
+ "https://www.googleapis.com/auth/userinfo.email",
1440
+ "https://www.googleapis.com/auth/userinfo.profile",
1441
+ "https://www.googleapis.com/auth/cclog",
1442
+ "https://www.googleapis.com/auth/experimentsandconfigs",
1443
+ }
1444
+
1445
+ ctx := context.Background()
1446
+
1447
+ fmt.Println("Initializing Antigravity authentication...")
1448
+
1449
+ state, errState := misc.GenerateRandomState()
1450
+ if errState != nil {
1451
+ log.Errorf("Failed to generate state parameter: %v", errState)
1452
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
1453
+ return
1454
+ }
1455
+
1456
+ redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
1457
+
1458
+ params := url.Values{}
1459
+ params.Set("access_type", "offline")
1460
+ params.Set("client_id", antigravityClientID)
1461
+ params.Set("prompt", "consent")
1462
+ params.Set("redirect_uri", redirectURI)
1463
+ params.Set("response_type", "code")
1464
+ params.Set("scope", strings.Join(antigravityScopes, " "))
1465
+ params.Set("state", state)
1466
+ authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
1467
+
1468
+ RegisterOAuthSession(state, "antigravity")
1469
+
1470
+ isWebUI := isWebUIRequest(c)
1471
+ var forwarder *callbackForwarder
1472
+ if isWebUI {
1473
+ targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
1474
+ if errTarget != nil {
1475
+ log.WithError(errTarget).Error("failed to compute antigravity callback target")
1476
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
1477
+ return
1478
+ }
1479
+ var errStart error
1480
+ if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
1481
+ log.WithError(errStart).Error("failed to start antigravity callback forwarder")
1482
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
1483
+ return
1484
+ }
1485
+ }
1486
+
1487
+ go func() {
1488
+ if isWebUI {
1489
+ defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
1490
+ }
1491
+
1492
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
1493
+ deadline := time.Now().Add(5 * time.Minute)
1494
+ var authCode string
1495
+ for {
1496
+ if !IsOAuthSessionPending(state, "antigravity") {
1497
+ return
1498
+ }
1499
+ if time.Now().After(deadline) {
1500
+ log.Error("oauth flow timed out")
1501
+ SetOAuthSessionError(state, "OAuth flow timed out")
1502
+ return
1503
+ }
1504
+ if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
1505
+ var payload map[string]string
1506
+ _ = json.Unmarshal(data, &payload)
1507
+ _ = os.Remove(waitFile)
1508
+ if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
1509
+ log.Errorf("Authentication failed: %s", errStr)
1510
+ SetOAuthSessionError(state, "Authentication failed")
1511
+ return
1512
+ }
1513
+ if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
1514
+ log.Errorf("Authentication failed: state mismatch")
1515
+ SetOAuthSessionError(state, "Authentication failed: state mismatch")
1516
+ return
1517
+ }
1518
+ authCode = strings.TrimSpace(payload["code"])
1519
+ if authCode == "" {
1520
+ log.Error("Authentication failed: code not found")
1521
+ SetOAuthSessionError(state, "Authentication failed: code not found")
1522
+ return
1523
+ }
1524
+ break
1525
+ }
1526
+ time.Sleep(500 * time.Millisecond)
1527
+ }
1528
+
1529
+ httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
1530
+ form := url.Values{}
1531
+ form.Set("code", authCode)
1532
+ form.Set("client_id", antigravityClientID)
1533
+ form.Set("client_secret", antigravityClientSecret)
1534
+ form.Set("redirect_uri", redirectURI)
1535
+ form.Set("grant_type", "authorization_code")
1536
+
1537
+ req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
1538
+ if errNewRequest != nil {
1539
+ log.Errorf("Failed to build token request: %v", errNewRequest)
1540
+ SetOAuthSessionError(state, "Failed to build token request")
1541
+ return
1542
+ }
1543
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1544
+
1545
+ resp, errDo := httpClient.Do(req)
1546
+ if errDo != nil {
1547
+ log.Errorf("Failed to execute token request: %v", errDo)
1548
+ SetOAuthSessionError(state, "Failed to exchange token")
1549
+ return
1550
+ }
1551
+ defer func() {
1552
+ if errClose := resp.Body.Close(); errClose != nil {
1553
+ log.Errorf("antigravity token exchange close error: %v", errClose)
1554
+ }
1555
+ }()
1556
+
1557
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
1558
+ bodyBytes, _ := io.ReadAll(resp.Body)
1559
+ log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
1560
+ SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
1561
+ return
1562
+ }
1563
+
1564
+ var tokenResp struct {
1565
+ AccessToken string `json:"access_token"`
1566
+ RefreshToken string `json:"refresh_token"`
1567
+ ExpiresIn int64 `json:"expires_in"`
1568
+ TokenType string `json:"token_type"`
1569
+ }
1570
+ if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
1571
+ log.Errorf("Failed to parse token response: %v", errDecode)
1572
+ SetOAuthSessionError(state, "Failed to parse token response")
1573
+ return
1574
+ }
1575
+
1576
+ email := ""
1577
+ if strings.TrimSpace(tokenResp.AccessToken) != "" {
1578
+ infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
1579
+ if errInfoReq != nil {
1580
+ log.Errorf("Failed to build user info request: %v", errInfoReq)
1581
+ SetOAuthSessionError(state, "Failed to build user info request")
1582
+ return
1583
+ }
1584
+ infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
1585
+
1586
+ infoResp, errInfo := httpClient.Do(infoReq)
1587
+ if errInfo != nil {
1588
+ log.Errorf("Failed to execute user info request: %v", errInfo)
1589
+ SetOAuthSessionError(state, "Failed to execute user info request")
1590
+ return
1591
+ }
1592
+ defer func() {
1593
+ if errClose := infoResp.Body.Close(); errClose != nil {
1594
+ log.Errorf("antigravity user info close error: %v", errClose)
1595
+ }
1596
+ }()
1597
+
1598
+ if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
1599
+ var infoPayload struct {
1600
+ Email string `json:"email"`
1601
+ }
1602
+ if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
1603
+ email = strings.TrimSpace(infoPayload.Email)
1604
+ }
1605
+ } else {
1606
+ bodyBytes, _ := io.ReadAll(infoResp.Body)
1607
+ log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
1608
+ SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
1609
+ return
1610
+ }
1611
+ }
1612
+
1613
+ projectID := ""
1614
+ if strings.TrimSpace(tokenResp.AccessToken) != "" {
1615
+ fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
1616
+ if errProject != nil {
1617
+ log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
1618
+ } else {
1619
+ projectID = fetchedProjectID
1620
+ log.Infof("antigravity: obtained project ID %s", projectID)
1621
+ }
1622
+ }
1623
+
1624
+ now := time.Now()
1625
+ metadata := map[string]any{
1626
+ "type": "antigravity",
1627
+ "access_token": tokenResp.AccessToken,
1628
+ "refresh_token": tokenResp.RefreshToken,
1629
+ "expires_in": tokenResp.ExpiresIn,
1630
+ "timestamp": now.UnixMilli(),
1631
+ "expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
1632
+ }
1633
+ if email != "" {
1634
+ metadata["email"] = email
1635
+ }
1636
+ if projectID != "" {
1637
+ metadata["project_id"] = projectID
1638
+ }
1639
+
1640
+ fileName := sanitizeAntigravityFileName(email)
1641
+ label := strings.TrimSpace(email)
1642
+ if label == "" {
1643
+ label = "antigravity"
1644
+ }
1645
+
1646
+ record := &coreauth.Auth{
1647
+ ID: fileName,
1648
+ Provider: "antigravity",
1649
+ FileName: fileName,
1650
+ Label: label,
1651
+ Metadata: metadata,
1652
+ }
1653
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1654
+ if errSave != nil {
1655
+ log.Errorf("Failed to save token to file: %v", errSave)
1656
+ SetOAuthSessionError(state, "Failed to save token to file")
1657
+ return
1658
+ }
1659
+
1660
+ CompleteOAuthSession(state)
1661
+ CompleteOAuthSessionsByProvider("antigravity")
1662
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
1663
+ if projectID != "" {
1664
+ fmt.Printf("Using GCP project: %s\n", projectID)
1665
+ }
1666
+ fmt.Println("You can now use Antigravity services through this CLI")
1667
+ }()
1668
+
1669
+ c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
1670
+ }
1671
+
1672
+ func (h *Handler) RequestQwenToken(c *gin.Context) {
1673
+ ctx := context.Background()
1674
+
1675
+ fmt.Println("Initializing Qwen authentication...")
1676
+
1677
+ state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
1678
+ // Initialize Qwen auth service
1679
+ qwenAuth := qwen.NewQwenAuth(h.cfg)
1680
+
1681
+ // Generate authorization URL
1682
+ deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
1683
+ if err != nil {
1684
+ log.Errorf("Failed to generate authorization URL: %v", err)
1685
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
1686
+ return
1687
+ }
1688
+ authURL := deviceFlow.VerificationURIComplete
1689
+
1690
+ RegisterOAuthSession(state, "qwen")
1691
+
1692
+ go func() {
1693
+ fmt.Println("Waiting for authentication...")
1694
+ tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
1695
+ if errPollForToken != nil {
1696
+ SetOAuthSessionError(state, "Authentication failed")
1697
+ fmt.Printf("Authentication failed: %v\n", errPollForToken)
1698
+ return
1699
+ }
1700
+
1701
+ // Create token storage
1702
+ tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
1703
+
1704
+ tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
1705
+ record := &coreauth.Auth{
1706
+ ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
1707
+ Provider: "qwen",
1708
+ FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
1709
+ Storage: tokenStorage,
1710
+ Metadata: map[string]any{"email": tokenStorage.Email},
1711
+ }
1712
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1713
+ if errSave != nil {
1714
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
1715
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
1716
+ return
1717
+ }
1718
+
1719
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
1720
+ fmt.Println("You can now use Qwen services through this CLI")
1721
+ CompleteOAuthSession(state)
1722
+ }()
1723
+
1724
+ c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
1725
+ }
1726
+
1727
+ func (h *Handler) RequestIFlowToken(c *gin.Context) {
1728
+ ctx := context.Background()
1729
+
1730
+ fmt.Println("Initializing iFlow authentication...")
1731
+
1732
+ state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
1733
+ authSvc := iflowauth.NewIFlowAuth(h.cfg)
1734
+ authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
1735
+
1736
+ RegisterOAuthSession(state, "iflow")
1737
+
1738
+ isWebUI := isWebUIRequest(c)
1739
+ var forwarder *callbackForwarder
1740
+ if isWebUI {
1741
+ targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
1742
+ if errTarget != nil {
1743
+ log.WithError(errTarget).Error("failed to compute iflow callback target")
1744
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
1745
+ return
1746
+ }
1747
+ var errStart error
1748
+ if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
1749
+ log.WithError(errStart).Error("failed to start iflow callback forwarder")
1750
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
1751
+ return
1752
+ }
1753
+ }
1754
+
1755
+ go func() {
1756
+ if isWebUI {
1757
+ defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
1758
+ }
1759
+ fmt.Println("Waiting for authentication...")
1760
+
1761
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
1762
+ deadline := time.Now().Add(5 * time.Minute)
1763
+ var resultMap map[string]string
1764
+ for {
1765
+ if !IsOAuthSessionPending(state, "iflow") {
1766
+ return
1767
+ }
1768
+ if time.Now().After(deadline) {
1769
+ SetOAuthSessionError(state, "Authentication failed")
1770
+ fmt.Println("Authentication failed: timeout waiting for callback")
1771
+ return
1772
+ }
1773
+ if data, errR := os.ReadFile(waitFile); errR == nil {
1774
+ _ = os.Remove(waitFile)
1775
+ _ = json.Unmarshal(data, &resultMap)
1776
+ break
1777
+ }
1778
+ time.Sleep(500 * time.Millisecond)
1779
+ }
1780
+
1781
+ if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
1782
+ SetOAuthSessionError(state, "Authentication failed")
1783
+ fmt.Printf("Authentication failed: %s\n", errStr)
1784
+ return
1785
+ }
1786
+ if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
1787
+ SetOAuthSessionError(state, "Authentication failed")
1788
+ fmt.Println("Authentication failed: state mismatch")
1789
+ return
1790
+ }
1791
+
1792
+ code := strings.TrimSpace(resultMap["code"])
1793
+ if code == "" {
1794
+ SetOAuthSessionError(state, "Authentication failed")
1795
+ fmt.Println("Authentication failed: code missing")
1796
+ return
1797
+ }
1798
+
1799
+ tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
1800
+ if errExchange != nil {
1801
+ SetOAuthSessionError(state, "Authentication failed")
1802
+ fmt.Printf("Authentication failed: %v\n", errExchange)
1803
+ return
1804
+ }
1805
+
1806
+ tokenStorage := authSvc.CreateTokenStorage(tokenData)
1807
+ identifier := strings.TrimSpace(tokenStorage.Email)
1808
+ if identifier == "" {
1809
+ identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
1810
+ tokenStorage.Email = identifier
1811
+ }
1812
+ record := &coreauth.Auth{
1813
+ ID: fmt.Sprintf("iflow-%s.json", identifier),
1814
+ Provider: "iflow",
1815
+ FileName: fmt.Sprintf("iflow-%s.json", identifier),
1816
+ Storage: tokenStorage,
1817
+ Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
1818
+ Attributes: map[string]string{"api_key": tokenStorage.APIKey},
1819
+ }
1820
+
1821
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1822
+ if errSave != nil {
1823
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
1824
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
1825
+ return
1826
+ }
1827
+
1828
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
1829
+ if tokenStorage.APIKey != "" {
1830
+ fmt.Println("API key obtained and saved")
1831
+ }
1832
+ fmt.Println("You can now use iFlow services through this CLI")
1833
+ CompleteOAuthSession(state)
1834
+ CompleteOAuthSessionsByProvider("iflow")
1835
+ }()
1836
+
1837
+ c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
1838
+ }
1839
+
1840
+ func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
1841
+ ctx := context.Background()
1842
+
1843
+ var payload struct {
1844
+ Cookie string `json:"cookie"`
1845
+ }
1846
+ if err := c.ShouldBindJSON(&payload); err != nil {
1847
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
1848
+ return
1849
+ }
1850
+
1851
+ cookieValue := strings.TrimSpace(payload.Cookie)
1852
+
1853
+ if cookieValue == "" {
1854
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
1855
+ return
1856
+ }
1857
+
1858
+ cookieValue, errNormalize := iflowauth.NormalizeCookie(cookieValue)
1859
+ if errNormalize != nil {
1860
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errNormalize.Error()})
1861
+ return
1862
+ }
1863
+
1864
+ // Check for duplicate BXAuth before authentication
1865
+ bxAuth := iflowauth.ExtractBXAuth(cookieValue)
1866
+ if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
1867
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
1868
+ return
1869
+ } else if existingFile != "" {
1870
+ existingFileName := filepath.Base(existingFile)
1871
+ c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
1872
+ return
1873
+ }
1874
+
1875
+ authSvc := iflowauth.NewIFlowAuth(h.cfg)
1876
+ tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
1877
+ if errAuth != nil {
1878
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errAuth.Error()})
1879
+ return
1880
+ }
1881
+
1882
+ tokenData.Cookie = cookieValue
1883
+
1884
+ tokenStorage := authSvc.CreateCookieTokenStorage(tokenData)
1885
+ email := strings.TrimSpace(tokenStorage.Email)
1886
+ if email == "" {
1887
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "failed to extract email from token"})
1888
+ return
1889
+ }
1890
+
1891
+ fileName := iflowauth.SanitizeIFlowFileName(email)
1892
+ if fileName == "" {
1893
+ fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
1894
+ }
1895
+
1896
+ tokenStorage.Email = email
1897
+ timestamp := time.Now().Unix()
1898
+
1899
+ record := &coreauth.Auth{
1900
+ ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
1901
+ Provider: "iflow",
1902
+ FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
1903
+ Storage: tokenStorage,
1904
+ Metadata: map[string]any{
1905
+ "email": email,
1906
+ "api_key": tokenStorage.APIKey,
1907
+ "expired": tokenStorage.Expire,
1908
+ "cookie": tokenStorage.Cookie,
1909
+ "type": tokenStorage.Type,
1910
+ "last_refresh": tokenStorage.LastRefresh,
1911
+ },
1912
+ Attributes: map[string]string{
1913
+ "api_key": tokenStorage.APIKey,
1914
+ },
1915
+ }
1916
+
1917
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
1918
+ if errSave != nil {
1919
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
1920
+ return
1921
+ }
1922
+
1923
+ fmt.Printf("iFlow cookie authentication successful. Token saved to %s\n", savedPath)
1924
+ c.JSON(http.StatusOK, gin.H{
1925
+ "status": "ok",
1926
+ "saved_path": savedPath,
1927
+ "email": email,
1928
+ "expired": tokenStorage.Expire,
1929
+ "type": tokenStorage.Type,
1930
+ })
1931
+ }
1932
+
1933
+ type projectSelectionRequiredError struct{}
1934
+
1935
+ func (e *projectSelectionRequiredError) Error() string {
1936
+ return "gemini cli: project selection required"
1937
+ }
1938
+
1939
+ func ensureGeminiProjectAndOnboard(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
1940
+ if storage == nil {
1941
+ return fmt.Errorf("gemini storage is nil")
1942
+ }
1943
+
1944
+ trimmedRequest := strings.TrimSpace(requestedProject)
1945
+ if trimmedRequest == "" {
1946
+ projects, errProjects := fetchGCPProjects(ctx, httpClient)
1947
+ if errProjects != nil {
1948
+ return fmt.Errorf("fetch project list: %w", errProjects)
1949
+ }
1950
+ if len(projects) == 0 {
1951
+ return fmt.Errorf("no Google Cloud projects available for this account")
1952
+ }
1953
+ trimmedRequest = strings.TrimSpace(projects[0].ProjectID)
1954
+ if trimmedRequest == "" {
1955
+ return fmt.Errorf("resolved project id is empty")
1956
+ }
1957
+ storage.Auto = true
1958
+ } else {
1959
+ storage.Auto = false
1960
+ }
1961
+
1962
+ if err := performGeminiCLISetup(ctx, httpClient, storage, trimmedRequest); err != nil {
1963
+ return err
1964
+ }
1965
+
1966
+ if strings.TrimSpace(storage.ProjectID) == "" {
1967
+ storage.ProjectID = trimmedRequest
1968
+ }
1969
+
1970
+ return nil
1971
+ }
1972
+
1973
+ func onboardAllGeminiProjects(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage) ([]string, error) {
1974
+ projects, errProjects := fetchGCPProjects(ctx, httpClient)
1975
+ if errProjects != nil {
1976
+ return nil, fmt.Errorf("fetch project list: %w", errProjects)
1977
+ }
1978
+ if len(projects) == 0 {
1979
+ return nil, fmt.Errorf("no Google Cloud projects available for this account")
1980
+ }
1981
+ activated := make([]string, 0, len(projects))
1982
+ seen := make(map[string]struct{}, len(projects))
1983
+ for _, project := range projects {
1984
+ candidate := strings.TrimSpace(project.ProjectID)
1985
+ if candidate == "" {
1986
+ continue
1987
+ }
1988
+ if _, dup := seen[candidate]; dup {
1989
+ continue
1990
+ }
1991
+ if err := performGeminiCLISetup(ctx, httpClient, storage, candidate); err != nil {
1992
+ return nil, fmt.Errorf("onboard project %s: %w", candidate, err)
1993
+ }
1994
+ finalID := strings.TrimSpace(storage.ProjectID)
1995
+ if finalID == "" {
1996
+ finalID = candidate
1997
+ }
1998
+ activated = append(activated, finalID)
1999
+ seen[candidate] = struct{}{}
2000
+ }
2001
+ if len(activated) == 0 {
2002
+ return nil, fmt.Errorf("no Google Cloud projects available for this account")
2003
+ }
2004
+ return activated, nil
2005
+ }
2006
+
2007
+ func ensureGeminiProjectsEnabled(ctx context.Context, httpClient *http.Client, projectIDs []string) error {
2008
+ for _, pid := range projectIDs {
2009
+ trimmed := strings.TrimSpace(pid)
2010
+ if trimmed == "" {
2011
+ continue
2012
+ }
2013
+ isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, trimmed)
2014
+ if errCheck != nil {
2015
+ return fmt.Errorf("project %s: %w", trimmed, errCheck)
2016
+ }
2017
+ if !isChecked {
2018
+ return fmt.Errorf("project %s: Cloud AI API not enabled", trimmed)
2019
+ }
2020
+ }
2021
+ return nil
2022
+ }
2023
+
2024
+ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
2025
+ metadata := map[string]string{
2026
+ "ideType": "IDE_UNSPECIFIED",
2027
+ "platform": "PLATFORM_UNSPECIFIED",
2028
+ "pluginType": "GEMINI",
2029
+ }
2030
+
2031
+ trimmedRequest := strings.TrimSpace(requestedProject)
2032
+ explicitProject := trimmedRequest != ""
2033
+
2034
+ loadReqBody := map[string]any{
2035
+ "metadata": metadata,
2036
+ }
2037
+ if explicitProject {
2038
+ loadReqBody["cloudaicompanionProject"] = trimmedRequest
2039
+ }
2040
+
2041
+ var loadResp map[string]any
2042
+ if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil {
2043
+ return fmt.Errorf("load code assist: %w", errLoad)
2044
+ }
2045
+
2046
+ tierID := "legacy-tier"
2047
+ if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
2048
+ for _, rawTier := range tiers {
2049
+ tier, okTier := rawTier.(map[string]any)
2050
+ if !okTier {
2051
+ continue
2052
+ }
2053
+ if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
2054
+ if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
2055
+ tierID = strings.TrimSpace(id)
2056
+ break
2057
+ }
2058
+ }
2059
+ }
2060
+ }
2061
+
2062
+ projectID := trimmedRequest
2063
+ if projectID == "" {
2064
+ if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject {
2065
+ projectID = strings.TrimSpace(id)
2066
+ }
2067
+ if projectID == "" {
2068
+ if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject {
2069
+ if id, okID := projectMap["id"].(string); okID {
2070
+ projectID = strings.TrimSpace(id)
2071
+ }
2072
+ }
2073
+ }
2074
+ }
2075
+ if projectID == "" {
2076
+ return &projectSelectionRequiredError{}
2077
+ }
2078
+
2079
+ onboardReqBody := map[string]any{
2080
+ "tierId": tierID,
2081
+ "metadata": metadata,
2082
+ "cloudaicompanionProject": projectID,
2083
+ }
2084
+
2085
+ storage.ProjectID = projectID
2086
+
2087
+ for {
2088
+ var onboardResp map[string]any
2089
+ if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil {
2090
+ return fmt.Errorf("onboard user: %w", errOnboard)
2091
+ }
2092
+
2093
+ if done, okDone := onboardResp["done"].(bool); okDone && done {
2094
+ responseProjectID := ""
2095
+ if resp, okResp := onboardResp["response"].(map[string]any); okResp {
2096
+ switch projectValue := resp["cloudaicompanionProject"].(type) {
2097
+ case map[string]any:
2098
+ if id, okID := projectValue["id"].(string); okID {
2099
+ responseProjectID = strings.TrimSpace(id)
2100
+ }
2101
+ case string:
2102
+ responseProjectID = strings.TrimSpace(projectValue)
2103
+ }
2104
+ }
2105
+
2106
+ finalProjectID := projectID
2107
+ if responseProjectID != "" {
2108
+ if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
2109
+ log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
2110
+ } else {
2111
+ finalProjectID = responseProjectID
2112
+ }
2113
+ }
2114
+
2115
+ storage.ProjectID = strings.TrimSpace(finalProjectID)
2116
+ if storage.ProjectID == "" {
2117
+ storage.ProjectID = strings.TrimSpace(projectID)
2118
+ }
2119
+ if storage.ProjectID == "" {
2120
+ return fmt.Errorf("onboard user completed without project id")
2121
+ }
2122
+ log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID)
2123
+ return nil
2124
+ }
2125
+
2126
+ log.Println("Onboarding in progress, waiting 5 seconds...")
2127
+ time.Sleep(5 * time.Second)
2128
+ }
2129
+ }
2130
+
2131
+ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error {
2132
+ endPointURL := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint)
2133
+ if strings.HasPrefix(endpoint, "operations/") {
2134
+ endPointURL = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint)
2135
+ }
2136
+
2137
+ var reader io.Reader
2138
+ if body != nil {
2139
+ rawBody, errMarshal := json.Marshal(body)
2140
+ if errMarshal != nil {
2141
+ return fmt.Errorf("marshal request body: %w", errMarshal)
2142
+ }
2143
+ reader = bytes.NewReader(rawBody)
2144
+ }
2145
+
2146
+ req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, endPointURL, reader)
2147
+ if errRequest != nil {
2148
+ return fmt.Errorf("create request: %w", errRequest)
2149
+ }
2150
+ req.Header.Set("Content-Type", "application/json")
2151
+ req.Header.Set("User-Agent", geminiCLIUserAgent)
2152
+ req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
2153
+ req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
2154
+
2155
+ resp, errDo := httpClient.Do(req)
2156
+ if errDo != nil {
2157
+ return fmt.Errorf("execute request: %w", errDo)
2158
+ }
2159
+ defer func() {
2160
+ if errClose := resp.Body.Close(); errClose != nil {
2161
+ log.Errorf("response body close error: %v", errClose)
2162
+ }
2163
+ }()
2164
+
2165
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
2166
+ bodyBytes, _ := io.ReadAll(resp.Body)
2167
+ return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
2168
+ }
2169
+
2170
+ if result == nil {
2171
+ _, _ = io.Copy(io.Discard, resp.Body)
2172
+ return nil
2173
+ }
2174
+
2175
+ if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil {
2176
+ return fmt.Errorf("decode response body: %w", errDecode)
2177
+ }
2178
+
2179
+ return nil
2180
+ }
2181
+
2182
+ func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) {
2183
+ req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
2184
+ if errRequest != nil {
2185
+ return nil, fmt.Errorf("could not create project list request: %w", errRequest)
2186
+ }
2187
+
2188
+ resp, errDo := httpClient.Do(req)
2189
+ if errDo != nil {
2190
+ return nil, fmt.Errorf("failed to execute project list request: %w", errDo)
2191
+ }
2192
+ defer func() {
2193
+ if errClose := resp.Body.Close(); errClose != nil {
2194
+ log.Errorf("response body close error: %v", errClose)
2195
+ }
2196
+ }()
2197
+
2198
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
2199
+ bodyBytes, _ := io.ReadAll(resp.Body)
2200
+ return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
2201
+ }
2202
+
2203
+ var projects interfaces.GCPProject
2204
+ if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil {
2205
+ return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode)
2206
+ }
2207
+
2208
+ return projects.Projects, nil
2209
+ }
2210
+
2211
+ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) {
2212
+ serviceUsageURL := "https://serviceusage.googleapis.com"
2213
+ requiredServices := []string{
2214
+ "cloudaicompanion.googleapis.com",
2215
+ }
2216
+ for _, service := range requiredServices {
2217
+ checkURL := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service)
2218
+ req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkURL, nil)
2219
+ if errRequest != nil {
2220
+ return false, fmt.Errorf("failed to create request: %w", errRequest)
2221
+ }
2222
+ req.Header.Set("Content-Type", "application/json")
2223
+ req.Header.Set("User-Agent", geminiCLIUserAgent)
2224
+ resp, errDo := httpClient.Do(req)
2225
+ if errDo != nil {
2226
+ return false, fmt.Errorf("failed to execute request: %w", errDo)
2227
+ }
2228
+
2229
+ if resp.StatusCode == http.StatusOK {
2230
+ bodyBytes, _ := io.ReadAll(resp.Body)
2231
+ if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" {
2232
+ _ = resp.Body.Close()
2233
+ continue
2234
+ }
2235
+ }
2236
+ _ = resp.Body.Close()
2237
+
2238
+ enableURL := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service)
2239
+ req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableURL, strings.NewReader("{}"))
2240
+ if errRequest != nil {
2241
+ return false, fmt.Errorf("failed to create request: %w", errRequest)
2242
+ }
2243
+ req.Header.Set("Content-Type", "application/json")
2244
+ req.Header.Set("User-Agent", geminiCLIUserAgent)
2245
+ resp, errDo = httpClient.Do(req)
2246
+ if errDo != nil {
2247
+ return false, fmt.Errorf("failed to execute request: %w", errDo)
2248
+ }
2249
+
2250
+ bodyBytes, _ := io.ReadAll(resp.Body)
2251
+ errMessage := string(bodyBytes)
2252
+ errMessageResult := gjson.GetBytes(bodyBytes, "error.message")
2253
+ if errMessageResult.Exists() {
2254
+ errMessage = errMessageResult.String()
2255
+ }
2256
+ if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
2257
+ _ = resp.Body.Close()
2258
+ continue
2259
+ } else if resp.StatusCode == http.StatusBadRequest {
2260
+ _ = resp.Body.Close()
2261
+ if strings.Contains(strings.ToLower(errMessage), "already enabled") {
2262
+ continue
2263
+ }
2264
+ }
2265
+ _ = resp.Body.Close()
2266
+ return false, fmt.Errorf("project activation required: %s", errMessage)
2267
+ }
2268
+ return true, nil
2269
+ }
2270
+
2271
+ func (h *Handler) GetAuthStatus(c *gin.Context) {
2272
+ state := strings.TrimSpace(c.Query("state"))
2273
+ if state == "" {
2274
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
2275
+ return
2276
+ }
2277
+ if err := ValidateOAuthState(state); err != nil {
2278
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
2279
+ return
2280
+ }
2281
+
2282
+ _, status, ok := GetOAuthSession(state)
2283
+ if !ok {
2284
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
2285
+ return
2286
+ }
2287
+ if status != "" {
2288
+ if strings.HasPrefix(status, "device_code|") {
2289
+ parts := strings.SplitN(status, "|", 3)
2290
+ if len(parts) == 3 {
2291
+ c.JSON(http.StatusOK, gin.H{
2292
+ "status": "device_code",
2293
+ "verification_url": parts[1],
2294
+ "user_code": parts[2],
2295
+ })
2296
+ return
2297
+ }
2298
+ }
2299
+ if strings.HasPrefix(status, "auth_url|") {
2300
+ authURL := strings.TrimPrefix(status, "auth_url|")
2301
+ c.JSON(http.StatusOK, gin.H{
2302
+ "status": "auth_url",
2303
+ "url": authURL,
2304
+ })
2305
+ return
2306
+ }
2307
+ c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
2308
+ return
2309
+ }
2310
+ c.JSON(http.StatusOK, gin.H{"status": "wait"})
2311
+ }
2312
+
2313
+ const kiroCallbackPort = 9876
2314
+
2315
+ func (h *Handler) RequestKiroToken(c *gin.Context) {
2316
+ ctx := context.Background()
2317
+
2318
+ // Get the login method from query parameter (default: aws for device code flow)
2319
+ method := strings.ToLower(strings.TrimSpace(c.Query("method")))
2320
+ if method == "" {
2321
+ method = "aws"
2322
+ }
2323
+
2324
+ fmt.Println("Initializing Kiro authentication...")
2325
+
2326
+ state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
2327
+
2328
+ switch method {
2329
+ case "aws", "builder-id":
2330
+ RegisterOAuthSession(state, "kiro")
2331
+
2332
+ // AWS Builder ID uses device code flow (no callback needed)
2333
+ go func() {
2334
+ ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
2335
+
2336
+ // Step 1: Register client
2337
+ fmt.Println("Registering client...")
2338
+ regResp, errRegister := ssoClient.RegisterClient(ctx)
2339
+ if errRegister != nil {
2340
+ log.Errorf("Failed to register client: %v", errRegister)
2341
+ SetOAuthSessionError(state, "Failed to register client")
2342
+ return
2343
+ }
2344
+
2345
+ // Step 2: Start device authorization
2346
+ fmt.Println("Starting device authorization...")
2347
+ authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
2348
+ if errAuth != nil {
2349
+ log.Errorf("Failed to start device auth: %v", errAuth)
2350
+ SetOAuthSessionError(state, "Failed to start device authorization")
2351
+ return
2352
+ }
2353
+
2354
+ // Store the verification URL for the frontend to display.
2355
+ // Using "|" as separator because URLs contain ":".
2356
+ SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
2357
+
2358
+ // Step 3: Poll for token
2359
+ fmt.Println("Waiting for authorization...")
2360
+ interval := 5 * time.Second
2361
+ if authResp.Interval > 0 {
2362
+ interval = time.Duration(authResp.Interval) * time.Second
2363
+ }
2364
+ deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
2365
+
2366
+ for time.Now().Before(deadline) {
2367
+ select {
2368
+ case <-ctx.Done():
2369
+ SetOAuthSessionError(state, "Authorization cancelled")
2370
+ return
2371
+ case <-time.After(interval):
2372
+ tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
2373
+ if errToken != nil {
2374
+ errStr := errToken.Error()
2375
+ if strings.Contains(errStr, "authorization_pending") {
2376
+ continue
2377
+ }
2378
+ if strings.Contains(errStr, "slow_down") {
2379
+ interval += 5 * time.Second
2380
+ continue
2381
+ }
2382
+ log.Errorf("Token creation failed: %v", errToken)
2383
+ SetOAuthSessionError(state, "Token creation failed")
2384
+ return
2385
+ }
2386
+
2387
+ // Success! Save the token
2388
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
2389
+ email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
2390
+
2391
+ idPart := kiroauth.SanitizeEmailForFilename(email)
2392
+ if idPart == "" {
2393
+ idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
2394
+ }
2395
+
2396
+ now := time.Now()
2397
+ fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
2398
+
2399
+ record := &coreauth.Auth{
2400
+ ID: fileName,
2401
+ Provider: "kiro",
2402
+ FileName: fileName,
2403
+ Metadata: map[string]any{
2404
+ "type": "kiro",
2405
+ "access_token": tokenResp.AccessToken,
2406
+ "refresh_token": tokenResp.RefreshToken,
2407
+ "expires_at": expiresAt.Format(time.RFC3339),
2408
+ "auth_method": "builder-id",
2409
+ "provider": "AWS",
2410
+ "client_id": regResp.ClientID,
2411
+ "client_secret": regResp.ClientSecret,
2412
+ "email": email,
2413
+ "last_refresh": now.Format(time.RFC3339),
2414
+ },
2415
+ }
2416
+
2417
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
2418
+ if errSave != nil {
2419
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
2420
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
2421
+ return
2422
+ }
2423
+
2424
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
2425
+ if email != "" {
2426
+ fmt.Printf("Authenticated as: %s\n", email)
2427
+ }
2428
+ CompleteOAuthSession(state)
2429
+ return
2430
+ }
2431
+ }
2432
+
2433
+ SetOAuthSessionError(state, "Authorization timed out")
2434
+ }()
2435
+
2436
+ // Return immediately with the state for polling
2437
+ c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"})
2438
+
2439
+ case "google", "github":
2440
+ RegisterOAuthSession(state, "kiro")
2441
+
2442
+ // Social auth uses protocol handler - for WEB UI we use a callback forwarder
2443
+ provider := "Google"
2444
+ if method == "github" {
2445
+ provider = "Github"
2446
+ }
2447
+
2448
+ isWebUI := isWebUIRequest(c)
2449
+ if isWebUI {
2450
+ targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
2451
+ if errTarget != nil {
2452
+ log.WithError(errTarget).Error("failed to compute kiro callback target")
2453
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
2454
+ return
2455
+ }
2456
+ if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
2457
+ log.WithError(errStart).Error("failed to start kiro callback forwarder")
2458
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
2459
+ return
2460
+ }
2461
+ }
2462
+
2463
+ go func() {
2464
+ if isWebUI {
2465
+ defer stopCallbackForwarder(kiroCallbackPort)
2466
+ }
2467
+
2468
+ socialClient := kiroauth.NewSocialAuthClient(h.cfg)
2469
+
2470
+ // Generate PKCE codes
2471
+ codeVerifier, codeChallenge, errPKCE := generateKiroPKCE()
2472
+ if errPKCE != nil {
2473
+ log.Errorf("Failed to generate PKCE: %v", errPKCE)
2474
+ SetOAuthSessionError(state, "Failed to generate PKCE")
2475
+ return
2476
+ }
2477
+
2478
+ // Build login URL
2479
+ authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
2480
+ "https://prod.us-east-1.auth.desktop.kiro.dev",
2481
+ provider,
2482
+ url.QueryEscape(kiroauth.KiroRedirectURI),
2483
+ codeChallenge,
2484
+ state,
2485
+ )
2486
+
2487
+ // Store auth URL for frontend.
2488
+ // Using "|" as separator because URLs contain ":".
2489
+ SetOAuthSessionError(state, "auth_url|"+authURL)
2490
+
2491
+ // Wait for callback file
2492
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
2493
+ deadline := time.Now().Add(5 * time.Minute)
2494
+
2495
+ for {
2496
+ if time.Now().After(deadline) {
2497
+ log.Error("oauth flow timed out")
2498
+ SetOAuthSessionError(state, "OAuth flow timed out")
2499
+ return
2500
+ }
2501
+ if data, errRead := os.ReadFile(waitFile); errRead == nil {
2502
+ var m map[string]string
2503
+ _ = json.Unmarshal(data, &m)
2504
+ _ = os.Remove(waitFile)
2505
+ if errStr := m["error"]; errStr != "" {
2506
+ log.Errorf("Authentication failed: %s", errStr)
2507
+ SetOAuthSessionError(state, "Authentication failed")
2508
+ return
2509
+ }
2510
+ if m["state"] != state {
2511
+ log.Errorf("State mismatch")
2512
+ SetOAuthSessionError(state, "State mismatch")
2513
+ return
2514
+ }
2515
+ code := m["code"]
2516
+ if code == "" {
2517
+ log.Error("No authorization code received")
2518
+ SetOAuthSessionError(state, "No authorization code received")
2519
+ return
2520
+ }
2521
+
2522
+ // Exchange code for tokens
2523
+ tokenReq := &kiroauth.CreateTokenRequest{
2524
+ Code: code,
2525
+ CodeVerifier: codeVerifier,
2526
+ RedirectURI: kiroauth.KiroRedirectURI,
2527
+ }
2528
+
2529
+ tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
2530
+ if errToken != nil {
2531
+ log.Errorf("Failed to exchange code for tokens: %v", errToken)
2532
+ SetOAuthSessionError(state, "Failed to exchange code for tokens")
2533
+ return
2534
+ }
2535
+
2536
+ // Save the token
2537
+ expiresIn := tokenResp.ExpiresIn
2538
+ if expiresIn <= 0 {
2539
+ expiresIn = 3600
2540
+ }
2541
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
2542
+ email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
2543
+
2544
+ idPart := kiroauth.SanitizeEmailForFilename(email)
2545
+ if idPart == "" {
2546
+ idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
2547
+ }
2548
+
2549
+ now := time.Now()
2550
+ fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
2551
+
2552
+ record := &coreauth.Auth{
2553
+ ID: fileName,
2554
+ Provider: "kiro",
2555
+ FileName: fileName,
2556
+ Metadata: map[string]any{
2557
+ "type": "kiro",
2558
+ "access_token": tokenResp.AccessToken,
2559
+ "refresh_token": tokenResp.RefreshToken,
2560
+ "profile_arn": tokenResp.ProfileArn,
2561
+ "expires_at": expiresAt.Format(time.RFC3339),
2562
+ "auth_method": "social",
2563
+ "provider": provider,
2564
+ "email": email,
2565
+ "last_refresh": now.Format(time.RFC3339),
2566
+ },
2567
+ }
2568
+
2569
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
2570
+ if errSave != nil {
2571
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
2572
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
2573
+ return
2574
+ }
2575
+
2576
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
2577
+ if email != "" {
2578
+ fmt.Printf("Authenticated as: %s\n", email)
2579
+ }
2580
+ CompleteOAuthSession(state)
2581
+ return
2582
+ }
2583
+ time.Sleep(500 * time.Millisecond)
2584
+ }
2585
+ }()
2586
+
2587
+ c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"})
2588
+
2589
+ default:
2590
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
2591
+ }
2592
+ }
2593
+
2594
+ // generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
2595
+ func generateKiroPKCE() (verifier, challenge string, err error) {
2596
+ b := make([]byte, 32)
2597
+ if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil {
2598
+ return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead)
2599
+ }
2600
+ verifier = base64.RawURLEncoding.EncodeToString(b)
2601
+
2602
+ h := sha256.Sum256([]byte(verifier))
2603
+ challenge = base64.RawURLEncoding.EncodeToString(h[:])
2604
+
2605
+ return verifier, challenge, nil
2606
+ }
internal/api/handlers/management/config_basic.go ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "io"
7
+ "net/http"
8
+ "os"
9
+ "path/filepath"
10
+ "strings"
11
+ "time"
12
+
13
+ "github.com/gin-gonic/gin"
14
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
15
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
16
+ sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
17
+ log "github.com/sirupsen/logrus"
18
+ "gopkg.in/yaml.v3"
19
+ )
20
+
21
+ const (
22
+ latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
23
+ latestReleaseUserAgent = "CLIProxyAPIPlus"
24
+ )
25
+
26
+ func (h *Handler) GetConfig(c *gin.Context) {
27
+ if h == nil || h.cfg == nil {
28
+ c.JSON(200, gin.H{})
29
+ return
30
+ }
31
+ cfgCopy := *h.cfg
32
+ c.JSON(200, &cfgCopy)
33
+ }
34
+
35
+ type releaseInfo struct {
36
+ TagName string `json:"tag_name"`
37
+ Name string `json:"name"`
38
+ }
39
+
40
+ // GetLatestVersion returns the latest release version from GitHub without downloading assets.
41
+ func (h *Handler) GetLatestVersion(c *gin.Context) {
42
+ client := &http.Client{Timeout: 10 * time.Second}
43
+ proxyURL := ""
44
+ if h != nil && h.cfg != nil {
45
+ proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
46
+ }
47
+ if proxyURL != "" {
48
+ sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
49
+ util.SetProxy(sdkCfg, client)
50
+ }
51
+
52
+ req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
53
+ if err != nil {
54
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
55
+ return
56
+ }
57
+ req.Header.Set("Accept", "application/vnd.github+json")
58
+ req.Header.Set("User-Agent", latestReleaseUserAgent)
59
+
60
+ resp, err := client.Do(req)
61
+ if err != nil {
62
+ c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
63
+ return
64
+ }
65
+ defer func() {
66
+ if errClose := resp.Body.Close(); errClose != nil {
67
+ log.WithError(errClose).Debug("failed to close latest version response body")
68
+ }
69
+ }()
70
+
71
+ if resp.StatusCode != http.StatusOK {
72
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
73
+ c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
74
+ return
75
+ }
76
+
77
+ var info releaseInfo
78
+ if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
79
+ c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
80
+ return
81
+ }
82
+
83
+ version := strings.TrimSpace(info.TagName)
84
+ if version == "" {
85
+ version = strings.TrimSpace(info.Name)
86
+ }
87
+ if version == "" {
88
+ c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
89
+ return
90
+ }
91
+
92
+ c.JSON(http.StatusOK, gin.H{"latest-version": version})
93
+ }
94
+
95
+ func WriteConfig(path string, data []byte) error {
96
+ data = config.NormalizeCommentIndentation(data)
97
+ f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
98
+ if err != nil {
99
+ return err
100
+ }
101
+ if _, errWrite := f.Write(data); errWrite != nil {
102
+ _ = f.Close()
103
+ return errWrite
104
+ }
105
+ if errSync := f.Sync(); errSync != nil {
106
+ _ = f.Close()
107
+ return errSync
108
+ }
109
+ return f.Close()
110
+ }
111
+
112
+ func (h *Handler) PutConfigYAML(c *gin.Context) {
113
+ body, err := io.ReadAll(c.Request.Body)
114
+ if err != nil {
115
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": "cannot read request body"})
116
+ return
117
+ }
118
+ var cfg config.Config
119
+ if err = yaml.Unmarshal(body, &cfg); err != nil {
120
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
121
+ return
122
+ }
123
+ // Validate config using LoadConfigOptional with optional=false to enforce parsing
124
+ tmpDir := filepath.Dir(h.configFilePath)
125
+ tmpFile, err := os.CreateTemp(tmpDir, "config-validate-*.yaml")
126
+ if err != nil {
127
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
128
+ return
129
+ }
130
+ tempFile := tmpFile.Name()
131
+ if _, errWrite := tmpFile.Write(body); errWrite != nil {
132
+ _ = tmpFile.Close()
133
+ _ = os.Remove(tempFile)
134
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
135
+ return
136
+ }
137
+ if errClose := tmpFile.Close(); errClose != nil {
138
+ _ = os.Remove(tempFile)
139
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
140
+ return
141
+ }
142
+ defer func() {
143
+ _ = os.Remove(tempFile)
144
+ }()
145
+ _, err = config.LoadConfigOptional(tempFile, false)
146
+ if err != nil {
147
+ c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
148
+ return
149
+ }
150
+ h.mu.Lock()
151
+ defer h.mu.Unlock()
152
+ if WriteConfig(h.configFilePath, body) != nil {
153
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": "failed to write config"})
154
+ return
155
+ }
156
+ // Reload into handler to keep memory in sync
157
+ newCfg, err := config.LoadConfig(h.configFilePath)
158
+ if err != nil {
159
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": err.Error()})
160
+ return
161
+ }
162
+ h.cfg = newCfg
163
+ c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}})
164
+ }
165
+
166
+ // GetConfigYAML returns the raw config.yaml file bytes without re-encoding.
167
+ // It preserves comments and original formatting/styles.
168
+ func (h *Handler) GetConfigYAML(c *gin.Context) {
169
+ data, err := os.ReadFile(h.configFilePath)
170
+ if err != nil {
171
+ if os.IsNotExist(err) {
172
+ c.JSON(http.StatusNotFound, gin.H{"error": "not_found", "message": "config file not found"})
173
+ return
174
+ }
175
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()})
176
+ return
177
+ }
178
+ c.Header("Content-Type", "application/yaml; charset=utf-8")
179
+ c.Header("Cache-Control", "no-store")
180
+ c.Header("X-Content-Type-Options", "nosniff")
181
+ // Write raw bytes as-is
182
+ _, _ = c.Writer.Write(data)
183
+ }
184
+
185
+ // Debug
186
+ func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) }
187
+ func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) }
188
+
189
+ // UsageStatisticsEnabled
190
+ func (h *Handler) GetUsageStatisticsEnabled(c *gin.Context) {
191
+ c.JSON(200, gin.H{"usage-statistics-enabled": h.cfg.UsageStatisticsEnabled})
192
+ }
193
+ func (h *Handler) PutUsageStatisticsEnabled(c *gin.Context) {
194
+ h.updateBoolField(c, func(v bool) { h.cfg.UsageStatisticsEnabled = v })
195
+ }
196
+
197
+ // UsageStatisticsEnabled
198
+ func (h *Handler) GetLoggingToFile(c *gin.Context) {
199
+ c.JSON(200, gin.H{"logging-to-file": h.cfg.LoggingToFile})
200
+ }
201
+ func (h *Handler) PutLoggingToFile(c *gin.Context) {
202
+ h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
203
+ }
204
+
205
+ // Request log
206
+ func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
207
+ func (h *Handler) PutRequestLog(c *gin.Context) {
208
+ h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
209
+ }
210
+
211
+ // Websocket auth
212
+ func (h *Handler) GetWebsocketAuth(c *gin.Context) {
213
+ c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
214
+ }
215
+ func (h *Handler) PutWebsocketAuth(c *gin.Context) {
216
+ h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
217
+ }
218
+
219
+ // Request retry
220
+ func (h *Handler) GetRequestRetry(c *gin.Context) {
221
+ c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
222
+ }
223
+ func (h *Handler) PutRequestRetry(c *gin.Context) {
224
+ h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
225
+ }
226
+
227
+ // Max retry interval
228
+ func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
229
+ c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
230
+ }
231
+ func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
232
+ h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
233
+ }
234
+
235
+ // Proxy URL
236
+ func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
237
+ func (h *Handler) PutProxyURL(c *gin.Context) {
238
+ h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v })
239
+ }
240
+ func (h *Handler) DeleteProxyURL(c *gin.Context) {
241
+ h.cfg.ProxyURL = ""
242
+ h.persist(c)
243
+ }
internal/api/handlers/management/config_lists.go ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "strings"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
10
+ )
11
+
12
+ // Generic helpers for list[string]
13
+ func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) {
14
+ data, err := c.GetRawData()
15
+ if err != nil {
16
+ c.JSON(400, gin.H{"error": "failed to read body"})
17
+ return
18
+ }
19
+ var arr []string
20
+ if err = json.Unmarshal(data, &arr); err != nil {
21
+ var obj struct {
22
+ Items []string `json:"items"`
23
+ }
24
+ if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
25
+ c.JSON(400, gin.H{"error": "invalid body"})
26
+ return
27
+ }
28
+ arr = obj.Items
29
+ }
30
+ set(arr)
31
+ if after != nil {
32
+ after()
33
+ }
34
+ h.persist(c)
35
+ }
36
+
37
+ func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) {
38
+ var body struct {
39
+ Old *string `json:"old"`
40
+ New *string `json:"new"`
41
+ Index *int `json:"index"`
42
+ Value *string `json:"value"`
43
+ }
44
+ if err := c.ShouldBindJSON(&body); err != nil {
45
+ c.JSON(400, gin.H{"error": "invalid body"})
46
+ return
47
+ }
48
+ if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) {
49
+ (*target)[*body.Index] = *body.Value
50
+ if after != nil {
51
+ after()
52
+ }
53
+ h.persist(c)
54
+ return
55
+ }
56
+ if body.Old != nil && body.New != nil {
57
+ for i := range *target {
58
+ if (*target)[i] == *body.Old {
59
+ (*target)[i] = *body.New
60
+ if after != nil {
61
+ after()
62
+ }
63
+ h.persist(c)
64
+ return
65
+ }
66
+ }
67
+ *target = append(*target, *body.New)
68
+ if after != nil {
69
+ after()
70
+ }
71
+ h.persist(c)
72
+ return
73
+ }
74
+ c.JSON(400, gin.H{"error": "missing fields"})
75
+ }
76
+
77
+ func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) {
78
+ if idxStr := c.Query("index"); idxStr != "" {
79
+ var idx int
80
+ _, err := fmt.Sscanf(idxStr, "%d", &idx)
81
+ if err == nil && idx >= 0 && idx < len(*target) {
82
+ *target = append((*target)[:idx], (*target)[idx+1:]...)
83
+ if after != nil {
84
+ after()
85
+ }
86
+ h.persist(c)
87
+ return
88
+ }
89
+ }
90
+ if val := strings.TrimSpace(c.Query("value")); val != "" {
91
+ out := make([]string, 0, len(*target))
92
+ for _, v := range *target {
93
+ if strings.TrimSpace(v) != val {
94
+ out = append(out, v)
95
+ }
96
+ }
97
+ *target = out
98
+ if after != nil {
99
+ after()
100
+ }
101
+ h.persist(c)
102
+ return
103
+ }
104
+ c.JSON(400, gin.H{"error": "missing index or value"})
105
+ }
106
+
107
+ // api-keys
108
+ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) }
109
+ func (h *Handler) PutAPIKeys(c *gin.Context) {
110
+ h.putStringList(c, func(v []string) {
111
+ h.cfg.APIKeys = append([]string(nil), v...)
112
+ h.cfg.Access.Providers = nil
113
+ }, nil)
114
+ }
115
+ func (h *Handler) PatchAPIKeys(c *gin.Context) {
116
+ h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
117
+ }
118
+ func (h *Handler) DeleteAPIKeys(c *gin.Context) {
119
+ h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
120
+ }
121
+
122
+ // gemini-api-key: []GeminiKey
123
+ func (h *Handler) GetGeminiKeys(c *gin.Context) {
124
+ c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey})
125
+ }
126
+ func (h *Handler) PutGeminiKeys(c *gin.Context) {
127
+ data, err := c.GetRawData()
128
+ if err != nil {
129
+ c.JSON(400, gin.H{"error": "failed to read body"})
130
+ return
131
+ }
132
+ var arr []config.GeminiKey
133
+ if err = json.Unmarshal(data, &arr); err != nil {
134
+ var obj struct {
135
+ Items []config.GeminiKey `json:"items"`
136
+ }
137
+ if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
138
+ c.JSON(400, gin.H{"error": "invalid body"})
139
+ return
140
+ }
141
+ arr = obj.Items
142
+ }
143
+ h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...)
144
+ h.cfg.SanitizeGeminiKeys()
145
+ h.persist(c)
146
+ }
147
+ func (h *Handler) PatchGeminiKey(c *gin.Context) {
148
+ type geminiKeyPatch struct {
149
+ APIKey *string `json:"api-key"`
150
+ Prefix *string `json:"prefix"`
151
+ BaseURL *string `json:"base-url"`
152
+ ProxyURL *string `json:"proxy-url"`
153
+ Headers *map[string]string `json:"headers"`
154
+ ExcludedModels *[]string `json:"excluded-models"`
155
+ }
156
+ var body struct {
157
+ Index *int `json:"index"`
158
+ Match *string `json:"match"`
159
+ Value *geminiKeyPatch `json:"value"`
160
+ }
161
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
162
+ c.JSON(400, gin.H{"error": "invalid body"})
163
+ return
164
+ }
165
+ targetIndex := -1
166
+ if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
167
+ targetIndex = *body.Index
168
+ }
169
+ if targetIndex == -1 && body.Match != nil {
170
+ match := strings.TrimSpace(*body.Match)
171
+ if match != "" {
172
+ for i := range h.cfg.GeminiKey {
173
+ if h.cfg.GeminiKey[i].APIKey == match {
174
+ targetIndex = i
175
+ break
176
+ }
177
+ }
178
+ }
179
+ }
180
+ if targetIndex == -1 {
181
+ c.JSON(404, gin.H{"error": "item not found"})
182
+ return
183
+ }
184
+
185
+ entry := h.cfg.GeminiKey[targetIndex]
186
+ if body.Value.APIKey != nil {
187
+ trimmed := strings.TrimSpace(*body.Value.APIKey)
188
+ if trimmed == "" {
189
+ h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
190
+ h.cfg.SanitizeGeminiKeys()
191
+ h.persist(c)
192
+ return
193
+ }
194
+ entry.APIKey = trimmed
195
+ }
196
+ if body.Value.Prefix != nil {
197
+ entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
198
+ }
199
+ if body.Value.BaseURL != nil {
200
+ entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
201
+ }
202
+ if body.Value.ProxyURL != nil {
203
+ entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
204
+ }
205
+ if body.Value.Headers != nil {
206
+ entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
207
+ }
208
+ if body.Value.ExcludedModels != nil {
209
+ entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
210
+ }
211
+ h.cfg.GeminiKey[targetIndex] = entry
212
+ h.cfg.SanitizeGeminiKeys()
213
+ h.persist(c)
214
+ }
215
+
216
+ func (h *Handler) DeleteGeminiKey(c *gin.Context) {
217
+ if val := strings.TrimSpace(c.Query("api-key")); val != "" {
218
+ out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
219
+ for _, v := range h.cfg.GeminiKey {
220
+ if v.APIKey != val {
221
+ out = append(out, v)
222
+ }
223
+ }
224
+ if len(out) != len(h.cfg.GeminiKey) {
225
+ h.cfg.GeminiKey = out
226
+ h.cfg.SanitizeGeminiKeys()
227
+ h.persist(c)
228
+ } else {
229
+ c.JSON(404, gin.H{"error": "item not found"})
230
+ }
231
+ return
232
+ }
233
+ if idxStr := c.Query("index"); idxStr != "" {
234
+ var idx int
235
+ if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) {
236
+ h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...)
237
+ h.cfg.SanitizeGeminiKeys()
238
+ h.persist(c)
239
+ return
240
+ }
241
+ }
242
+ c.JSON(400, gin.H{"error": "missing api-key or index"})
243
+ }
244
+
245
+ // claude-api-key: []ClaudeKey
246
+ func (h *Handler) GetClaudeKeys(c *gin.Context) {
247
+ c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey})
248
+ }
249
+ func (h *Handler) PutClaudeKeys(c *gin.Context) {
250
+ data, err := c.GetRawData()
251
+ if err != nil {
252
+ c.JSON(400, gin.H{"error": "failed to read body"})
253
+ return
254
+ }
255
+ var arr []config.ClaudeKey
256
+ if err = json.Unmarshal(data, &arr); err != nil {
257
+ var obj struct {
258
+ Items []config.ClaudeKey `json:"items"`
259
+ }
260
+ if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
261
+ c.JSON(400, gin.H{"error": "invalid body"})
262
+ return
263
+ }
264
+ arr = obj.Items
265
+ }
266
+ for i := range arr {
267
+ normalizeClaudeKey(&arr[i])
268
+ }
269
+ h.cfg.ClaudeKey = arr
270
+ h.cfg.SanitizeClaudeKeys()
271
+ h.persist(c)
272
+ }
273
+ func (h *Handler) PatchClaudeKey(c *gin.Context) {
274
+ type claudeKeyPatch struct {
275
+ APIKey *string `json:"api-key"`
276
+ Prefix *string `json:"prefix"`
277
+ BaseURL *string `json:"base-url"`
278
+ ProxyURL *string `json:"proxy-url"`
279
+ Models *[]config.ClaudeModel `json:"models"`
280
+ Headers *map[string]string `json:"headers"`
281
+ ExcludedModels *[]string `json:"excluded-models"`
282
+ }
283
+ var body struct {
284
+ Index *int `json:"index"`
285
+ Match *string `json:"match"`
286
+ Value *claudeKeyPatch `json:"value"`
287
+ }
288
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
289
+ c.JSON(400, gin.H{"error": "invalid body"})
290
+ return
291
+ }
292
+ targetIndex := -1
293
+ if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
294
+ targetIndex = *body.Index
295
+ }
296
+ if targetIndex == -1 && body.Match != nil {
297
+ match := strings.TrimSpace(*body.Match)
298
+ for i := range h.cfg.ClaudeKey {
299
+ if h.cfg.ClaudeKey[i].APIKey == match {
300
+ targetIndex = i
301
+ break
302
+ }
303
+ }
304
+ }
305
+ if targetIndex == -1 {
306
+ c.JSON(404, gin.H{"error": "item not found"})
307
+ return
308
+ }
309
+
310
+ entry := h.cfg.ClaudeKey[targetIndex]
311
+ if body.Value.APIKey != nil {
312
+ entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
313
+ }
314
+ if body.Value.Prefix != nil {
315
+ entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
316
+ }
317
+ if body.Value.BaseURL != nil {
318
+ entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
319
+ }
320
+ if body.Value.ProxyURL != nil {
321
+ entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
322
+ }
323
+ if body.Value.Models != nil {
324
+ entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
325
+ }
326
+ if body.Value.Headers != nil {
327
+ entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
328
+ }
329
+ if body.Value.ExcludedModels != nil {
330
+ entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
331
+ }
332
+ normalizeClaudeKey(&entry)
333
+ h.cfg.ClaudeKey[targetIndex] = entry
334
+ h.cfg.SanitizeClaudeKeys()
335
+ h.persist(c)
336
+ }
337
+
338
+ func (h *Handler) DeleteClaudeKey(c *gin.Context) {
339
+ if val := c.Query("api-key"); val != "" {
340
+ out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
341
+ for _, v := range h.cfg.ClaudeKey {
342
+ if v.APIKey != val {
343
+ out = append(out, v)
344
+ }
345
+ }
346
+ h.cfg.ClaudeKey = out
347
+ h.cfg.SanitizeClaudeKeys()
348
+ h.persist(c)
349
+ return
350
+ }
351
+ if idxStr := c.Query("index"); idxStr != "" {
352
+ var idx int
353
+ _, err := fmt.Sscanf(idxStr, "%d", &idx)
354
+ if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) {
355
+ h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...)
356
+ h.cfg.SanitizeClaudeKeys()
357
+ h.persist(c)
358
+ return
359
+ }
360
+ }
361
+ c.JSON(400, gin.H{"error": "missing api-key or index"})
362
+ }
363
+
364
+ // openai-compatibility: []OpenAICompatibility
365
+ func (h *Handler) GetOpenAICompat(c *gin.Context) {
366
+ c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)})
367
+ }
368
+ func (h *Handler) PutOpenAICompat(c *gin.Context) {
369
+ data, err := c.GetRawData()
370
+ if err != nil {
371
+ c.JSON(400, gin.H{"error": "failed to read body"})
372
+ return
373
+ }
374
+ var arr []config.OpenAICompatibility
375
+ if err = json.Unmarshal(data, &arr); err != nil {
376
+ var obj struct {
377
+ Items []config.OpenAICompatibility `json:"items"`
378
+ }
379
+ if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
380
+ c.JSON(400, gin.H{"error": "invalid body"})
381
+ return
382
+ }
383
+ arr = obj.Items
384
+ }
385
+ filtered := make([]config.OpenAICompatibility, 0, len(arr))
386
+ for i := range arr {
387
+ normalizeOpenAICompatibilityEntry(&arr[i])
388
+ if strings.TrimSpace(arr[i].BaseURL) != "" {
389
+ filtered = append(filtered, arr[i])
390
+ }
391
+ }
392
+ h.cfg.OpenAICompatibility = filtered
393
+ h.cfg.SanitizeOpenAICompatibility()
394
+ h.persist(c)
395
+ }
396
+ func (h *Handler) PatchOpenAICompat(c *gin.Context) {
397
+ type openAICompatPatch struct {
398
+ Name *string `json:"name"`
399
+ Prefix *string `json:"prefix"`
400
+ BaseURL *string `json:"base-url"`
401
+ APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
402
+ Models *[]config.OpenAICompatibilityModel `json:"models"`
403
+ Headers *map[string]string `json:"headers"`
404
+ }
405
+ var body struct {
406
+ Name *string `json:"name"`
407
+ Index *int `json:"index"`
408
+ Value *openAICompatPatch `json:"value"`
409
+ }
410
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
411
+ c.JSON(400, gin.H{"error": "invalid body"})
412
+ return
413
+ }
414
+ targetIndex := -1
415
+ if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
416
+ targetIndex = *body.Index
417
+ }
418
+ if targetIndex == -1 && body.Name != nil {
419
+ match := strings.TrimSpace(*body.Name)
420
+ for i := range h.cfg.OpenAICompatibility {
421
+ if h.cfg.OpenAICompatibility[i].Name == match {
422
+ targetIndex = i
423
+ break
424
+ }
425
+ }
426
+ }
427
+ if targetIndex == -1 {
428
+ c.JSON(404, gin.H{"error": "item not found"})
429
+ return
430
+ }
431
+
432
+ entry := h.cfg.OpenAICompatibility[targetIndex]
433
+ if body.Value.Name != nil {
434
+ entry.Name = strings.TrimSpace(*body.Value.Name)
435
+ }
436
+ if body.Value.Prefix != nil {
437
+ entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
438
+ }
439
+ if body.Value.BaseURL != nil {
440
+ trimmed := strings.TrimSpace(*body.Value.BaseURL)
441
+ if trimmed == "" {
442
+ h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
443
+ h.cfg.SanitizeOpenAICompatibility()
444
+ h.persist(c)
445
+ return
446
+ }
447
+ entry.BaseURL = trimmed
448
+ }
449
+ if body.Value.APIKeyEntries != nil {
450
+ entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
451
+ }
452
+ if body.Value.Models != nil {
453
+ entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
454
+ }
455
+ if body.Value.Headers != nil {
456
+ entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
457
+ }
458
+ normalizeOpenAICompatibilityEntry(&entry)
459
+ h.cfg.OpenAICompatibility[targetIndex] = entry
460
+ h.cfg.SanitizeOpenAICompatibility()
461
+ h.persist(c)
462
+ }
463
+
464
+ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
465
+ if name := c.Query("name"); name != "" {
466
+ out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
467
+ for _, v := range h.cfg.OpenAICompatibility {
468
+ if v.Name != name {
469
+ out = append(out, v)
470
+ }
471
+ }
472
+ h.cfg.OpenAICompatibility = out
473
+ h.cfg.SanitizeOpenAICompatibility()
474
+ h.persist(c)
475
+ return
476
+ }
477
+ if idxStr := c.Query("index"); idxStr != "" {
478
+ var idx int
479
+ _, err := fmt.Sscanf(idxStr, "%d", &idx)
480
+ if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) {
481
+ h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...)
482
+ h.cfg.SanitizeOpenAICompatibility()
483
+ h.persist(c)
484
+ return
485
+ }
486
+ }
487
+ c.JSON(400, gin.H{"error": "missing name or index"})
488
+ }
489
+
490
+ // oauth-excluded-models: map[string][]string
491
+ func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
492
+ c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
493
+ }
494
+
495
+ func (h *Handler) PutOAuthExcludedModels(c *gin.Context) {
496
+ data, err := c.GetRawData()
497
+ if err != nil {
498
+ c.JSON(400, gin.H{"error": "failed to read body"})
499
+ return
500
+ }
501
+ var entries map[string][]string
502
+ if err = json.Unmarshal(data, &entries); err != nil {
503
+ var wrapper struct {
504
+ Items map[string][]string `json:"items"`
505
+ }
506
+ if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
507
+ c.JSON(400, gin.H{"error": "invalid body"})
508
+ return
509
+ }
510
+ entries = wrapper.Items
511
+ }
512
+ h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries)
513
+ h.persist(c)
514
+ }
515
+
516
+ func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) {
517
+ var body struct {
518
+ Provider *string `json:"provider"`
519
+ Models []string `json:"models"`
520
+ }
521
+ if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil {
522
+ c.JSON(400, gin.H{"error": "invalid body"})
523
+ return
524
+ }
525
+ provider := strings.ToLower(strings.TrimSpace(*body.Provider))
526
+ if provider == "" {
527
+ c.JSON(400, gin.H{"error": "invalid provider"})
528
+ return
529
+ }
530
+ normalized := config.NormalizeExcludedModels(body.Models)
531
+ if len(normalized) == 0 {
532
+ if h.cfg.OAuthExcludedModels == nil {
533
+ c.JSON(404, gin.H{"error": "provider not found"})
534
+ return
535
+ }
536
+ if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
537
+ c.JSON(404, gin.H{"error": "provider not found"})
538
+ return
539
+ }
540
+ delete(h.cfg.OAuthExcludedModels, provider)
541
+ if len(h.cfg.OAuthExcludedModels) == 0 {
542
+ h.cfg.OAuthExcludedModels = nil
543
+ }
544
+ h.persist(c)
545
+ return
546
+ }
547
+ if h.cfg.OAuthExcludedModels == nil {
548
+ h.cfg.OAuthExcludedModels = make(map[string][]string)
549
+ }
550
+ h.cfg.OAuthExcludedModels[provider] = normalized
551
+ h.persist(c)
552
+ }
553
+
554
+ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
555
+ provider := strings.ToLower(strings.TrimSpace(c.Query("provider")))
556
+ if provider == "" {
557
+ c.JSON(400, gin.H{"error": "missing provider"})
558
+ return
559
+ }
560
+ if h.cfg.OAuthExcludedModels == nil {
561
+ c.JSON(404, gin.H{"error": "provider not found"})
562
+ return
563
+ }
564
+ if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
565
+ c.JSON(404, gin.H{"error": "provider not found"})
566
+ return
567
+ }
568
+ delete(h.cfg.OAuthExcludedModels, provider)
569
+ if len(h.cfg.OAuthExcludedModels) == 0 {
570
+ h.cfg.OAuthExcludedModels = nil
571
+ }
572
+ h.persist(c)
573
+ }
574
+
575
+ // codex-api-key: []CodexKey
576
+ func (h *Handler) GetCodexKeys(c *gin.Context) {
577
+ c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
578
+ }
579
+ func (h *Handler) PutCodexKeys(c *gin.Context) {
580
+ data, err := c.GetRawData()
581
+ if err != nil {
582
+ c.JSON(400, gin.H{"error": "failed to read body"})
583
+ return
584
+ }
585
+ var arr []config.CodexKey
586
+ if err = json.Unmarshal(data, &arr); err != nil {
587
+ var obj struct {
588
+ Items []config.CodexKey `json:"items"`
589
+ }
590
+ if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
591
+ c.JSON(400, gin.H{"error": "invalid body"})
592
+ return
593
+ }
594
+ arr = obj.Items
595
+ }
596
+ // Filter out codex entries with empty base-url (treat as removed)
597
+ filtered := make([]config.CodexKey, 0, len(arr))
598
+ for i := range arr {
599
+ entry := arr[i]
600
+ normalizeCodexKey(&entry)
601
+ if entry.BaseURL == "" {
602
+ continue
603
+ }
604
+ filtered = append(filtered, entry)
605
+ }
606
+ h.cfg.CodexKey = filtered
607
+ h.cfg.SanitizeCodexKeys()
608
+ h.persist(c)
609
+ }
610
+ func (h *Handler) PatchCodexKey(c *gin.Context) {
611
+ type codexKeyPatch struct {
612
+ APIKey *string `json:"api-key"`
613
+ Prefix *string `json:"prefix"`
614
+ BaseURL *string `json:"base-url"`
615
+ ProxyURL *string `json:"proxy-url"`
616
+ Models *[]config.CodexModel `json:"models"`
617
+ Headers *map[string]string `json:"headers"`
618
+ ExcludedModels *[]string `json:"excluded-models"`
619
+ }
620
+ var body struct {
621
+ Index *int `json:"index"`
622
+ Match *string `json:"match"`
623
+ Value *codexKeyPatch `json:"value"`
624
+ }
625
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
626
+ c.JSON(400, gin.H{"error": "invalid body"})
627
+ return
628
+ }
629
+ targetIndex := -1
630
+ if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
631
+ targetIndex = *body.Index
632
+ }
633
+ if targetIndex == -1 && body.Match != nil {
634
+ match := strings.TrimSpace(*body.Match)
635
+ for i := range h.cfg.CodexKey {
636
+ if h.cfg.CodexKey[i].APIKey == match {
637
+ targetIndex = i
638
+ break
639
+ }
640
+ }
641
+ }
642
+ if targetIndex == -1 {
643
+ c.JSON(404, gin.H{"error": "item not found"})
644
+ return
645
+ }
646
+
647
+ entry := h.cfg.CodexKey[targetIndex]
648
+ if body.Value.APIKey != nil {
649
+ entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
650
+ }
651
+ if body.Value.Prefix != nil {
652
+ entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
653
+ }
654
+ if body.Value.BaseURL != nil {
655
+ trimmed := strings.TrimSpace(*body.Value.BaseURL)
656
+ if trimmed == "" {
657
+ h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
658
+ h.cfg.SanitizeCodexKeys()
659
+ h.persist(c)
660
+ return
661
+ }
662
+ entry.BaseURL = trimmed
663
+ }
664
+ if body.Value.ProxyURL != nil {
665
+ entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
666
+ }
667
+ if body.Value.Models != nil {
668
+ entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
669
+ }
670
+ if body.Value.Headers != nil {
671
+ entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
672
+ }
673
+ if body.Value.ExcludedModels != nil {
674
+ entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
675
+ }
676
+ normalizeCodexKey(&entry)
677
+ h.cfg.CodexKey[targetIndex] = entry
678
+ h.cfg.SanitizeCodexKeys()
679
+ h.persist(c)
680
+ }
681
+
682
+ func (h *Handler) DeleteCodexKey(c *gin.Context) {
683
+ if val := c.Query("api-key"); val != "" {
684
+ out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
685
+ for _, v := range h.cfg.CodexKey {
686
+ if v.APIKey != val {
687
+ out = append(out, v)
688
+ }
689
+ }
690
+ h.cfg.CodexKey = out
691
+ h.cfg.SanitizeCodexKeys()
692
+ h.persist(c)
693
+ return
694
+ }
695
+ if idxStr := c.Query("index"); idxStr != "" {
696
+ var idx int
697
+ _, err := fmt.Sscanf(idxStr, "%d", &idx)
698
+ if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) {
699
+ h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...)
700
+ h.cfg.SanitizeCodexKeys()
701
+ h.persist(c)
702
+ return
703
+ }
704
+ }
705
+ c.JSON(400, gin.H{"error": "missing api-key or index"})
706
+ }
707
+
708
+ func normalizeOpenAICompatibilityEntry(entry *config.OpenAICompatibility) {
709
+ if entry == nil {
710
+ return
711
+ }
712
+ // Trim base-url; empty base-url indicates provider should be removed by sanitization
713
+ entry.BaseURL = strings.TrimSpace(entry.BaseURL)
714
+ entry.Headers = config.NormalizeHeaders(entry.Headers)
715
+ existing := make(map[string]struct{}, len(entry.APIKeyEntries))
716
+ for i := range entry.APIKeyEntries {
717
+ trimmed := strings.TrimSpace(entry.APIKeyEntries[i].APIKey)
718
+ entry.APIKeyEntries[i].APIKey = trimmed
719
+ if trimmed != "" {
720
+ existing[trimmed] = struct{}{}
721
+ }
722
+ }
723
+ }
724
+
725
+ func normalizedOpenAICompatibilityEntries(entries []config.OpenAICompatibility) []config.OpenAICompatibility {
726
+ if len(entries) == 0 {
727
+ return nil
728
+ }
729
+ out := make([]config.OpenAICompatibility, len(entries))
730
+ for i := range entries {
731
+ copyEntry := entries[i]
732
+ if len(copyEntry.APIKeyEntries) > 0 {
733
+ copyEntry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), copyEntry.APIKeyEntries...)
734
+ }
735
+ normalizeOpenAICompatibilityEntry(&copyEntry)
736
+ out[i] = copyEntry
737
+ }
738
+ return out
739
+ }
740
+
741
+ func normalizeClaudeKey(entry *config.ClaudeKey) {
742
+ if entry == nil {
743
+ return
744
+ }
745
+ entry.APIKey = strings.TrimSpace(entry.APIKey)
746
+ entry.BaseURL = strings.TrimSpace(entry.BaseURL)
747
+ entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
748
+ entry.Headers = config.NormalizeHeaders(entry.Headers)
749
+ entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
750
+ if len(entry.Models) == 0 {
751
+ return
752
+ }
753
+ normalized := make([]config.ClaudeModel, 0, len(entry.Models))
754
+ for i := range entry.Models {
755
+ model := entry.Models[i]
756
+ model.Name = strings.TrimSpace(model.Name)
757
+ model.Alias = strings.TrimSpace(model.Alias)
758
+ if model.Name == "" && model.Alias == "" {
759
+ continue
760
+ }
761
+ normalized = append(normalized, model)
762
+ }
763
+ entry.Models = normalized
764
+ }
765
+
766
+ func normalizeCodexKey(entry *config.CodexKey) {
767
+ if entry == nil {
768
+ return
769
+ }
770
+ entry.APIKey = strings.TrimSpace(entry.APIKey)
771
+ entry.Prefix = strings.TrimSpace(entry.Prefix)
772
+ entry.BaseURL = strings.TrimSpace(entry.BaseURL)
773
+ entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
774
+ entry.Headers = config.NormalizeHeaders(entry.Headers)
775
+ entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
776
+ if len(entry.Models) == 0 {
777
+ return
778
+ }
779
+ normalized := make([]config.CodexModel, 0, len(entry.Models))
780
+ for i := range entry.Models {
781
+ model := entry.Models[i]
782
+ model.Name = strings.TrimSpace(model.Name)
783
+ model.Alias = strings.TrimSpace(model.Alias)
784
+ if model.Name == "" && model.Alias == "" {
785
+ continue
786
+ }
787
+ normalized = append(normalized, model)
788
+ }
789
+ entry.Models = normalized
790
+ }
791
+
792
+ // GetAmpCode returns the complete ampcode configuration.
793
+ func (h *Handler) GetAmpCode(c *gin.Context) {
794
+ if h == nil || h.cfg == nil {
795
+ c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
796
+ return
797
+ }
798
+ c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
799
+ }
800
+
801
+ // GetAmpUpstreamURL returns the ampcode upstream URL.
802
+ func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
803
+ if h == nil || h.cfg == nil {
804
+ c.JSON(200, gin.H{"upstream-url": ""})
805
+ return
806
+ }
807
+ c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
808
+ }
809
+
810
+ // PutAmpUpstreamURL updates the ampcode upstream URL.
811
+ func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
812
+ h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
813
+ }
814
+
815
+ // DeleteAmpUpstreamURL clears the ampcode upstream URL.
816
+ func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
817
+ h.cfg.AmpCode.UpstreamURL = ""
818
+ h.persist(c)
819
+ }
820
+
821
+ // GetAmpUpstreamAPIKey returns the ampcode upstream API key.
822
+ func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
823
+ if h == nil || h.cfg == nil {
824
+ c.JSON(200, gin.H{"upstream-api-key": ""})
825
+ return
826
+ }
827
+ c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
828
+ }
829
+
830
+ // PutAmpUpstreamAPIKey updates the ampcode upstream API key.
831
+ func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
832
+ h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
833
+ }
834
+
835
+ // DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
836
+ func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
837
+ h.cfg.AmpCode.UpstreamAPIKey = ""
838
+ h.persist(c)
839
+ }
840
+
841
+ // GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
842
+ func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
843
+ if h == nil || h.cfg == nil {
844
+ c.JSON(200, gin.H{"restrict-management-to-localhost": true})
845
+ return
846
+ }
847
+ c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
848
+ }
849
+
850
+ // PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
851
+ func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
852
+ h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
853
+ }
854
+
855
+ // GetAmpModelMappings returns the ampcode model mappings.
856
+ func (h *Handler) GetAmpModelMappings(c *gin.Context) {
857
+ if h == nil || h.cfg == nil {
858
+ c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
859
+ return
860
+ }
861
+ c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
862
+ }
863
+
864
+ // PutAmpModelMappings replaces all ampcode model mappings.
865
+ func (h *Handler) PutAmpModelMappings(c *gin.Context) {
866
+ var body struct {
867
+ Value []config.AmpModelMapping `json:"value"`
868
+ }
869
+ if err := c.ShouldBindJSON(&body); err != nil {
870
+ c.JSON(400, gin.H{"error": "invalid body"})
871
+ return
872
+ }
873
+ h.cfg.AmpCode.ModelMappings = body.Value
874
+ h.persist(c)
875
+ }
876
+
877
+ // PatchAmpModelMappings adds or updates model mappings.
878
+ func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
879
+ var body struct {
880
+ Value []config.AmpModelMapping `json:"value"`
881
+ }
882
+ if err := c.ShouldBindJSON(&body); err != nil {
883
+ c.JSON(400, gin.H{"error": "invalid body"})
884
+ return
885
+ }
886
+
887
+ existing := make(map[string]int)
888
+ for i, m := range h.cfg.AmpCode.ModelMappings {
889
+ existing[strings.TrimSpace(m.From)] = i
890
+ }
891
+
892
+ for _, newMapping := range body.Value {
893
+ from := strings.TrimSpace(newMapping.From)
894
+ if idx, ok := existing[from]; ok {
895
+ h.cfg.AmpCode.ModelMappings[idx] = newMapping
896
+ } else {
897
+ h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
898
+ existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
899
+ }
900
+ }
901
+ h.persist(c)
902
+ }
903
+
904
+ // DeleteAmpModelMappings removes specified model mappings by "from" field.
905
+ func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
906
+ var body struct {
907
+ Value []string `json:"value"`
908
+ }
909
+ if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
910
+ h.cfg.AmpCode.ModelMappings = nil
911
+ h.persist(c)
912
+ return
913
+ }
914
+
915
+ toRemove := make(map[string]bool)
916
+ for _, from := range body.Value {
917
+ toRemove[strings.TrimSpace(from)] = true
918
+ }
919
+
920
+ newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
921
+ for _, m := range h.cfg.AmpCode.ModelMappings {
922
+ if !toRemove[strings.TrimSpace(m.From)] {
923
+ newMappings = append(newMappings, m)
924
+ }
925
+ }
926
+ h.cfg.AmpCode.ModelMappings = newMappings
927
+ h.persist(c)
928
+ }
929
+
930
+ // GetAmpForceModelMappings returns whether model mappings are forced.
931
+ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
932
+ if h == nil || h.cfg == nil {
933
+ c.JSON(200, gin.H{"force-model-mappings": false})
934
+ return
935
+ }
936
+ c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
937
+ }
938
+
939
+ // PutAmpForceModelMappings updates the force model mappings setting.
940
+ func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
941
+ h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
942
+ }
943
+
944
+ // GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
945
+ func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
946
+ if h == nil || h.cfg == nil {
947
+ c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
948
+ return
949
+ }
950
+ c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
951
+ }
952
+
953
+ // PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
954
+ func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
955
+ var body struct {
956
+ Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
957
+ }
958
+ if err := c.ShouldBindJSON(&body); err != nil {
959
+ c.JSON(400, gin.H{"error": "invalid body"})
960
+ return
961
+ }
962
+ // Normalize entries: trim whitespace, filter empty
963
+ normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
964
+ h.cfg.AmpCode.UpstreamAPIKeys = normalized
965
+ h.persist(c)
966
+ }
967
+
968
+ // PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
969
+ // Matching is done by upstream-api-key value.
970
+ func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
971
+ var body struct {
972
+ Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
973
+ }
974
+ if err := c.ShouldBindJSON(&body); err != nil {
975
+ c.JSON(400, gin.H{"error": "invalid body"})
976
+ return
977
+ }
978
+
979
+ existing := make(map[string]int)
980
+ for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
981
+ existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
982
+ }
983
+
984
+ for _, newEntry := range body.Value {
985
+ upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
986
+ if upstreamKey == "" {
987
+ continue
988
+ }
989
+ normalizedEntry := config.AmpUpstreamAPIKeyEntry{
990
+ UpstreamAPIKey: upstreamKey,
991
+ APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
992
+ }
993
+ if idx, ok := existing[upstreamKey]; ok {
994
+ h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
995
+ } else {
996
+ h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
997
+ existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
998
+ }
999
+ }
1000
+ h.persist(c)
1001
+ }
1002
+
1003
+ // DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
1004
+ // Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
1005
+ // If "value" is an empty array, clears all entries.
1006
+ // If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
1007
+ func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
1008
+ var body struct {
1009
+ Value []string `json:"value"`
1010
+ }
1011
+ if err := c.ShouldBindJSON(&body); err != nil {
1012
+ c.JSON(400, gin.H{"error": "invalid body"})
1013
+ return
1014
+ }
1015
+
1016
+ if body.Value == nil {
1017
+ c.JSON(400, gin.H{"error": "missing value"})
1018
+ return
1019
+ }
1020
+
1021
+ // Empty array means clear all
1022
+ if len(body.Value) == 0 {
1023
+ h.cfg.AmpCode.UpstreamAPIKeys = nil
1024
+ h.persist(c)
1025
+ return
1026
+ }
1027
+
1028
+ toRemove := make(map[string]bool)
1029
+ for _, key := range body.Value {
1030
+ trimmed := strings.TrimSpace(key)
1031
+ if trimmed == "" {
1032
+ continue
1033
+ }
1034
+ toRemove[trimmed] = true
1035
+ }
1036
+ if len(toRemove) == 0 {
1037
+ c.JSON(400, gin.H{"error": "empty value"})
1038
+ return
1039
+ }
1040
+
1041
+ newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
1042
+ for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
1043
+ if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
1044
+ newEntries = append(newEntries, entry)
1045
+ }
1046
+ }
1047
+ h.cfg.AmpCode.UpstreamAPIKeys = newEntries
1048
+ h.persist(c)
1049
+ }
1050
+
1051
+ // normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
1052
+ func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
1053
+ if len(entries) == 0 {
1054
+ return nil
1055
+ }
1056
+ out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
1057
+ for _, entry := range entries {
1058
+ upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
1059
+ if upstreamKey == "" {
1060
+ continue
1061
+ }
1062
+ apiKeys := normalizeAPIKeysList(entry.APIKeys)
1063
+ out = append(out, config.AmpUpstreamAPIKeyEntry{
1064
+ UpstreamAPIKey: upstreamKey,
1065
+ APIKeys: apiKeys,
1066
+ })
1067
+ }
1068
+ if len(out) == 0 {
1069
+ return nil
1070
+ }
1071
+ return out
1072
+ }
1073
+
1074
+ // normalizeAPIKeysList trims and filters empty strings from a list of API keys.
1075
+ func normalizeAPIKeysList(keys []string) []string {
1076
+ if len(keys) == 0 {
1077
+ return nil
1078
+ }
1079
+ out := make([]string, 0, len(keys))
1080
+ for _, k := range keys {
1081
+ trimmed := strings.TrimSpace(k)
1082
+ if trimmed != "" {
1083
+ out = append(out, trimmed)
1084
+ }
1085
+ }
1086
+ if len(out) == 0 {
1087
+ return nil
1088
+ }
1089
+ return out
1090
+ }
internal/api/handlers/management/handler.go ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package management provides the management API handlers and middleware
2
+ // for configuring the server and managing auth files.
3
+ package management
4
+
5
+ import (
6
+ "crypto/subtle"
7
+ "fmt"
8
+ "net/http"
9
+ "os"
10
+ "path/filepath"
11
+ "strings"
12
+ "sync"
13
+ "time"
14
+
15
+ "github.com/gin-gonic/gin"
16
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
17
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
18
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
19
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
20
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
21
+ "golang.org/x/crypto/bcrypt"
22
+ )
23
+
24
+ type attemptInfo struct {
25
+ count int
26
+ blockedUntil time.Time
27
+ }
28
+
29
+ // Handler aggregates config reference, persistence path and helpers.
30
+ type Handler struct {
31
+ cfg *config.Config
32
+ configFilePath string
33
+ mu sync.Mutex
34
+ attemptsMu sync.Mutex
35
+ failedAttempts map[string]*attemptInfo // keyed by client IP
36
+ authManager *coreauth.Manager
37
+ usageStats *usage.RequestStatistics
38
+ tokenStore coreauth.Store
39
+ localPassword string
40
+ allowRemoteOverride bool
41
+ envSecret string
42
+ logDir string
43
+ }
44
+
45
+ // NewHandler creates a new management handler instance.
46
+ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Manager) *Handler {
47
+ envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
48
+ envSecret = strings.TrimSpace(envSecret)
49
+
50
+ return &Handler{
51
+ cfg: cfg,
52
+ configFilePath: configFilePath,
53
+ failedAttempts: make(map[string]*attemptInfo),
54
+ authManager: manager,
55
+ usageStats: usage.GetRequestStatistics(),
56
+ tokenStore: sdkAuth.GetTokenStore(),
57
+ allowRemoteOverride: envSecret != "",
58
+ envSecret: envSecret,
59
+ }
60
+ }
61
+
62
+ // NewHandler creates a new management handler instance.
63
+ func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
64
+ return NewHandler(cfg, "", manager)
65
+ }
66
+
67
+ // SetConfig updates the in-memory config reference when the server hot-reloads.
68
+ func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
69
+
70
+ // SetAuthManager updates the auth manager reference used by management endpoints.
71
+ func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager }
72
+
73
+ // SetUsageStatistics allows replacing the usage statistics reference.
74
+ func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
75
+
76
+ // SetLocalPassword configures the runtime-local password accepted for localhost requests.
77
+ func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
78
+
79
+ // SetLogDirectory updates the directory where main.log should be looked up.
80
+ func (h *Handler) SetLogDirectory(dir string) {
81
+ if dir == "" {
82
+ return
83
+ }
84
+ if !filepath.IsAbs(dir) {
85
+ if abs, err := filepath.Abs(dir); err == nil {
86
+ dir = abs
87
+ }
88
+ }
89
+ h.logDir = dir
90
+ }
91
+
92
+ // Middleware enforces access control for management endpoints.
93
+ // All requests (local and remote) require a valid management key.
94
+ // Additionally, remote access requires allow-remote-management=true.
95
+ func (h *Handler) Middleware() gin.HandlerFunc {
96
+ const maxFailures = 5
97
+ const banDuration = 30 * time.Minute
98
+
99
+ return func(c *gin.Context) {
100
+ c.Header("X-CPA-VERSION", buildinfo.Version)
101
+ c.Header("X-CPA-COMMIT", buildinfo.Commit)
102
+ c.Header("X-CPA-BUILD-DATE", buildinfo.BuildDate)
103
+
104
+ clientIP := c.ClientIP()
105
+ localClient := clientIP == "127.0.0.1" || clientIP == "::1"
106
+ cfg := h.cfg
107
+ var (
108
+ allowRemote bool
109
+ secretHash string
110
+ )
111
+ if cfg != nil {
112
+ allowRemote = cfg.RemoteManagement.AllowRemote
113
+ secretHash = cfg.RemoteManagement.SecretKey
114
+ }
115
+ if h.allowRemoteOverride {
116
+ allowRemote = true
117
+ }
118
+ envSecret := h.envSecret
119
+
120
+ fail := func() {}
121
+ if !localClient {
122
+ h.attemptsMu.Lock()
123
+ ai := h.failedAttempts[clientIP]
124
+ if ai != nil {
125
+ if !ai.blockedUntil.IsZero() {
126
+ if time.Now().Before(ai.blockedUntil) {
127
+ remaining := time.Until(ai.blockedUntil).Round(time.Second)
128
+ h.attemptsMu.Unlock()
129
+ c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
130
+ return
131
+ }
132
+ // Ban expired, reset state
133
+ ai.blockedUntil = time.Time{}
134
+ ai.count = 0
135
+ }
136
+ }
137
+ h.attemptsMu.Unlock()
138
+
139
+ if !allowRemote {
140
+ c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
141
+ return
142
+ }
143
+
144
+ fail = func() {
145
+ h.attemptsMu.Lock()
146
+ aip := h.failedAttempts[clientIP]
147
+ if aip == nil {
148
+ aip = &attemptInfo{}
149
+ h.failedAttempts[clientIP] = aip
150
+ }
151
+ aip.count++
152
+ if aip.count >= maxFailures {
153
+ aip.blockedUntil = time.Now().Add(banDuration)
154
+ aip.count = 0
155
+ }
156
+ h.attemptsMu.Unlock()
157
+ }
158
+ }
159
+ if secretHash == "" && envSecret == "" {
160
+ c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
161
+ return
162
+ }
163
+
164
+ // Accept either Authorization: Bearer <key> or X-Management-Key
165
+ var provided string
166
+ if ah := c.GetHeader("Authorization"); ah != "" {
167
+ parts := strings.SplitN(ah, " ", 2)
168
+ if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
169
+ provided = parts[1]
170
+ } else {
171
+ provided = ah
172
+ }
173
+ }
174
+ if provided == "" {
175
+ provided = c.GetHeader("X-Management-Key")
176
+ }
177
+
178
+ if provided == "" {
179
+ if !localClient {
180
+ fail()
181
+ }
182
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
183
+ return
184
+ }
185
+
186
+ if localClient {
187
+ if lp := h.localPassword; lp != "" {
188
+ if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
189
+ c.Next()
190
+ return
191
+ }
192
+ }
193
+ }
194
+
195
+ if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
196
+ if !localClient {
197
+ h.attemptsMu.Lock()
198
+ if ai := h.failedAttempts[clientIP]; ai != nil {
199
+ ai.count = 0
200
+ ai.blockedUntil = time.Time{}
201
+ }
202
+ h.attemptsMu.Unlock()
203
+ }
204
+ c.Next()
205
+ return
206
+ }
207
+
208
+ if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
209
+ if !localClient {
210
+ fail()
211
+ }
212
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
213
+ return
214
+ }
215
+
216
+ if !localClient {
217
+ h.attemptsMu.Lock()
218
+ if ai := h.failedAttempts[clientIP]; ai != nil {
219
+ ai.count = 0
220
+ ai.blockedUntil = time.Time{}
221
+ }
222
+ h.attemptsMu.Unlock()
223
+ }
224
+
225
+ c.Next()
226
+ }
227
+ }
228
+
229
+ // persist saves the current in-memory config to disk.
230
+ func (h *Handler) persist(c *gin.Context) bool {
231
+ h.mu.Lock()
232
+ defer h.mu.Unlock()
233
+ // Preserve comments when writing
234
+ if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil {
235
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)})
236
+ return false
237
+ }
238
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
239
+ return true
240
+ }
241
+
242
+ // Helper methods for simple types
243
+ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
244
+ var body struct {
245
+ Value *bool `json:"value"`
246
+ }
247
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
248
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
249
+ return
250
+ }
251
+ set(*body.Value)
252
+ h.persist(c)
253
+ }
254
+
255
+ func (h *Handler) updateIntField(c *gin.Context, set func(int)) {
256
+ var body struct {
257
+ Value *int `json:"value"`
258
+ }
259
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
260
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
261
+ return
262
+ }
263
+ set(*body.Value)
264
+ h.persist(c)
265
+ }
266
+
267
+ func (h *Handler) updateStringField(c *gin.Context, set func(string)) {
268
+ var body struct {
269
+ Value *string `json:"value"`
270
+ }
271
+ if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
272
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
273
+ return
274
+ }
275
+ set(*body.Value)
276
+ h.persist(c)
277
+ }
internal/api/handlers/management/logs.go ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "bufio"
5
+ "fmt"
6
+ "math"
7
+ "net/http"
8
+ "os"
9
+ "path/filepath"
10
+ "sort"
11
+ "strconv"
12
+ "strings"
13
+ "time"
14
+
15
+ "github.com/gin-gonic/gin"
16
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
17
+ )
18
+
19
+ const (
20
+ defaultLogFileName = "main.log"
21
+ logScannerInitialBuffer = 64 * 1024
22
+ logScannerMaxBuffer = 8 * 1024 * 1024
23
+ )
24
+
25
+ // GetLogs returns log lines with optional incremental loading.
26
+ func (h *Handler) GetLogs(c *gin.Context) {
27
+ if h == nil {
28
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
29
+ return
30
+ }
31
+ if h.cfg == nil {
32
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
33
+ return
34
+ }
35
+ if !h.cfg.LoggingToFile {
36
+ c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
37
+ return
38
+ }
39
+
40
+ logDir := h.logDirectory()
41
+ if strings.TrimSpace(logDir) == "" {
42
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
43
+ return
44
+ }
45
+
46
+ files, err := h.collectLogFiles(logDir)
47
+ if err != nil {
48
+ if os.IsNotExist(err) {
49
+ cutoff := parseCutoff(c.Query("after"))
50
+ c.JSON(http.StatusOK, gin.H{
51
+ "lines": []string{},
52
+ "line-count": 0,
53
+ "latest-timestamp": cutoff,
54
+ })
55
+ return
56
+ }
57
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log files: %v", err)})
58
+ return
59
+ }
60
+
61
+ limit, errLimit := parseLimit(c.Query("limit"))
62
+ if errLimit != nil {
63
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
64
+ return
65
+ }
66
+
67
+ cutoff := parseCutoff(c.Query("after"))
68
+ acc := newLogAccumulator(cutoff, limit)
69
+ for i := range files {
70
+ if errProcess := acc.consumeFile(files[i]); errProcess != nil {
71
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
72
+ return
73
+ }
74
+ }
75
+
76
+ lines, total, latest := acc.result()
77
+ if latest == 0 || latest < cutoff {
78
+ latest = cutoff
79
+ }
80
+ c.JSON(http.StatusOK, gin.H{
81
+ "lines": lines,
82
+ "line-count": total,
83
+ "latest-timestamp": latest,
84
+ })
85
+ }
86
+
87
+ // DeleteLogs removes all rotated log files and truncates the active log.
88
+ func (h *Handler) DeleteLogs(c *gin.Context) {
89
+ if h == nil {
90
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
91
+ return
92
+ }
93
+ if h.cfg == nil {
94
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
95
+ return
96
+ }
97
+ if !h.cfg.LoggingToFile {
98
+ c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
99
+ return
100
+ }
101
+
102
+ dir := h.logDirectory()
103
+ if strings.TrimSpace(dir) == "" {
104
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
105
+ return
106
+ }
107
+
108
+ entries, err := os.ReadDir(dir)
109
+ if err != nil {
110
+ if os.IsNotExist(err) {
111
+ c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
112
+ return
113
+ }
114
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
115
+ return
116
+ }
117
+
118
+ removed := 0
119
+ for _, entry := range entries {
120
+ if entry.IsDir() {
121
+ continue
122
+ }
123
+ name := entry.Name()
124
+ fullPath := filepath.Join(dir, name)
125
+ if name == defaultLogFileName {
126
+ if errTrunc := os.Truncate(fullPath, 0); errTrunc != nil && !os.IsNotExist(errTrunc) {
127
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to truncate log file: %v", errTrunc)})
128
+ return
129
+ }
130
+ continue
131
+ }
132
+ if isRotatedLogFile(name) {
133
+ if errRemove := os.Remove(fullPath); errRemove != nil && !os.IsNotExist(errRemove) {
134
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to remove %s: %v", name, errRemove)})
135
+ return
136
+ }
137
+ removed++
138
+ }
139
+ }
140
+
141
+ c.JSON(http.StatusOK, gin.H{
142
+ "success": true,
143
+ "message": "Logs cleared successfully",
144
+ "removed": removed,
145
+ })
146
+ }
147
+
148
+ // GetRequestErrorLogs lists error request log files when RequestLog is disabled.
149
+ // It returns an empty list when RequestLog is enabled.
150
+ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
151
+ if h == nil {
152
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
153
+ return
154
+ }
155
+ if h.cfg == nil {
156
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
157
+ return
158
+ }
159
+ if h.cfg.RequestLog {
160
+ c.JSON(http.StatusOK, gin.H{"files": []any{}})
161
+ return
162
+ }
163
+
164
+ dir := h.logDirectory()
165
+ if strings.TrimSpace(dir) == "" {
166
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
167
+ return
168
+ }
169
+
170
+ entries, err := os.ReadDir(dir)
171
+ if err != nil {
172
+ if os.IsNotExist(err) {
173
+ c.JSON(http.StatusOK, gin.H{"files": []any{}})
174
+ return
175
+ }
176
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
177
+ return
178
+ }
179
+
180
+ type errorLog struct {
181
+ Name string `json:"name"`
182
+ Size int64 `json:"size"`
183
+ Modified int64 `json:"modified"`
184
+ }
185
+
186
+ files := make([]errorLog, 0, len(entries))
187
+ for _, entry := range entries {
188
+ if entry.IsDir() {
189
+ continue
190
+ }
191
+ name := entry.Name()
192
+ if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
193
+ continue
194
+ }
195
+ info, errInfo := entry.Info()
196
+ if errInfo != nil {
197
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
198
+ return
199
+ }
200
+ files = append(files, errorLog{
201
+ Name: name,
202
+ Size: info.Size(),
203
+ Modified: info.ModTime().Unix(),
204
+ })
205
+ }
206
+
207
+ sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
208
+
209
+ c.JSON(http.StatusOK, gin.H{"files": files})
210
+ }
211
+
212
+ // GetRequestLogByID finds and downloads a request log file by its request ID.
213
+ // The ID is matched against the suffix of log file names (format: *-{requestID}.log).
214
+ func (h *Handler) GetRequestLogByID(c *gin.Context) {
215
+ if h == nil {
216
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
217
+ return
218
+ }
219
+ if h.cfg == nil {
220
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
221
+ return
222
+ }
223
+
224
+ dir := h.logDirectory()
225
+ if strings.TrimSpace(dir) == "" {
226
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
227
+ return
228
+ }
229
+
230
+ requestID := strings.TrimSpace(c.Param("id"))
231
+ if requestID == "" {
232
+ requestID = strings.TrimSpace(c.Query("id"))
233
+ }
234
+ if requestID == "" {
235
+ c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
236
+ return
237
+ }
238
+ if strings.ContainsAny(requestID, "/\\") {
239
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
240
+ return
241
+ }
242
+
243
+ entries, err := os.ReadDir(dir)
244
+ if err != nil {
245
+ if os.IsNotExist(err) {
246
+ c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
247
+ return
248
+ }
249
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
250
+ return
251
+ }
252
+
253
+ suffix := "-" + requestID + ".log"
254
+ var matchedFile string
255
+ for _, entry := range entries {
256
+ if entry.IsDir() {
257
+ continue
258
+ }
259
+ name := entry.Name()
260
+ if strings.HasSuffix(name, suffix) {
261
+ matchedFile = name
262
+ break
263
+ }
264
+ }
265
+
266
+ if matchedFile == "" {
267
+ c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
268
+ return
269
+ }
270
+
271
+ dirAbs, errAbs := filepath.Abs(dir)
272
+ if errAbs != nil {
273
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
274
+ return
275
+ }
276
+ fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
277
+ prefix := dirAbs + string(os.PathSeparator)
278
+ if !strings.HasPrefix(fullPath, prefix) {
279
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
280
+ return
281
+ }
282
+
283
+ info, errStat := os.Stat(fullPath)
284
+ if errStat != nil {
285
+ if os.IsNotExist(errStat) {
286
+ c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
287
+ return
288
+ }
289
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
290
+ return
291
+ }
292
+ if info.IsDir() {
293
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
294
+ return
295
+ }
296
+
297
+ c.FileAttachment(fullPath, matchedFile)
298
+ }
299
+
300
+ // DownloadRequestErrorLog downloads a specific error request log file by name.
301
+ func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
302
+ if h == nil {
303
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
304
+ return
305
+ }
306
+ if h.cfg == nil {
307
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
308
+ return
309
+ }
310
+
311
+ dir := h.logDirectory()
312
+ if strings.TrimSpace(dir) == "" {
313
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
314
+ return
315
+ }
316
+
317
+ name := strings.TrimSpace(c.Param("name"))
318
+ if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
319
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
320
+ return
321
+ }
322
+ if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
323
+ c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
324
+ return
325
+ }
326
+
327
+ dirAbs, errAbs := filepath.Abs(dir)
328
+ if errAbs != nil {
329
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
330
+ return
331
+ }
332
+ fullPath := filepath.Clean(filepath.Join(dirAbs, name))
333
+ prefix := dirAbs + string(os.PathSeparator)
334
+ if !strings.HasPrefix(fullPath, prefix) {
335
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
336
+ return
337
+ }
338
+
339
+ info, errStat := os.Stat(fullPath)
340
+ if errStat != nil {
341
+ if os.IsNotExist(errStat) {
342
+ c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
343
+ return
344
+ }
345
+ c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
346
+ return
347
+ }
348
+ if info.IsDir() {
349
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
350
+ return
351
+ }
352
+
353
+ c.FileAttachment(fullPath, name)
354
+ }
355
+
356
+ func (h *Handler) logDirectory() string {
357
+ if h == nil {
358
+ return ""
359
+ }
360
+ if h.logDir != "" {
361
+ return h.logDir
362
+ }
363
+ if base := util.WritablePath(); base != "" {
364
+ return filepath.Join(base, "logs")
365
+ }
366
+ if h.configFilePath != "" {
367
+ dir := filepath.Dir(h.configFilePath)
368
+ if dir != "" && dir != "." {
369
+ return filepath.Join(dir, "logs")
370
+ }
371
+ }
372
+ return "logs"
373
+ }
374
+
375
+ func (h *Handler) collectLogFiles(dir string) ([]string, error) {
376
+ entries, err := os.ReadDir(dir)
377
+ if err != nil {
378
+ return nil, err
379
+ }
380
+ type candidate struct {
381
+ path string
382
+ order int64
383
+ }
384
+ cands := make([]candidate, 0, len(entries))
385
+ for _, entry := range entries {
386
+ if entry.IsDir() {
387
+ continue
388
+ }
389
+ name := entry.Name()
390
+ if name == defaultLogFileName {
391
+ cands = append(cands, candidate{path: filepath.Join(dir, name), order: 0})
392
+ continue
393
+ }
394
+ if order, ok := rotationOrder(name); ok {
395
+ cands = append(cands, candidate{path: filepath.Join(dir, name), order: order})
396
+ }
397
+ }
398
+ if len(cands) == 0 {
399
+ return []string{}, nil
400
+ }
401
+ sort.Slice(cands, func(i, j int) bool { return cands[i].order < cands[j].order })
402
+ paths := make([]string, 0, len(cands))
403
+ for i := len(cands) - 1; i >= 0; i-- {
404
+ paths = append(paths, cands[i].path)
405
+ }
406
+ return paths, nil
407
+ }
408
+
409
+ type logAccumulator struct {
410
+ cutoff int64
411
+ limit int
412
+ lines []string
413
+ total int
414
+ latest int64
415
+ include bool
416
+ }
417
+
418
+ func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
419
+ capacity := 256
420
+ if limit > 0 && limit < capacity {
421
+ capacity = limit
422
+ }
423
+ return &logAccumulator{
424
+ cutoff: cutoff,
425
+ limit: limit,
426
+ lines: make([]string, 0, capacity),
427
+ }
428
+ }
429
+
430
+ func (acc *logAccumulator) consumeFile(path string) error {
431
+ file, err := os.Open(path)
432
+ if err != nil {
433
+ if os.IsNotExist(err) {
434
+ return nil
435
+ }
436
+ return err
437
+ }
438
+ defer func() {
439
+ _ = file.Close()
440
+ }()
441
+
442
+ scanner := bufio.NewScanner(file)
443
+ buf := make([]byte, 0, logScannerInitialBuffer)
444
+ scanner.Buffer(buf, logScannerMaxBuffer)
445
+ for scanner.Scan() {
446
+ acc.addLine(scanner.Text())
447
+ }
448
+ if errScan := scanner.Err(); errScan != nil {
449
+ return errScan
450
+ }
451
+ return nil
452
+ }
453
+
454
+ func (acc *logAccumulator) addLine(raw string) {
455
+ line := strings.TrimRight(raw, "\r")
456
+ acc.total++
457
+ ts := parseTimestamp(line)
458
+ if ts > acc.latest {
459
+ acc.latest = ts
460
+ }
461
+ if ts > 0 {
462
+ acc.include = acc.cutoff == 0 || ts > acc.cutoff
463
+ if acc.cutoff == 0 || acc.include {
464
+ acc.append(line)
465
+ }
466
+ return
467
+ }
468
+ if acc.cutoff == 0 || acc.include {
469
+ acc.append(line)
470
+ }
471
+ }
472
+
473
+ func (acc *logAccumulator) append(line string) {
474
+ acc.lines = append(acc.lines, line)
475
+ if acc.limit > 0 && len(acc.lines) > acc.limit {
476
+ acc.lines = acc.lines[len(acc.lines)-acc.limit:]
477
+ }
478
+ }
479
+
480
+ func (acc *logAccumulator) result() ([]string, int, int64) {
481
+ if acc.lines == nil {
482
+ acc.lines = []string{}
483
+ }
484
+ return acc.lines, acc.total, acc.latest
485
+ }
486
+
487
+ func parseCutoff(raw string) int64 {
488
+ value := strings.TrimSpace(raw)
489
+ if value == "" {
490
+ return 0
491
+ }
492
+ ts, err := strconv.ParseInt(value, 10, 64)
493
+ if err != nil || ts <= 0 {
494
+ return 0
495
+ }
496
+ return ts
497
+ }
498
+
499
+ func parseLimit(raw string) (int, error) {
500
+ value := strings.TrimSpace(raw)
501
+ if value == "" {
502
+ return 0, nil
503
+ }
504
+ limit, err := strconv.Atoi(value)
505
+ if err != nil {
506
+ return 0, fmt.Errorf("must be a positive integer")
507
+ }
508
+ if limit <= 0 {
509
+ return 0, fmt.Errorf("must be greater than zero")
510
+ }
511
+ return limit, nil
512
+ }
513
+
514
+ func parseTimestamp(line string) int64 {
515
+ if strings.HasPrefix(line, "[") {
516
+ line = line[1:]
517
+ }
518
+ if len(line) < 19 {
519
+ return 0
520
+ }
521
+ candidate := line[:19]
522
+ t, err := time.ParseInLocation("2006-01-02 15:04:05", candidate, time.Local)
523
+ if err != nil {
524
+ return 0
525
+ }
526
+ return t.Unix()
527
+ }
528
+
529
+ func isRotatedLogFile(name string) bool {
530
+ if _, ok := rotationOrder(name); ok {
531
+ return true
532
+ }
533
+ return false
534
+ }
535
+
536
+ func rotationOrder(name string) (int64, bool) {
537
+ if order, ok := numericRotationOrder(name); ok {
538
+ return order, true
539
+ }
540
+ if order, ok := timestampRotationOrder(name); ok {
541
+ return order, true
542
+ }
543
+ return 0, false
544
+ }
545
+
546
+ func numericRotationOrder(name string) (int64, bool) {
547
+ if !strings.HasPrefix(name, defaultLogFileName+".") {
548
+ return 0, false
549
+ }
550
+ suffix := strings.TrimPrefix(name, defaultLogFileName+".")
551
+ if suffix == "" {
552
+ return 0, false
553
+ }
554
+ n, err := strconv.Atoi(suffix)
555
+ if err != nil {
556
+ return 0, false
557
+ }
558
+ return int64(n), true
559
+ }
560
+
561
+ func timestampRotationOrder(name string) (int64, bool) {
562
+ ext := filepath.Ext(defaultLogFileName)
563
+ base := strings.TrimSuffix(defaultLogFileName, ext)
564
+ if base == "" {
565
+ return 0, false
566
+ }
567
+ prefix := base + "-"
568
+ if !strings.HasPrefix(name, prefix) {
569
+ return 0, false
570
+ }
571
+ clean := strings.TrimPrefix(name, prefix)
572
+ if strings.HasSuffix(clean, ".gz") {
573
+ clean = strings.TrimSuffix(clean, ".gz")
574
+ }
575
+ if ext != "" {
576
+ if !strings.HasSuffix(clean, ext) {
577
+ return 0, false
578
+ }
579
+ clean = strings.TrimSuffix(clean, ext)
580
+ }
581
+ if clean == "" {
582
+ return 0, false
583
+ }
584
+ if idx := strings.IndexByte(clean, '.'); idx != -1 {
585
+ clean = clean[:idx]
586
+ }
587
+ parsed, err := time.ParseInLocation("2006-01-02T15-04-05", clean, time.Local)
588
+ if err != nil {
589
+ return 0, false
590
+ }
591
+ return math.MaxInt64 - parsed.Unix(), true
592
+ }
internal/api/handlers/management/oauth_callback.go ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "errors"
5
+ "net/http"
6
+ "net/url"
7
+ "strings"
8
+
9
+ "github.com/gin-gonic/gin"
10
+ )
11
+
12
+ type oauthCallbackRequest struct {
13
+ Provider string `json:"provider"`
14
+ RedirectURL string `json:"redirect_url"`
15
+ Code string `json:"code"`
16
+ State string `json:"state"`
17
+ Error string `json:"error"`
18
+ }
19
+
20
+ func (h *Handler) PostOAuthCallback(c *gin.Context) {
21
+ if h == nil || h.cfg == nil {
22
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
23
+ return
24
+ }
25
+
26
+ var req oauthCallbackRequest
27
+ if err := c.ShouldBindJSON(&req); err != nil {
28
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
29
+ return
30
+ }
31
+
32
+ canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
33
+ if err != nil {
34
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
35
+ return
36
+ }
37
+
38
+ state := strings.TrimSpace(req.State)
39
+ code := strings.TrimSpace(req.Code)
40
+ errMsg := strings.TrimSpace(req.Error)
41
+
42
+ if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
43
+ u, errParse := url.Parse(rawRedirect)
44
+ if errParse != nil {
45
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
46
+ return
47
+ }
48
+ q := u.Query()
49
+ if state == "" {
50
+ state = strings.TrimSpace(q.Get("state"))
51
+ }
52
+ if code == "" {
53
+ code = strings.TrimSpace(q.Get("code"))
54
+ }
55
+ if errMsg == "" {
56
+ errMsg = strings.TrimSpace(q.Get("error"))
57
+ if errMsg == "" {
58
+ errMsg = strings.TrimSpace(q.Get("error_description"))
59
+ }
60
+ }
61
+ }
62
+
63
+ if state == "" {
64
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
65
+ return
66
+ }
67
+ if err := ValidateOAuthState(state); err != nil {
68
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
69
+ return
70
+ }
71
+ if code == "" && errMsg == "" {
72
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
73
+ return
74
+ }
75
+
76
+ sessionProvider, sessionStatus, ok := GetOAuthSession(state)
77
+ if !ok {
78
+ c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
79
+ return
80
+ }
81
+ if sessionStatus != "" {
82
+ c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
83
+ return
84
+ }
85
+ if !strings.EqualFold(sessionProvider, canonicalProvider) {
86
+ c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
87
+ return
88
+ }
89
+
90
+ if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
91
+ if errors.Is(errWrite, errOAuthSessionNotPending) {
92
+ c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
93
+ return
94
+ }
95
+ c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
96
+ return
97
+ }
98
+
99
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
100
+ }
internal/api/handlers/management/oauth_sessions.go ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "encoding/json"
5
+ "errors"
6
+ "fmt"
7
+ "os"
8
+ "path/filepath"
9
+ "strings"
10
+ "sync"
11
+ "time"
12
+ )
13
+
14
+ const (
15
+ oauthSessionTTL = 10 * time.Minute
16
+ maxOAuthStateLength = 128
17
+ )
18
+
19
+ var (
20
+ errInvalidOAuthState = errors.New("invalid oauth state")
21
+ errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
22
+ errOAuthSessionNotPending = errors.New("oauth session is not pending")
23
+ )
24
+
25
+ type oauthSession struct {
26
+ Provider string
27
+ Status string
28
+ CreatedAt time.Time
29
+ ExpiresAt time.Time
30
+ }
31
+
32
+ type oauthSessionStore struct {
33
+ mu sync.RWMutex
34
+ ttl time.Duration
35
+ sessions map[string]oauthSession
36
+ }
37
+
38
+ func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
39
+ if ttl <= 0 {
40
+ ttl = oauthSessionTTL
41
+ }
42
+ return &oauthSessionStore{
43
+ ttl: ttl,
44
+ sessions: make(map[string]oauthSession),
45
+ }
46
+ }
47
+
48
+ func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
49
+ for state, session := range s.sessions {
50
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
51
+ delete(s.sessions, state)
52
+ }
53
+ }
54
+ }
55
+
56
+ func (s *oauthSessionStore) Register(state, provider string) {
57
+ state = strings.TrimSpace(state)
58
+ provider = strings.ToLower(strings.TrimSpace(provider))
59
+ if state == "" || provider == "" {
60
+ return
61
+ }
62
+ now := time.Now()
63
+
64
+ s.mu.Lock()
65
+ defer s.mu.Unlock()
66
+
67
+ s.purgeExpiredLocked(now)
68
+ s.sessions[state] = oauthSession{
69
+ Provider: provider,
70
+ Status: "",
71
+ CreatedAt: now,
72
+ ExpiresAt: now.Add(s.ttl),
73
+ }
74
+ }
75
+
76
+ func (s *oauthSessionStore) SetError(state, message string) {
77
+ state = strings.TrimSpace(state)
78
+ message = strings.TrimSpace(message)
79
+ if state == "" {
80
+ return
81
+ }
82
+ if message == "" {
83
+ message = "Authentication failed"
84
+ }
85
+ now := time.Now()
86
+
87
+ s.mu.Lock()
88
+ defer s.mu.Unlock()
89
+
90
+ s.purgeExpiredLocked(now)
91
+ session, ok := s.sessions[state]
92
+ if !ok {
93
+ return
94
+ }
95
+ session.Status = message
96
+ session.ExpiresAt = now.Add(s.ttl)
97
+ s.sessions[state] = session
98
+ }
99
+
100
+ func (s *oauthSessionStore) Complete(state string) {
101
+ state = strings.TrimSpace(state)
102
+ if state == "" {
103
+ return
104
+ }
105
+ now := time.Now()
106
+
107
+ s.mu.Lock()
108
+ defer s.mu.Unlock()
109
+
110
+ s.purgeExpiredLocked(now)
111
+ delete(s.sessions, state)
112
+ }
113
+
114
+ func (s *oauthSessionStore) CompleteProvider(provider string) int {
115
+ provider = strings.ToLower(strings.TrimSpace(provider))
116
+ if provider == "" {
117
+ return 0
118
+ }
119
+ now := time.Now()
120
+
121
+ s.mu.Lock()
122
+ defer s.mu.Unlock()
123
+
124
+ s.purgeExpiredLocked(now)
125
+ removed := 0
126
+ for state, session := range s.sessions {
127
+ if strings.EqualFold(session.Provider, provider) {
128
+ delete(s.sessions, state)
129
+ removed++
130
+ }
131
+ }
132
+ return removed
133
+ }
134
+
135
+ func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
136
+ state = strings.TrimSpace(state)
137
+ now := time.Now()
138
+
139
+ s.mu.Lock()
140
+ defer s.mu.Unlock()
141
+
142
+ s.purgeExpiredLocked(now)
143
+ session, ok := s.sessions[state]
144
+ return session, ok
145
+ }
146
+
147
+ func (s *oauthSessionStore) IsPending(state, provider string) bool {
148
+ state = strings.TrimSpace(state)
149
+ provider = strings.ToLower(strings.TrimSpace(provider))
150
+ now := time.Now()
151
+
152
+ s.mu.Lock()
153
+ defer s.mu.Unlock()
154
+
155
+ s.purgeExpiredLocked(now)
156
+ session, ok := s.sessions[state]
157
+ if !ok {
158
+ return false
159
+ }
160
+ if session.Status != "" {
161
+ if !strings.EqualFold(session.Provider, "kiro") {
162
+ return false
163
+ }
164
+ if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
165
+ return false
166
+ }
167
+ }
168
+ if provider == "" {
169
+ return true
170
+ }
171
+ return strings.EqualFold(session.Provider, provider)
172
+ }
173
+
174
+ var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
175
+
176
+ func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
177
+
178
+ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
179
+
180
+ func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
181
+
182
+ func CompleteOAuthSessionsByProvider(provider string) int {
183
+ return oauthSessions.CompleteProvider(provider)
184
+ }
185
+
186
+ func GetOAuthSession(state string) (provider string, status string, ok bool) {
187
+ session, ok := oauthSessions.Get(state)
188
+ if !ok {
189
+ return "", "", false
190
+ }
191
+ return session.Provider, session.Status, true
192
+ }
193
+
194
+ func IsOAuthSessionPending(state, provider string) bool {
195
+ return oauthSessions.IsPending(state, provider)
196
+ }
197
+
198
+ func ValidateOAuthState(state string) error {
199
+ trimmed := strings.TrimSpace(state)
200
+ if trimmed == "" {
201
+ return fmt.Errorf("%w: empty", errInvalidOAuthState)
202
+ }
203
+ if len(trimmed) > maxOAuthStateLength {
204
+ return fmt.Errorf("%w: too long", errInvalidOAuthState)
205
+ }
206
+ if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
207
+ return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
208
+ }
209
+ if strings.Contains(trimmed, "..") {
210
+ return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
211
+ }
212
+ for _, r := range trimmed {
213
+ switch {
214
+ case r >= 'a' && r <= 'z':
215
+ case r >= 'A' && r <= 'Z':
216
+ case r >= '0' && r <= '9':
217
+ case r == '-' || r == '_' || r == '.':
218
+ default:
219
+ return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
220
+ }
221
+ }
222
+ return nil
223
+ }
224
+
225
+ func NormalizeOAuthProvider(provider string) (string, error) {
226
+ switch strings.ToLower(strings.TrimSpace(provider)) {
227
+ case "anthropic", "claude":
228
+ return "anthropic", nil
229
+ case "codex", "openai":
230
+ return "codex", nil
231
+ case "gemini", "google":
232
+ return "gemini", nil
233
+ case "iflow", "i-flow":
234
+ return "iflow", nil
235
+ case "antigravity", "anti-gravity":
236
+ return "antigravity", nil
237
+ case "qwen":
238
+ return "qwen", nil
239
+ case "kiro":
240
+ return "kiro", nil
241
+ default:
242
+ return "", errUnsupportedOAuthFlow
243
+ }
244
+ }
245
+
246
+ type oauthCallbackFilePayload struct {
247
+ Code string `json:"code"`
248
+ State string `json:"state"`
249
+ Error string `json:"error"`
250
+ }
251
+
252
+ func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
253
+ if strings.TrimSpace(authDir) == "" {
254
+ return "", fmt.Errorf("auth dir is empty")
255
+ }
256
+ canonicalProvider, err := NormalizeOAuthProvider(provider)
257
+ if err != nil {
258
+ return "", err
259
+ }
260
+ if err := ValidateOAuthState(state); err != nil {
261
+ return "", err
262
+ }
263
+
264
+ fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
265
+ filePath := filepath.Join(authDir, fileName)
266
+ payload := oauthCallbackFilePayload{
267
+ Code: strings.TrimSpace(code),
268
+ State: strings.TrimSpace(state),
269
+ Error: strings.TrimSpace(errorMessage),
270
+ }
271
+ data, err := json.Marshal(payload)
272
+ if err != nil {
273
+ return "", fmt.Errorf("marshal oauth callback payload: %w", err)
274
+ }
275
+ if err := os.WriteFile(filePath, data, 0o600); err != nil {
276
+ return "", fmt.Errorf("write oauth callback file: %w", err)
277
+ }
278
+ return filePath, nil
279
+ }
280
+
281
+ func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
282
+ canonicalProvider, err := NormalizeOAuthProvider(provider)
283
+ if err != nil {
284
+ return "", err
285
+ }
286
+ if !IsOAuthSessionPending(state, canonicalProvider) {
287
+ return "", errOAuthSessionNotPending
288
+ }
289
+ return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
290
+ }
internal/api/handlers/management/quota.go ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import "github.com/gin-gonic/gin"
4
+
5
+ // Quota exceeded toggles
6
+ func (h *Handler) GetSwitchProject(c *gin.Context) {
7
+ c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject})
8
+ }
9
+ func (h *Handler) PutSwitchProject(c *gin.Context) {
10
+ h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v })
11
+ }
12
+
13
+ func (h *Handler) GetSwitchPreviewModel(c *gin.Context) {
14
+ c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel})
15
+ }
16
+ func (h *Handler) PutSwitchPreviewModel(c *gin.Context) {
17
+ h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v })
18
+ }
internal/api/handlers/management/usage.go ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "encoding/json"
5
+ "net/http"
6
+ "time"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
10
+ )
11
+
12
+ type usageExportPayload struct {
13
+ Version int `json:"version"`
14
+ ExportedAt time.Time `json:"exported_at"`
15
+ Usage usage.StatisticsSnapshot `json:"usage"`
16
+ }
17
+
18
+ type usageImportPayload struct {
19
+ Version int `json:"version"`
20
+ Usage usage.StatisticsSnapshot `json:"usage"`
21
+ }
22
+
23
+ // GetUsageStatistics returns the in-memory request statistics snapshot.
24
+ func (h *Handler) GetUsageStatistics(c *gin.Context) {
25
+ var snapshot usage.StatisticsSnapshot
26
+ if h != nil && h.usageStats != nil {
27
+ snapshot = h.usageStats.Snapshot()
28
+ }
29
+ c.JSON(http.StatusOK, gin.H{
30
+ "usage": snapshot,
31
+ "failed_requests": snapshot.FailureCount,
32
+ })
33
+ }
34
+
35
+ // ExportUsageStatistics returns a complete usage snapshot for backup/migration.
36
+ func (h *Handler) ExportUsageStatistics(c *gin.Context) {
37
+ var snapshot usage.StatisticsSnapshot
38
+ if h != nil && h.usageStats != nil {
39
+ snapshot = h.usageStats.Snapshot()
40
+ }
41
+ c.JSON(http.StatusOK, usageExportPayload{
42
+ Version: 1,
43
+ ExportedAt: time.Now().UTC(),
44
+ Usage: snapshot,
45
+ })
46
+ }
47
+
48
+ // ImportUsageStatistics merges a previously exported usage snapshot into memory.
49
+ func (h *Handler) ImportUsageStatistics(c *gin.Context) {
50
+ if h == nil || h.usageStats == nil {
51
+ c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
52
+ return
53
+ }
54
+
55
+ data, err := c.GetRawData()
56
+ if err != nil {
57
+ c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
58
+ return
59
+ }
60
+
61
+ var payload usageImportPayload
62
+ if err := json.Unmarshal(data, &payload); err != nil {
63
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
64
+ return
65
+ }
66
+ if payload.Version != 0 && payload.Version != 1 {
67
+ c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
68
+ return
69
+ }
70
+
71
+ result := h.usageStats.MergeSnapshot(payload.Usage)
72
+ snapshot := h.usageStats.Snapshot()
73
+ c.JSON(http.StatusOK, gin.H{
74
+ "added": result.Added,
75
+ "skipped": result.Skipped,
76
+ "total_requests": snapshot.TotalRequests,
77
+ "failed_requests": snapshot.FailureCount,
78
+ })
79
+ }
internal/api/handlers/management/vertex_import.go ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package management
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "fmt"
7
+ "io"
8
+ "net/http"
9
+ "strings"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
13
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
14
+ )
15
+
16
+ // ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
17
+ func (h *Handler) ImportVertexCredential(c *gin.Context) {
18
+ if h == nil || h.cfg == nil {
19
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
20
+ return
21
+ }
22
+ if h.cfg.AuthDir == "" {
23
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
24
+ return
25
+ }
26
+
27
+ fileHeader, err := c.FormFile("file")
28
+ if err != nil {
29
+ c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
30
+ return
31
+ }
32
+
33
+ file, err := fileHeader.Open()
34
+ if err != nil {
35
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
36
+ return
37
+ }
38
+ defer file.Close()
39
+
40
+ data, err := io.ReadAll(file)
41
+ if err != nil {
42
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
43
+ return
44
+ }
45
+
46
+ var serviceAccount map[string]any
47
+ if err := json.Unmarshal(data, &serviceAccount); err != nil {
48
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
49
+ return
50
+ }
51
+
52
+ normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
53
+ if err != nil {
54
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
55
+ return
56
+ }
57
+ serviceAccount = normalizedSA
58
+
59
+ projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
60
+ if projectID == "" {
61
+ c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
62
+ return
63
+ }
64
+ email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
65
+
66
+ location := strings.TrimSpace(c.PostForm("location"))
67
+ if location == "" {
68
+ location = strings.TrimSpace(c.Query("location"))
69
+ }
70
+ if location == "" {
71
+ location = "us-central1"
72
+ }
73
+
74
+ fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
75
+ label := labelForVertex(projectID, email)
76
+ storage := &vertex.VertexCredentialStorage{
77
+ ServiceAccount: serviceAccount,
78
+ ProjectID: projectID,
79
+ Email: email,
80
+ Location: location,
81
+ Type: "vertex",
82
+ }
83
+ metadata := map[string]any{
84
+ "service_account": serviceAccount,
85
+ "project_id": projectID,
86
+ "email": email,
87
+ "location": location,
88
+ "type": "vertex",
89
+ "label": label,
90
+ }
91
+ record := &coreauth.Auth{
92
+ ID: fileName,
93
+ Provider: "vertex",
94
+ FileName: fileName,
95
+ Storage: storage,
96
+ Label: label,
97
+ Metadata: metadata,
98
+ }
99
+
100
+ ctx := context.Background()
101
+ if reqCtx := c.Request.Context(); reqCtx != nil {
102
+ ctx = reqCtx
103
+ }
104
+ savedPath, err := h.saveTokenRecord(ctx, record)
105
+ if err != nil {
106
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
107
+ return
108
+ }
109
+
110
+ c.JSON(http.StatusOK, gin.H{
111
+ "status": "ok",
112
+ "auth-file": savedPath,
113
+ "project_id": projectID,
114
+ "email": email,
115
+ "location": location,
116
+ })
117
+ }
118
+
119
+ func valueAsString(v any) string {
120
+ if v == nil {
121
+ return ""
122
+ }
123
+ switch t := v.(type) {
124
+ case string:
125
+ return t
126
+ default:
127
+ return fmt.Sprint(t)
128
+ }
129
+ }
130
+
131
+ func sanitizeVertexFilePart(s string) string {
132
+ out := strings.TrimSpace(s)
133
+ replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
134
+ for i := 0; i < len(replacers); i += 2 {
135
+ out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
136
+ }
137
+ if out == "" {
138
+ return "vertex"
139
+ }
140
+ return out
141
+ }
142
+
143
+ func labelForVertex(projectID, email string) string {
144
+ p := strings.TrimSpace(projectID)
145
+ e := strings.TrimSpace(email)
146
+ if p != "" && e != "" {
147
+ return fmt.Sprintf("%s (%s)", p, e)
148
+ }
149
+ if p != "" {
150
+ return p
151
+ }
152
+ if e != "" {
153
+ return e
154
+ }
155
+ return "vertex"
156
+ }
internal/api/middleware/request_logging.go ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package middleware provides HTTP middleware components for the CLI Proxy API server.
2
+ // This file contains the request logging middleware that captures comprehensive
3
+ // request and response data when enabled through configuration.
4
+ package middleware
5
+
6
+ import (
7
+ "bytes"
8
+ "io"
9
+ "net/http"
10
+ "strings"
11
+
12
+ "github.com/gin-gonic/gin"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
14
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
15
+ )
16
+
17
+ // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
18
+ // It captures detailed information about the request and response, including headers and body,
19
+ // and uses the provided RequestLogger to record this data. When logging is disabled in the
20
+ // logger, it still captures data so that upstream errors can be persisted.
21
+ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
22
+ return func(c *gin.Context) {
23
+ if logger == nil {
24
+ c.Next()
25
+ return
26
+ }
27
+
28
+ if c.Request.Method == http.MethodGet {
29
+ c.Next()
30
+ return
31
+ }
32
+
33
+ path := c.Request.URL.Path
34
+ if !shouldLogRequest(path) {
35
+ c.Next()
36
+ return
37
+ }
38
+
39
+ // Capture request information
40
+ requestInfo, err := captureRequestInfo(c)
41
+ if err != nil {
42
+ // Log error but continue processing
43
+ // In a real implementation, you might want to use a proper logger here
44
+ c.Next()
45
+ return
46
+ }
47
+
48
+ // Create response writer wrapper
49
+ wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
50
+ if !logger.IsEnabled() {
51
+ wrapper.logOnErrorOnly = true
52
+ }
53
+ c.Writer = wrapper
54
+
55
+ // Process the request
56
+ c.Next()
57
+
58
+ // Finalize logging after request processing
59
+ if err = wrapper.Finalize(c); err != nil {
60
+ // Log error but don't interrupt the response
61
+ // In a real implementation, you might want to use a proper logger here
62
+ }
63
+ }
64
+ }
65
+
66
+ // captureRequestInfo extracts relevant information from the incoming HTTP request.
67
+ // It captures the URL, method, headers, and body. The request body is read and then
68
+ // restored so that it can be processed by subsequent handlers.
69
+ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
70
+ // Capture URL with sensitive query parameters masked
71
+ maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
72
+ url := c.Request.URL.Path
73
+ if maskedQuery != "" {
74
+ url += "?" + maskedQuery
75
+ }
76
+
77
+ // Capture method
78
+ method := c.Request.Method
79
+
80
+ // Capture headers
81
+ headers := make(map[string][]string)
82
+ for key, values := range c.Request.Header {
83
+ headers[key] = values
84
+ }
85
+
86
+ // Capture request body
87
+ var body []byte
88
+ if c.Request.Body != nil {
89
+ // Read the body
90
+ bodyBytes, err := io.ReadAll(c.Request.Body)
91
+ if err != nil {
92
+ return nil, err
93
+ }
94
+
95
+ // Restore the body for the actual request processing
96
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
97
+ body = bodyBytes
98
+ }
99
+
100
+ return &RequestInfo{
101
+ URL: url,
102
+ Method: method,
103
+ Headers: headers,
104
+ Body: body,
105
+ RequestID: logging.GetGinRequestID(c),
106
+ }, nil
107
+ }
108
+
109
+ // shouldLogRequest determines whether the request should be logged.
110
+ // It skips management endpoints to avoid leaking secrets but allows
111
+ // all other routes, including module-provided ones, to honor request-log.
112
+ func shouldLogRequest(path string) bool {
113
+ if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
114
+ return false
115
+ }
116
+
117
+ if strings.HasPrefix(path, "/api") {
118
+ return strings.HasPrefix(path, "/api/provider")
119
+ }
120
+
121
+ return true
122
+ }
internal/api/middleware/response_writer.go ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package middleware provides Gin HTTP middleware for the CLI Proxy API server.
2
+ // It includes a sophisticated response writer wrapper designed to capture and log request and response data,
3
+ // including support for streaming responses, without impacting latency.
4
+ package middleware
5
+
6
+ import (
7
+ "bytes"
8
+ "net/http"
9
+ "strings"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
14
+ )
15
+
16
+ // RequestInfo holds essential details of an incoming HTTP request for logging purposes.
17
+ type RequestInfo struct {
18
+ URL string // URL is the request URL.
19
+ Method string // Method is the HTTP method (e.g., GET, POST).
20
+ Headers map[string][]string // Headers contains the request headers.
21
+ Body []byte // Body is the raw request body.
22
+ RequestID string // RequestID is the unique identifier for the request.
23
+ }
24
+
25
+ // ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
26
+ // It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
27
+ type ResponseWriterWrapper struct {
28
+ gin.ResponseWriter
29
+ body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
30
+ isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
31
+ streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
32
+ chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
33
+ streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
34
+ logger logging.RequestLogger // logger is the instance of the request logger service.
35
+ requestInfo *RequestInfo // requestInfo holds the details of the original request.
36
+ statusCode int // statusCode stores the HTTP status code of the response.
37
+ headers map[string][]string // headers stores the response headers.
38
+ logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
39
+ }
40
+
41
+ // NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
42
+ // It takes the original gin.ResponseWriter, a logger instance, and request information.
43
+ //
44
+ // Parameters:
45
+ // - w: The original gin.ResponseWriter to wrap.
46
+ // - logger: The logging service to use for recording requests.
47
+ // - requestInfo: The pre-captured information about the incoming request.
48
+ //
49
+ // Returns:
50
+ // - A pointer to a new ResponseWriterWrapper.
51
+ func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
52
+ return &ResponseWriterWrapper{
53
+ ResponseWriter: w,
54
+ body: &bytes.Buffer{},
55
+ logger: logger,
56
+ requestInfo: requestInfo,
57
+ headers: make(map[string][]string),
58
+ }
59
+ }
60
+
61
+ // Write wraps the underlying ResponseWriter's Write method to capture response data.
62
+ // For non-streaming responses, it writes to an internal buffer. For streaming responses,
63
+ // it sends data chunks to a non-blocking channel for asynchronous logging.
64
+ // CRITICAL: This method prioritizes writing to the client to ensure zero latency,
65
+ // handling logging operations subsequently.
66
+ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
67
+ // Ensure headers are captured before first write
68
+ // This is critical because Write() may trigger WriteHeader() internally
69
+ w.ensureHeadersCaptured()
70
+
71
+ // CRITICAL: Write to client first (zero latency)
72
+ n, err := w.ResponseWriter.Write(data)
73
+
74
+ // THEN: Handle logging based on response type
75
+ if w.isStreaming && w.chunkChannel != nil {
76
+ // For streaming responses: Send to async logging channel (non-blocking)
77
+ select {
78
+ case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
79
+ default: // Channel full, skip logging to avoid blocking
80
+ }
81
+ return n, err
82
+ }
83
+
84
+ if w.shouldBufferResponseBody() {
85
+ w.body.Write(data)
86
+ }
87
+
88
+ return n, err
89
+ }
90
+
91
+ func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
92
+ if w.logger != nil && w.logger.IsEnabled() {
93
+ return true
94
+ }
95
+ if !w.logOnErrorOnly {
96
+ return false
97
+ }
98
+ status := w.statusCode
99
+ if status == 0 {
100
+ if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
101
+ status = statusWriter.Status()
102
+ } else {
103
+ status = http.StatusOK
104
+ }
105
+ }
106
+ return status >= http.StatusBadRequest
107
+ }
108
+
109
+ // WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
110
+ // Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
111
+ // bypass Write() and would be missing from request logs.
112
+ func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
113
+ w.ensureHeadersCaptured()
114
+
115
+ // CRITICAL: Write to client first (zero latency)
116
+ n, err := w.ResponseWriter.WriteString(data)
117
+
118
+ // THEN: Capture for logging
119
+ if w.isStreaming && w.chunkChannel != nil {
120
+ select {
121
+ case w.chunkChannel <- []byte(data):
122
+ default:
123
+ }
124
+ return n, err
125
+ }
126
+
127
+ if w.shouldBufferResponseBody() {
128
+ w.body.WriteString(data)
129
+ }
130
+ return n, err
131
+ }
132
+
133
+ // WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
134
+ // It captures the status code, detects if the response is streaming based on the Content-Type header,
135
+ // and initializes the appropriate logging mechanism (standard or streaming).
136
+ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
137
+ w.statusCode = statusCode
138
+
139
+ // Capture response headers using the new method
140
+ w.captureCurrentHeaders()
141
+
142
+ // Detect streaming based on Content-Type
143
+ contentType := w.ResponseWriter.Header().Get("Content-Type")
144
+ w.isStreaming = w.detectStreaming(contentType)
145
+
146
+ // If streaming, initialize streaming log writer
147
+ if w.isStreaming && w.logger.IsEnabled() {
148
+ streamWriter, err := w.logger.LogStreamingRequest(
149
+ w.requestInfo.URL,
150
+ w.requestInfo.Method,
151
+ w.requestInfo.Headers,
152
+ w.requestInfo.Body,
153
+ w.requestInfo.RequestID,
154
+ )
155
+ if err == nil {
156
+ w.streamWriter = streamWriter
157
+ w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
158
+ doneChan := make(chan struct{})
159
+ w.streamDone = doneChan
160
+
161
+ // Start async chunk processor
162
+ go w.processStreamingChunks(doneChan)
163
+
164
+ // Write status immediately
165
+ _ = streamWriter.WriteStatus(statusCode, w.headers)
166
+ }
167
+ }
168
+
169
+ // Call original WriteHeader
170
+ w.ResponseWriter.WriteHeader(statusCode)
171
+ }
172
+
173
+ // ensureHeadersCaptured is a helper function to make sure response headers are captured.
174
+ // It is safe to call this method multiple times; it will always refresh the headers
175
+ // with the latest state from the underlying ResponseWriter.
176
+ func (w *ResponseWriterWrapper) ensureHeadersCaptured() {
177
+ // Always capture the current headers to ensure we have the latest state
178
+ w.captureCurrentHeaders()
179
+ }
180
+
181
+ // captureCurrentHeaders reads all headers from the underlying ResponseWriter and stores them
182
+ // in the wrapper's headers map. It creates copies of the header values to prevent race conditions.
183
+ func (w *ResponseWriterWrapper) captureCurrentHeaders() {
184
+ // Initialize headers map if needed
185
+ if w.headers == nil {
186
+ w.headers = make(map[string][]string)
187
+ }
188
+
189
+ // Capture all current headers from the underlying ResponseWriter
190
+ for key, values := range w.ResponseWriter.Header() {
191
+ // Make a copy of the values slice to avoid reference issues
192
+ headerValues := make([]string, len(values))
193
+ copy(headerValues, values)
194
+ w.headers[key] = headerValues
195
+ }
196
+ }
197
+
198
+ // detectStreaming determines if a response should be treated as a streaming response.
199
+ // It checks for a "text/event-stream" Content-Type or a '"stream": true'
200
+ // field in the original request body.
201
+ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
202
+ // Check Content-Type for Server-Sent Events
203
+ if strings.Contains(contentType, "text/event-stream") {
204
+ return true
205
+ }
206
+
207
+ // If a concrete Content-Type is already set (e.g., application/json for error responses),
208
+ // treat it as non-streaming instead of inferring from the request payload.
209
+ if strings.TrimSpace(contentType) != "" {
210
+ return false
211
+ }
212
+
213
+ // Only fall back to request payload hints when Content-Type is not set yet.
214
+ if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
215
+ bodyStr := string(w.requestInfo.Body)
216
+ return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
217
+ }
218
+
219
+ return false
220
+ }
221
+
222
+ // processStreamingChunks runs in a separate goroutine to process response chunks from the chunkChannel.
223
+ // It asynchronously writes each chunk to the streaming log writer.
224
+ func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
225
+ if done == nil {
226
+ return
227
+ }
228
+
229
+ defer close(done)
230
+
231
+ if w.streamWriter == nil || w.chunkChannel == nil {
232
+ return
233
+ }
234
+
235
+ for chunk := range w.chunkChannel {
236
+ w.streamWriter.WriteChunkAsync(chunk)
237
+ }
238
+ }
239
+
240
+ // Finalize completes the logging process for the request and response.
241
+ // For streaming responses, it closes the chunk channel and the stream writer.
242
+ // For non-streaming responses, it logs the complete request and response details,
243
+ // including any API-specific request/response data stored in the Gin context.
244
+ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
245
+ if w.logger == nil {
246
+ return nil
247
+ }
248
+
249
+ finalStatusCode := w.statusCode
250
+ if finalStatusCode == 0 {
251
+ if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
252
+ finalStatusCode = statusWriter.Status()
253
+ } else {
254
+ finalStatusCode = 200
255
+ }
256
+ }
257
+
258
+ var slicesAPIResponseError []*interfaces.ErrorMessage
259
+ apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
260
+ if isExist {
261
+ if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
262
+ slicesAPIResponseError = apiErrors
263
+ }
264
+ }
265
+
266
+ hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
267
+ forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
268
+ if !w.logger.IsEnabled() && !forceLog {
269
+ return nil
270
+ }
271
+
272
+ if w.isStreaming && w.streamWriter != nil {
273
+ if w.chunkChannel != nil {
274
+ close(w.chunkChannel)
275
+ w.chunkChannel = nil
276
+ }
277
+
278
+ if w.streamDone != nil {
279
+ <-w.streamDone
280
+ w.streamDone = nil
281
+ }
282
+
283
+ // Write API Request and Response to the streaming log before closing
284
+ apiRequest := w.extractAPIRequest(c)
285
+ if len(apiRequest) > 0 {
286
+ _ = w.streamWriter.WriteAPIRequest(apiRequest)
287
+ }
288
+ apiResponse := w.extractAPIResponse(c)
289
+ if len(apiResponse) > 0 {
290
+ _ = w.streamWriter.WriteAPIResponse(apiResponse)
291
+ }
292
+ if err := w.streamWriter.Close(); err != nil {
293
+ w.streamWriter = nil
294
+ return err
295
+ }
296
+ w.streamWriter = nil
297
+ return nil
298
+ }
299
+
300
+ return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
301
+ }
302
+
303
+ func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
304
+ w.ensureHeadersCaptured()
305
+
306
+ finalHeaders := make(map[string][]string, len(w.headers))
307
+ for key, values := range w.headers {
308
+ headerValues := make([]string, len(values))
309
+ copy(headerValues, values)
310
+ finalHeaders[key] = headerValues
311
+ }
312
+
313
+ return finalHeaders
314
+ }
315
+
316
+ func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
317
+ apiRequest, isExist := c.Get("API_REQUEST")
318
+ if !isExist {
319
+ return nil
320
+ }
321
+ data, ok := apiRequest.([]byte)
322
+ if !ok || len(data) == 0 {
323
+ return nil
324
+ }
325
+ return data
326
+ }
327
+
328
+ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
329
+ apiResponse, isExist := c.Get("API_RESPONSE")
330
+ if !isExist {
331
+ return nil
332
+ }
333
+ data, ok := apiResponse.([]byte)
334
+ if !ok || len(data) == 0 {
335
+ return nil
336
+ }
337
+ return data
338
+ }
339
+
340
+ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
341
+ if w.requestInfo == nil {
342
+ return nil
343
+ }
344
+
345
+ var requestBody []byte
346
+ if len(w.requestInfo.Body) > 0 {
347
+ requestBody = w.requestInfo.Body
348
+ }
349
+
350
+ if loggerWithOptions, ok := w.logger.(interface {
351
+ LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
352
+ }); ok {
353
+ return loggerWithOptions.LogRequestWithOptions(
354
+ w.requestInfo.URL,
355
+ w.requestInfo.Method,
356
+ w.requestInfo.Headers,
357
+ requestBody,
358
+ statusCode,
359
+ headers,
360
+ body,
361
+ apiRequestBody,
362
+ apiResponseBody,
363
+ apiResponseErrors,
364
+ forceLog,
365
+ w.requestInfo.RequestID,
366
+ )
367
+ }
368
+
369
+ return w.logger.LogRequest(
370
+ w.requestInfo.URL,
371
+ w.requestInfo.Method,
372
+ w.requestInfo.Headers,
373
+ requestBody,
374
+ statusCode,
375
+ headers,
376
+ body,
377
+ apiRequestBody,
378
+ apiResponseBody,
379
+ apiResponseErrors,
380
+ w.requestInfo.RequestID,
381
+ )
382
+ }
internal/api/modules/amp/amp.go ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package amp implements the Amp CLI routing module, providing OAuth-based
2
+ // integration with Amp CLI for ChatGPT and Anthropic subscriptions.
3
+ package amp
4
+
5
+ import (
6
+ "fmt"
7
+ "net/http/httputil"
8
+ "strings"
9
+ "sync"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
14
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
15
+ log "github.com/sirupsen/logrus"
16
+ )
17
+
18
+ // Option configures the AmpModule.
19
+ type Option func(*AmpModule)
20
+
21
+ // AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
22
+ // It provides:
23
+ // - Reverse proxy to Amp control plane for OAuth/management
24
+ // - Provider-specific route aliases (/api/provider/{provider}/...)
25
+ // - Automatic gzip decompression for misconfigured upstreams
26
+ // - Model mapping for routing unavailable models to alternatives
27
+ type AmpModule struct {
28
+ secretSource SecretSource
29
+ proxy *httputil.ReverseProxy
30
+ proxyMu sync.RWMutex // protects proxy for hot-reload
31
+ accessManager *sdkaccess.Manager
32
+ authMiddleware_ gin.HandlerFunc
33
+ modelMapper *DefaultModelMapper
34
+ enabled bool
35
+ registerOnce sync.Once
36
+
37
+ // restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
38
+ restrictToLocalhost bool
39
+ restrictMu sync.RWMutex
40
+
41
+ // configMu protects lastConfig for partial reload comparison
42
+ configMu sync.RWMutex
43
+ lastConfig *config.AmpCode
44
+ }
45
+
46
+ // New creates a new Amp routing module with the given options.
47
+ // This is the preferred constructor using the Option pattern.
48
+ //
49
+ // Example:
50
+ //
51
+ // ampModule := amp.New(
52
+ // amp.WithAccessManager(accessManager),
53
+ // amp.WithAuthMiddleware(authMiddleware),
54
+ // amp.WithSecretSource(customSecret),
55
+ // )
56
+ func New(opts ...Option) *AmpModule {
57
+ m := &AmpModule{
58
+ secretSource: nil, // Will be created on demand if not provided
59
+ }
60
+ for _, opt := range opts {
61
+ opt(m)
62
+ }
63
+ return m
64
+ }
65
+
66
+ // NewLegacy creates a new Amp routing module using the legacy constructor signature.
67
+ // This is provided for backwards compatibility.
68
+ //
69
+ // DEPRECATED: Use New with options instead.
70
+ func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
71
+ return New(
72
+ WithAccessManager(accessManager),
73
+ WithAuthMiddleware(authMiddleware),
74
+ )
75
+ }
76
+
77
+ // WithSecretSource sets a custom secret source for the module.
78
+ func WithSecretSource(source SecretSource) Option {
79
+ return func(m *AmpModule) {
80
+ m.secretSource = source
81
+ }
82
+ }
83
+
84
+ // WithAccessManager sets the access manager for the module.
85
+ func WithAccessManager(am *sdkaccess.Manager) Option {
86
+ return func(m *AmpModule) {
87
+ m.accessManager = am
88
+ }
89
+ }
90
+
91
+ // WithAuthMiddleware sets the authentication middleware for provider routes.
92
+ func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
93
+ return func(m *AmpModule) {
94
+ m.authMiddleware_ = middleware
95
+ }
96
+ }
97
+
98
+ // Name returns the module identifier
99
+ func (m *AmpModule) Name() string {
100
+ return "amp-routing"
101
+ }
102
+
103
+ // forceModelMappings returns whether model mappings should take precedence over local API keys
104
+ func (m *AmpModule) forceModelMappings() bool {
105
+ m.configMu.RLock()
106
+ defer m.configMu.RUnlock()
107
+ if m.lastConfig == nil {
108
+ return false
109
+ }
110
+ return m.lastConfig.ForceModelMappings
111
+ }
112
+
113
+ // Register sets up Amp routes if configured.
114
+ // This implements the RouteModuleV2 interface with Context.
115
+ // Routes are registered only once via sync.Once for idempotent behavior.
116
+ func (m *AmpModule) Register(ctx modules.Context) error {
117
+ settings := ctx.Config.AmpCode
118
+ upstreamURL := strings.TrimSpace(settings.UpstreamURL)
119
+
120
+ // Determine auth middleware (from module or context)
121
+ auth := m.getAuthMiddleware(ctx)
122
+
123
+ // Use registerOnce to ensure routes are only registered once
124
+ var regErr error
125
+ m.registerOnce.Do(func() {
126
+ // Initialize model mapper from config (for routing unavailable models to alternatives)
127
+ m.modelMapper = NewModelMapper(settings.ModelMappings)
128
+
129
+ // Store initial config for partial reload comparison
130
+ settingsCopy := settings
131
+ m.lastConfig = &settingsCopy
132
+
133
+ // Initialize localhost restriction setting (hot-reloadable)
134
+ m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
135
+
136
+ // Always register provider aliases - these work without an upstream
137
+ m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
138
+
139
+ // Register management proxy routes once; middleware will gate access when upstream is unavailable.
140
+ // Pass auth middleware to require valid API key for all management routes.
141
+ m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
142
+
143
+ // If no upstream URL, skip proxy routes but provider aliases are still available
144
+ if upstreamURL == "" {
145
+ log.Debug("amp upstream proxy disabled (no upstream URL configured)")
146
+ log.Debug("amp provider alias routes registered")
147
+ m.enabled = false
148
+ return
149
+ }
150
+
151
+ if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
152
+ regErr = fmt.Errorf("failed to create amp proxy: %w", err)
153
+ return
154
+ }
155
+
156
+ log.Debug("amp provider alias routes registered")
157
+ })
158
+
159
+ return regErr
160
+ }
161
+
162
+ // getAuthMiddleware returns the authentication middleware, preferring the
163
+ // module's configured middleware, then the context middleware, then a fallback.
164
+ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
165
+ if m.authMiddleware_ != nil {
166
+ return m.authMiddleware_
167
+ }
168
+ if ctx.AuthMiddleware != nil {
169
+ return ctx.AuthMiddleware
170
+ }
171
+ // Fallback: no authentication (should not happen in production)
172
+ log.Warn("amp module: no auth middleware provided, allowing all requests")
173
+ return func(c *gin.Context) {
174
+ c.Next()
175
+ }
176
+ }
177
+
178
+ // OnConfigUpdated handles configuration updates with partial reload support.
179
+ // Only updates components that have actually changed to avoid unnecessary work.
180
+ // Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
181
+ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
182
+ newSettings := cfg.AmpCode
183
+
184
+ // Get previous config for comparison
185
+ m.configMu.RLock()
186
+ oldSettings := m.lastConfig
187
+ m.configMu.RUnlock()
188
+
189
+ if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
190
+ m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
191
+ }
192
+
193
+ newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
194
+ oldUpstreamURL := ""
195
+ if oldSettings != nil {
196
+ oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
197
+ }
198
+
199
+ if !m.enabled && newUpstreamURL != "" {
200
+ if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
201
+ log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
202
+ }
203
+ }
204
+
205
+ // Check model mappings change
206
+ modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
207
+ if modelMappingsChanged {
208
+ if m.modelMapper != nil {
209
+ m.modelMapper.UpdateMappings(newSettings.ModelMappings)
210
+ } else if m.enabled {
211
+ log.Warnf("amp model mapper not initialized, skipping model mapping update")
212
+ }
213
+ }
214
+
215
+ if m.enabled {
216
+ // Check upstream URL change - now supports hot-reload
217
+ if newUpstreamURL == "" && oldUpstreamURL != "" {
218
+ m.setProxy(nil)
219
+ m.enabled = false
220
+ } else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
221
+ // Recreate proxy with new URL
222
+ proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
223
+ if err != nil {
224
+ log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
225
+ } else {
226
+ m.setProxy(proxy)
227
+ }
228
+ }
229
+
230
+ // Check API key change (both default and per-client mappings)
231
+ apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
232
+ upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
233
+ if apiKeyChanged || upstreamAPIKeysChanged {
234
+ if m.secretSource != nil {
235
+ if ms, ok := m.secretSource.(*MappedSecretSource); ok {
236
+ if apiKeyChanged {
237
+ ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
238
+ ms.InvalidateCache()
239
+ }
240
+ if upstreamAPIKeysChanged {
241
+ ms.UpdateMappings(newSettings.UpstreamAPIKeys)
242
+ }
243
+ } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
244
+ ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
245
+ ms.InvalidateCache()
246
+ }
247
+ }
248
+ }
249
+
250
+ }
251
+
252
+ // Store current config for next comparison
253
+ m.configMu.Lock()
254
+ settingsCopy := newSettings // copy struct
255
+ m.lastConfig = &settingsCopy
256
+ m.configMu.Unlock()
257
+
258
+ return nil
259
+ }
260
+
261
+ func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
262
+ if m.secretSource == nil {
263
+ // Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
264
+ defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
265
+ mappedSource := NewMappedSecretSource(defaultSource)
266
+ mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
267
+ m.secretSource = mappedSource
268
+ } else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
269
+ ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
270
+ ms.InvalidateCache()
271
+ ms.UpdateMappings(settings.UpstreamAPIKeys)
272
+ } else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
273
+ // Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
274
+ ms.UpdateExplicitKey(settings.UpstreamAPIKey)
275
+ ms.InvalidateCache()
276
+ mappedSource := NewMappedSecretSource(ms)
277
+ mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
278
+ m.secretSource = mappedSource
279
+ }
280
+
281
+ proxy, err := createReverseProxy(upstreamURL, m.secretSource)
282
+ if err != nil {
283
+ return err
284
+ }
285
+
286
+ m.setProxy(proxy)
287
+ m.enabled = true
288
+
289
+ log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
290
+ return nil
291
+ }
292
+
293
+ // hasModelMappingsChanged compares old and new model mappings.
294
+ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
295
+ if old == nil {
296
+ return len(new.ModelMappings) > 0
297
+ }
298
+
299
+ if len(old.ModelMappings) != len(new.ModelMappings) {
300
+ return true
301
+ }
302
+
303
+ // Build map for efficient and robust comparison
304
+ type mappingInfo struct {
305
+ to string
306
+ regex bool
307
+ }
308
+ oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
309
+ for _, mapping := range old.ModelMappings {
310
+ oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
311
+ to: strings.TrimSpace(mapping.To),
312
+ regex: mapping.Regex,
313
+ }
314
+ }
315
+
316
+ for _, mapping := range new.ModelMappings {
317
+ from := strings.TrimSpace(mapping.From)
318
+ to := strings.TrimSpace(mapping.To)
319
+ if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
320
+ return true
321
+ }
322
+ }
323
+
324
+ return false
325
+ }
326
+
327
+ // hasAPIKeyChanged compares old and new API keys.
328
+ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
329
+ oldKey := ""
330
+ if old != nil {
331
+ oldKey = strings.TrimSpace(old.UpstreamAPIKey)
332
+ }
333
+ newKey := strings.TrimSpace(new.UpstreamAPIKey)
334
+ return oldKey != newKey
335
+ }
336
+
337
+ // hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
338
+ func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
339
+ if old == nil {
340
+ return len(new.UpstreamAPIKeys) > 0
341
+ }
342
+
343
+ if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
344
+ return true
345
+ }
346
+
347
+ // Build map for comparison: upstreamKey -> set of clientKeys
348
+ type entryInfo struct {
349
+ upstreamKey string
350
+ clientKeys map[string]struct{}
351
+ }
352
+ oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
353
+ for i, entry := range old.UpstreamAPIKeys {
354
+ clientKeys := make(map[string]struct{}, len(entry.APIKeys))
355
+ for _, k := range entry.APIKeys {
356
+ trimmed := strings.TrimSpace(k)
357
+ if trimmed == "" {
358
+ continue
359
+ }
360
+ clientKeys[trimmed] = struct{}{}
361
+ }
362
+ oldEntries[i] = entryInfo{
363
+ upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
364
+ clientKeys: clientKeys,
365
+ }
366
+ }
367
+
368
+ for i, newEntry := range new.UpstreamAPIKeys {
369
+ if i >= len(oldEntries) {
370
+ return true
371
+ }
372
+ oldE := oldEntries[i]
373
+ if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
374
+ return true
375
+ }
376
+ newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
377
+ for _, k := range newEntry.APIKeys {
378
+ trimmed := strings.TrimSpace(k)
379
+ if trimmed == "" {
380
+ continue
381
+ }
382
+ newKeys[trimmed] = struct{}{}
383
+ }
384
+ if len(newKeys) != len(oldE.clientKeys) {
385
+ return true
386
+ }
387
+ for k := range newKeys {
388
+ if _, ok := oldE.clientKeys[k]; !ok {
389
+ return true
390
+ }
391
+ }
392
+ }
393
+
394
+ return false
395
+ }
396
+
397
+ // GetModelMapper returns the model mapper instance (for testing/debugging).
398
+ func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
399
+ return m.modelMapper
400
+ }
401
+
402
+ // getProxy returns the current proxy instance (thread-safe for hot-reload).
403
+ func (m *AmpModule) getProxy() *httputil.ReverseProxy {
404
+ m.proxyMu.RLock()
405
+ defer m.proxyMu.RUnlock()
406
+ return m.proxy
407
+ }
408
+
409
+ // setProxy updates the proxy instance (thread-safe for hot-reload).
410
+ func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
411
+ m.proxyMu.Lock()
412
+ defer m.proxyMu.Unlock()
413
+ m.proxy = proxy
414
+ }
415
+
416
+ // IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
417
+ func (m *AmpModule) IsRestrictedToLocalhost() bool {
418
+ m.restrictMu.RLock()
419
+ defer m.restrictMu.RUnlock()
420
+ return m.restrictToLocalhost
421
+ }
422
+
423
+ // setRestrictToLocalhost updates the localhost restriction setting.
424
+ func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
425
+ m.restrictMu.Lock()
426
+ defer m.restrictMu.Unlock()
427
+ m.restrictToLocalhost = restrict
428
+ }
internal/api/modules/amp/amp_test.go ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "context"
5
+ "net/http/httptest"
6
+ "os"
7
+ "path/filepath"
8
+ "testing"
9
+ "time"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
14
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
15
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
16
+ )
17
+
18
+ func TestAmpModule_Name(t *testing.T) {
19
+ m := New()
20
+ if m.Name() != "amp-routing" {
21
+ t.Fatalf("want amp-routing, got %s", m.Name())
22
+ }
23
+ }
24
+
25
+ func TestAmpModule_New(t *testing.T) {
26
+ accessManager := sdkaccess.NewManager()
27
+ authMiddleware := func(c *gin.Context) { c.Next() }
28
+
29
+ m := NewLegacy(accessManager, authMiddleware)
30
+
31
+ if m.accessManager != accessManager {
32
+ t.Fatal("accessManager not set")
33
+ }
34
+ if m.authMiddleware_ == nil {
35
+ t.Fatal("authMiddleware not set")
36
+ }
37
+ if m.enabled {
38
+ t.Fatal("enabled should be false initially")
39
+ }
40
+ if m.proxy != nil {
41
+ t.Fatal("proxy should be nil initially")
42
+ }
43
+ }
44
+
45
+ func TestAmpModule_Register_WithUpstream(t *testing.T) {
46
+ gin.SetMode(gin.TestMode)
47
+ r := gin.New()
48
+
49
+ // Fake upstream to ensure URL is valid
50
+ upstream := httptest.NewServer(nil)
51
+ defer upstream.Close()
52
+
53
+ accessManager := sdkaccess.NewManager()
54
+ base := &handlers.BaseAPIHandler{}
55
+
56
+ m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
57
+
58
+ cfg := &config.Config{
59
+ AmpCode: config.AmpCode{
60
+ UpstreamURL: upstream.URL,
61
+ UpstreamAPIKey: "test-key",
62
+ },
63
+ }
64
+
65
+ ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
66
+ if err := m.Register(ctx); err != nil {
67
+ t.Fatalf("register error: %v", err)
68
+ }
69
+
70
+ if !m.enabled {
71
+ t.Fatal("module should be enabled with upstream URL")
72
+ }
73
+ if m.proxy == nil {
74
+ t.Fatal("proxy should be initialized")
75
+ }
76
+ if m.secretSource == nil {
77
+ t.Fatal("secretSource should be initialized")
78
+ }
79
+ }
80
+
81
+ func TestAmpModule_Register_WithoutUpstream(t *testing.T) {
82
+ gin.SetMode(gin.TestMode)
83
+ r := gin.New()
84
+
85
+ accessManager := sdkaccess.NewManager()
86
+ base := &handlers.BaseAPIHandler{}
87
+
88
+ m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
89
+
90
+ cfg := &config.Config{
91
+ AmpCode: config.AmpCode{
92
+ UpstreamURL: "", // No upstream
93
+ },
94
+ }
95
+
96
+ ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
97
+ if err := m.Register(ctx); err != nil {
98
+ t.Fatalf("register should not error without upstream: %v", err)
99
+ }
100
+
101
+ if m.enabled {
102
+ t.Fatal("module should be disabled without upstream URL")
103
+ }
104
+ if m.proxy != nil {
105
+ t.Fatal("proxy should not be initialized without upstream")
106
+ }
107
+
108
+ // But provider aliases should still be registered
109
+ req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
110
+ w := httptest.NewRecorder()
111
+ r.ServeHTTP(w, req)
112
+
113
+ if w.Code == 404 {
114
+ t.Fatal("provider aliases should be registered even without upstream")
115
+ }
116
+ }
117
+
118
+ func TestAmpModule_Register_InvalidUpstream(t *testing.T) {
119
+ gin.SetMode(gin.TestMode)
120
+ r := gin.New()
121
+
122
+ accessManager := sdkaccess.NewManager()
123
+ base := &handlers.BaseAPIHandler{}
124
+
125
+ m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
126
+
127
+ cfg := &config.Config{
128
+ AmpCode: config.AmpCode{
129
+ UpstreamURL: "://invalid-url",
130
+ },
131
+ }
132
+
133
+ ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
134
+ if err := m.Register(ctx); err == nil {
135
+ t.Fatal("expected error for invalid upstream URL")
136
+ }
137
+ }
138
+
139
+ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
140
+ tmpDir := t.TempDir()
141
+ p := filepath.Join(tmpDir, "secrets.json")
142
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
143
+ t.Fatal(err)
144
+ }
145
+
146
+ m := &AmpModule{enabled: true}
147
+ ms := NewMultiSourceSecretWithPath("", p, time.Minute)
148
+ m.secretSource = ms
149
+ m.lastConfig = &config.AmpCode{
150
+ UpstreamAPIKey: "old-key",
151
+ }
152
+
153
+ // Warm the cache
154
+ if _, err := ms.Get(context.Background()); err != nil {
155
+ t.Fatal(err)
156
+ }
157
+
158
+ if ms.cache == nil {
159
+ t.Fatal("expected cache to be set")
160
+ }
161
+
162
+ // Update config - should invalidate cache
163
+ if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
164
+ t.Fatal(err)
165
+ }
166
+
167
+ if ms.cache != nil {
168
+ t.Fatal("expected cache to be invalidated")
169
+ }
170
+ }
171
+
172
+ func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) {
173
+ m := &AmpModule{enabled: false}
174
+
175
+ // Should not error or panic when disabled
176
+ if err := m.OnConfigUpdated(&config.Config{}); err != nil {
177
+ t.Fatalf("unexpected error: %v", err)
178
+ }
179
+ }
180
+
181
+ func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) {
182
+ m := &AmpModule{enabled: true}
183
+ ms := NewMultiSourceSecret("", 0)
184
+ m.secretSource = ms
185
+
186
+ // Config update with empty URL - should log warning but not error
187
+ cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: ""}}
188
+
189
+ if err := m.OnConfigUpdated(cfg); err != nil {
190
+ t.Fatalf("unexpected error: %v", err)
191
+ }
192
+ }
193
+
194
+ func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) {
195
+ // Test that OnConfigUpdated doesn't panic with StaticSecretSource
196
+ m := &AmpModule{enabled: true}
197
+ m.secretSource = NewStaticSecretSource("static-key")
198
+
199
+ cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://example.com"}}
200
+
201
+ // Should not error or panic
202
+ if err := m.OnConfigUpdated(cfg); err != nil {
203
+ t.Fatalf("unexpected error: %v", err)
204
+ }
205
+ }
206
+
207
+ func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) {
208
+ gin.SetMode(gin.TestMode)
209
+ r := gin.New()
210
+
211
+ // Create module with no auth middleware
212
+ m := &AmpModule{authMiddleware_: nil}
213
+
214
+ // Get the fallback middleware via getAuthMiddleware
215
+ ctx := modules.Context{Engine: r, AuthMiddleware: nil}
216
+ middleware := m.getAuthMiddleware(ctx)
217
+
218
+ if middleware == nil {
219
+ t.Fatal("getAuthMiddleware should return a fallback, not nil")
220
+ }
221
+
222
+ // Test that it works
223
+ called := false
224
+ r.GET("/test", middleware, func(c *gin.Context) {
225
+ called = true
226
+ c.String(200, "ok")
227
+ })
228
+
229
+ req := httptest.NewRequest("GET", "/test", nil)
230
+ w := httptest.NewRecorder()
231
+ r.ServeHTTP(w, req)
232
+
233
+ if !called {
234
+ t.Fatal("fallback middleware should allow requests through")
235
+ }
236
+ }
237
+
238
+ func TestAmpModule_SecretSource_FromConfig(t *testing.T) {
239
+ gin.SetMode(gin.TestMode)
240
+ r := gin.New()
241
+
242
+ upstream := httptest.NewServer(nil)
243
+ defer upstream.Close()
244
+
245
+ accessManager := sdkaccess.NewManager()
246
+ base := &handlers.BaseAPIHandler{}
247
+
248
+ m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
249
+
250
+ // Config with explicit API key
251
+ cfg := &config.Config{
252
+ AmpCode: config.AmpCode{
253
+ UpstreamURL: upstream.URL,
254
+ UpstreamAPIKey: "config-key",
255
+ },
256
+ }
257
+
258
+ ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
259
+ if err := m.Register(ctx); err != nil {
260
+ t.Fatalf("register error: %v", err)
261
+ }
262
+
263
+ // Secret source should be MultiSourceSecret with config key
264
+ if m.secretSource == nil {
265
+ t.Fatal("secretSource should be set")
266
+ }
267
+
268
+ // Verify it returns the config key
269
+ key, err := m.secretSource.Get(context.Background())
270
+ if err != nil {
271
+ t.Fatalf("Get error: %v", err)
272
+ }
273
+ if key != "config-key" {
274
+ t.Fatalf("want config-key, got %s", key)
275
+ }
276
+ }
277
+
278
+ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
279
+ gin.SetMode(gin.TestMode)
280
+
281
+ scenarios := []struct {
282
+ name string
283
+ configURL string
284
+ }{
285
+ {"with_upstream", "http://example.com"},
286
+ {"without_upstream", ""},
287
+ }
288
+
289
+ for _, scenario := range scenarios {
290
+ t.Run(scenario.name, func(t *testing.T) {
291
+ r := gin.New()
292
+ accessManager := sdkaccess.NewManager()
293
+ base := &handlers.BaseAPIHandler{}
294
+
295
+ m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
296
+
297
+ cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: scenario.configURL}}
298
+
299
+ ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
300
+ if err := m.Register(ctx); err != nil && scenario.configURL != "" {
301
+ t.Fatalf("register error: %v", err)
302
+ }
303
+
304
+ // Provider aliases should always be available
305
+ req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
306
+ w := httptest.NewRecorder()
307
+ r.ServeHTTP(w, req)
308
+
309
+ if w.Code == 404 {
310
+ t.Fatal("provider aliases should be registered")
311
+ }
312
+ })
313
+ }
314
+ }
315
+
316
+ func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
317
+ m := &AmpModule{}
318
+
319
+ oldCfg := &config.AmpCode{
320
+ UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
321
+ {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
322
+ },
323
+ }
324
+ newCfg := &config.AmpCode{
325
+ UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
326
+ {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
327
+ },
328
+ }
329
+
330
+ if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
331
+ t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
332
+ }
333
+ }
334
+
335
+ func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
336
+ m := &AmpModule{}
337
+
338
+ oldCfg := &config.AmpCode{
339
+ UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
340
+ {UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
341
+ },
342
+ }
343
+ newCfg := &config.AmpCode{
344
+ UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
345
+ {UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
346
+ },
347
+ }
348
+
349
+ if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
350
+ t.Fatal("expected no change when only whitespace/empty entries differ")
351
+ }
352
+ }
internal/api/modules/amp/fallback_handlers.go ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "bytes"
5
+ "io"
6
+ "net/http/httputil"
7
+ "strings"
8
+ "time"
9
+
10
+ "github.com/gin-gonic/gin"
11
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
12
+ log "github.com/sirupsen/logrus"
13
+ "github.com/tidwall/gjson"
14
+ "github.com/tidwall/sjson"
15
+ )
16
+
17
+ // AmpRouteType represents the type of routing decision made for an Amp request
18
+ type AmpRouteType string
19
+
20
+ const (
21
+ // RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
22
+ RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
23
+ // RouteTypeModelMapping indicates the request was remapped to another available model (free)
24
+ RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
25
+ // RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
26
+ RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
27
+ // RouteTypeNoProvider indicates no provider or fallback available
28
+ RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
29
+ )
30
+
31
+ // MappedModelContextKey is the Gin context key for passing mapped model names.
32
+ const MappedModelContextKey = "mapped_model"
33
+
34
+ // logAmpRouting logs the routing decision for an Amp request with structured fields
35
+ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
36
+ fields := log.Fields{
37
+ "component": "amp-routing",
38
+ "route_type": string(routeType),
39
+ "requested_model": requestedModel,
40
+ "path": path,
41
+ "timestamp": time.Now().Format(time.RFC3339),
42
+ }
43
+
44
+ if resolvedModel != "" && resolvedModel != requestedModel {
45
+ fields["resolved_model"] = resolvedModel
46
+ }
47
+ if provider != "" {
48
+ fields["provider"] = provider
49
+ }
50
+
51
+ switch routeType {
52
+ case RouteTypeLocalProvider:
53
+ fields["cost"] = "free"
54
+ fields["source"] = "local_oauth"
55
+ log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
56
+
57
+ case RouteTypeModelMapping:
58
+ fields["cost"] = "free"
59
+ fields["source"] = "local_oauth"
60
+ fields["mapping"] = requestedModel + " -> " + resolvedModel
61
+ // model mapping already logged in mapper; avoid duplicate here
62
+
63
+ case RouteTypeAmpCredits:
64
+ fields["cost"] = "amp_credits"
65
+ fields["source"] = "ampcode.com"
66
+ fields["model_id"] = requestedModel // Explicit model_id for easy config reference
67
+ log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
68
+
69
+ case RouteTypeNoProvider:
70
+ fields["cost"] = "none"
71
+ fields["source"] = "error"
72
+ fields["model_id"] = requestedModel // Explicit model_id for easy config reference
73
+ log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
74
+ }
75
+ }
76
+
77
+ // FallbackHandler wraps a standard handler with fallback logic to ampcode.com
78
+ // when the model's provider is not available in CLIProxyAPI
79
+ type FallbackHandler struct {
80
+ getProxy func() *httputil.ReverseProxy
81
+ modelMapper ModelMapper
82
+ forceModelMappings func() bool
83
+ }
84
+
85
+ // NewFallbackHandler creates a new fallback handler wrapper
86
+ // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
87
+ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
88
+ return &FallbackHandler{
89
+ getProxy: getProxy,
90
+ forceModelMappings: func() bool { return false },
91
+ }
92
+ }
93
+
94
+ // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
95
+ func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
96
+ if forceModelMappings == nil {
97
+ forceModelMappings = func() bool { return false }
98
+ }
99
+ return &FallbackHandler{
100
+ getProxy: getProxy,
101
+ modelMapper: mapper,
102
+ forceModelMappings: forceModelMappings,
103
+ }
104
+ }
105
+
106
+ // SetModelMapper sets the model mapper for this handler (allows late binding)
107
+ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
108
+ fh.modelMapper = mapper
109
+ }
110
+
111
+ // WrapHandler wraps a gin.HandlerFunc with fallback logic
112
+ // If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
113
+ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
114
+ return func(c *gin.Context) {
115
+ requestPath := c.Request.URL.Path
116
+
117
+ // Read the request body to extract the model name
118
+ bodyBytes, err := io.ReadAll(c.Request.Body)
119
+ if err != nil {
120
+ log.Errorf("amp fallback: failed to read request body: %v", err)
121
+ handler(c)
122
+ return
123
+ }
124
+
125
+ // Restore the body for the handler to read
126
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
127
+
128
+ // Try to extract model from request body or URL path (for Gemini)
129
+ modelName := extractModelFromRequest(bodyBytes, c)
130
+ if modelName == "" {
131
+ // Can't determine model, proceed with normal handler
132
+ handler(c)
133
+ return
134
+ }
135
+
136
+ // Normalize model (handles dynamic thinking suffixes)
137
+ normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
138
+ thinkingSuffix := ""
139
+ if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
140
+ thinkingSuffix = modelName[len(normalizedModel):]
141
+ }
142
+
143
+ resolveMappedModel := func() (string, []string) {
144
+ if fh.modelMapper == nil {
145
+ return "", nil
146
+ }
147
+
148
+ mappedModel := fh.modelMapper.MapModel(modelName)
149
+ if mappedModel == "" {
150
+ mappedModel = fh.modelMapper.MapModel(normalizedModel)
151
+ }
152
+ mappedModel = strings.TrimSpace(mappedModel)
153
+ if mappedModel == "" {
154
+ return "", nil
155
+ }
156
+
157
+ // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
158
+ // already specifies its own thinking suffix.
159
+ if thinkingSuffix != "" {
160
+ _, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
161
+ if mappedThinkingMetadata == nil {
162
+ mappedModel += thinkingSuffix
163
+ }
164
+ }
165
+
166
+ mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
167
+ mappedProviders := util.GetProviderName(mappedBaseModel)
168
+ if len(mappedProviders) == 0 {
169
+ return "", nil
170
+ }
171
+
172
+ return mappedModel, mappedProviders
173
+ }
174
+
175
+ // Track resolved model for logging (may change if mapping is applied)
176
+ resolvedModel := normalizedModel
177
+ usedMapping := false
178
+ var providers []string
179
+
180
+ // Check if model mappings should be forced ahead of local API keys
181
+ forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
182
+
183
+ if forceMappings {
184
+ // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
185
+ // This allows users to route Amp requests to their preferred OAuth providers
186
+ if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
187
+ // Mapping found and provider available - rewrite the model in request body
188
+ bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
189
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
190
+ // Store mapped model in context for handlers that check it (like gemini bridge)
191
+ c.Set(MappedModelContextKey, mappedModel)
192
+ resolvedModel = mappedModel
193
+ usedMapping = true
194
+ providers = mappedProviders
195
+ }
196
+
197
+ // If no mapping applied, check for local providers
198
+ if !usedMapping {
199
+ providers = util.GetProviderName(normalizedModel)
200
+ }
201
+ } else {
202
+ // DEFAULT MODE: Check local providers first, then mappings as fallback
203
+ providers = util.GetProviderName(normalizedModel)
204
+
205
+ if len(providers) == 0 {
206
+ // No providers configured - check if we have a model mapping
207
+ if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
208
+ // Mapping found and provider available - rewrite the model in request body
209
+ bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
210
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
211
+ // Store mapped model in context for handlers that check it (like gemini bridge)
212
+ c.Set(MappedModelContextKey, mappedModel)
213
+ resolvedModel = mappedModel
214
+ usedMapping = true
215
+ providers = mappedProviders
216
+ }
217
+ }
218
+ }
219
+
220
+ // If no providers available, fallback to ampcode.com
221
+ if len(providers) == 0 {
222
+ proxy := fh.getProxy()
223
+ if proxy != nil {
224
+ // Log: Forwarding to ampcode.com (uses Amp credits)
225
+ logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
226
+
227
+ // Restore body again for the proxy
228
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
229
+
230
+ // Forward to ampcode.com
231
+ proxy.ServeHTTP(c.Writer, c.Request)
232
+ return
233
+ }
234
+
235
+ // No proxy available, let the normal handler return the error
236
+ logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
237
+ }
238
+
239
+ // Log the routing decision
240
+ providerName := ""
241
+ if len(providers) > 0 {
242
+ providerName = providers[0]
243
+ }
244
+
245
+ if usedMapping {
246
+ // Log: Model was mapped to another model
247
+ log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
248
+ logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
249
+ rewriter := NewResponseRewriter(c.Writer, modelName)
250
+ c.Writer = rewriter
251
+ // Filter Anthropic-Beta header only for local handling paths
252
+ filterAntropicBetaHeader(c)
253
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
254
+ handler(c)
255
+ rewriter.Flush()
256
+ log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
257
+ } else if len(providers) > 0 {
258
+ // Log: Using local provider (free)
259
+ logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
260
+ // Filter Anthropic-Beta header only for local handling paths
261
+ filterAntropicBetaHeader(c)
262
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
263
+ handler(c)
264
+ } else {
265
+ // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
266
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
267
+ handler(c)
268
+ }
269
+ }
270
+ }
271
+
272
+ // filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
273
+ // This is needed when using local providers (bypassing the Amp proxy)
274
+ func filterAntropicBetaHeader(c *gin.Context) {
275
+ if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
276
+ if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
277
+ c.Request.Header.Set("Anthropic-Beta", filtered)
278
+ } else {
279
+ c.Request.Header.Del("Anthropic-Beta")
280
+ }
281
+ }
282
+ }
283
+
284
+ // rewriteModelInRequest replaces the model name in a JSON request body
285
+ func rewriteModelInRequest(body []byte, newModel string) []byte {
286
+ if !gjson.GetBytes(body, "model").Exists() {
287
+ return body
288
+ }
289
+ result, err := sjson.SetBytes(body, "model", newModel)
290
+ if err != nil {
291
+ log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
292
+ return body
293
+ }
294
+ return result
295
+ }
296
+
297
+ // extractModelFromRequest attempts to extract the model name from various request formats
298
+ func extractModelFromRequest(body []byte, c *gin.Context) string {
299
+ // First try to parse from JSON body (OpenAI, Claude, etc.)
300
+ // Check common model field names
301
+ if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
302
+ return result.String()
303
+ }
304
+
305
+ // For Gemini requests, model is in the URL path
306
+ // Standard format: /models/{model}:generateContent -> :action parameter
307
+ if action := c.Param("action"); action != "" {
308
+ // Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
309
+ parts := strings.Split(action, ":")
310
+ if len(parts) > 0 && parts[0] != "" {
311
+ return parts[0]
312
+ }
313
+ }
314
+
315
+ // AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
316
+ // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
317
+ if path := c.Param("path"); path != "" {
318
+ // Look for /models/{model}:method pattern
319
+ if idx := strings.Index(path, "/models/"); idx >= 0 {
320
+ modelPart := path[idx+8:] // Skip "/models/"
321
+ // Split by colon to get model name
322
+ if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
323
+ return modelPart[:colonIdx]
324
+ }
325
+ }
326
+ }
327
+
328
+ return ""
329
+ }
internal/api/modules/amp/fallback_handlers_test.go ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "net/http"
7
+ "net/http/httptest"
8
+ "net/http/httputil"
9
+ "testing"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
14
+ )
15
+
16
+ func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
17
+ gin.SetMode(gin.TestMode)
18
+
19
+ reg := registry.GetGlobalRegistry()
20
+ reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
21
+ {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
22
+ })
23
+ defer reg.UnregisterClient("test-client-amp-fallback")
24
+
25
+ mapper := NewModelMapper([]config.AmpModelMapping{
26
+ {From: "gpt-5.2", To: "test/gpt-5.2"},
27
+ })
28
+
29
+ fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
30
+
31
+ handler := func(c *gin.Context) {
32
+ var req struct {
33
+ Model string `json:"model"`
34
+ }
35
+ if err := c.ShouldBindJSON(&req); err != nil {
36
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
37
+ return
38
+ }
39
+
40
+ c.JSON(http.StatusOK, gin.H{
41
+ "model": req.Model,
42
+ "seen_model": req.Model,
43
+ })
44
+ }
45
+
46
+ r := gin.New()
47
+ r.POST("/chat/completions", fallback.WrapHandler(handler))
48
+
49
+ reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
50
+ req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
51
+ req.Header.Set("Content-Type", "application/json")
52
+ w := httptest.NewRecorder()
53
+ r.ServeHTTP(w, req)
54
+
55
+ if w.Code != http.StatusOK {
56
+ t.Fatalf("Expected status 200, got %d", w.Code)
57
+ }
58
+
59
+ var resp struct {
60
+ Model string `json:"model"`
61
+ SeenModel string `json:"seen_model"`
62
+ }
63
+ if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
64
+ t.Fatalf("Failed to parse response JSON: %v", err)
65
+ }
66
+
67
+ if resp.Model != "gpt-5.2(xhigh)" {
68
+ t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
69
+ }
70
+ if resp.SeenModel != "test/gpt-5.2(xhigh)" {
71
+ t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
72
+ }
73
+ }
internal/api/modules/amp/gemini_bridge.go ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "strings"
5
+
6
+ "github.com/gin-gonic/gin"
7
+ )
8
+
9
+ // createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
10
+ // to our standard Gemini handler by rewriting the request context.
11
+ //
12
+ // AMP CLI format: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
13
+ // Standard format: /models/gemini-3-pro-preview:streamGenerateContent
14
+ //
15
+ // This extracts the model+method from the AMP path and sets it as the :action parameter
16
+ // so the standard Gemini handler can process it.
17
+ //
18
+ // The handler parameter should be a Gemini-compatible handler that expects the :action param.
19
+ func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
20
+ return func(c *gin.Context) {
21
+ // Get the full path from the catch-all parameter
22
+ path := c.Param("path")
23
+
24
+ // Extract model:method from AMP CLI path format
25
+ // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
26
+ const modelsPrefix = "/models/"
27
+ if idx := strings.Index(path, modelsPrefix); idx >= 0 {
28
+ // Extract everything after modelsPrefix
29
+ actionPart := path[idx+len(modelsPrefix):]
30
+
31
+ // Check if model was mapped by FallbackHandler
32
+ if mappedModel, exists := c.Get(MappedModelContextKey); exists {
33
+ if strModel, ok := mappedModel.(string); ok && strModel != "" {
34
+ // Replace the model part in the action
35
+ // actionPart is like "model-name:method"
36
+ if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
37
+ method := actionPart[colonIdx:] // ":method"
38
+ actionPart = strModel + method
39
+ }
40
+ }
41
+ }
42
+
43
+ // Set this as the :action parameter that the Gemini handler expects
44
+ c.Params = append(c.Params, gin.Param{
45
+ Key: "action",
46
+ Value: actionPart,
47
+ })
48
+
49
+ // Call the handler
50
+ handler(c)
51
+ return
52
+ }
53
+
54
+ // If we can't parse the path, return 400
55
+ c.JSON(400, gin.H{
56
+ "error": "Invalid Gemini API path format",
57
+ })
58
+ }
59
+ }
internal/api/modules/amp/gemini_bridge_test.go ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "net/http"
5
+ "net/http/httptest"
6
+ "testing"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ )
10
+
11
+ func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
12
+ gin.SetMode(gin.TestMode)
13
+
14
+ tests := []struct {
15
+ name string
16
+ path string
17
+ mappedModel string // empty string means no mapping
18
+ expectedAction string
19
+ }{
20
+ {
21
+ name: "no_mapping_uses_url_model",
22
+ path: "/publishers/google/models/gemini-pro:generateContent",
23
+ mappedModel: "",
24
+ expectedAction: "gemini-pro:generateContent",
25
+ },
26
+ {
27
+ name: "mapped_model_replaces_url_model",
28
+ path: "/publishers/google/models/gemini-exp:generateContent",
29
+ mappedModel: "gemini-2.0-flash",
30
+ expectedAction: "gemini-2.0-flash:generateContent",
31
+ },
32
+ {
33
+ name: "mapping_preserves_method",
34
+ path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
35
+ mappedModel: "gemini-flash",
36
+ expectedAction: "gemini-flash:streamGenerateContent",
37
+ },
38
+ }
39
+
40
+ for _, tt := range tests {
41
+ t.Run(tt.name, func(t *testing.T) {
42
+ var capturedAction string
43
+
44
+ mockGeminiHandler := func(c *gin.Context) {
45
+ capturedAction = c.Param("action")
46
+ c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
47
+ }
48
+
49
+ // Use the actual createGeminiBridgeHandler function
50
+ bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
51
+
52
+ r := gin.New()
53
+ if tt.mappedModel != "" {
54
+ r.Use(func(c *gin.Context) {
55
+ c.Set(MappedModelContextKey, tt.mappedModel)
56
+ c.Next()
57
+ })
58
+ }
59
+ r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
60
+
61
+ req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
62
+ w := httptest.NewRecorder()
63
+ r.ServeHTTP(w, req)
64
+
65
+ if w.Code != http.StatusOK {
66
+ t.Fatalf("Expected status 200, got %d", w.Code)
67
+ }
68
+ if capturedAction != tt.expectedAction {
69
+ t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
70
+ }
71
+ })
72
+ }
73
+ }
74
+
75
+ func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
76
+ gin.SetMode(gin.TestMode)
77
+
78
+ mockHandler := func(c *gin.Context) {
79
+ c.JSON(http.StatusOK, gin.H{"ok": true})
80
+ }
81
+ bridgeHandler := createGeminiBridgeHandler(mockHandler)
82
+
83
+ r := gin.New()
84
+ r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
85
+
86
+ req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
87
+ w := httptest.NewRecorder()
88
+ r.ServeHTTP(w, req)
89
+
90
+ if w.Code != http.StatusBadRequest {
91
+ t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
92
+ }
93
+ }
internal/api/modules/amp/model_mapping.go ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package amp provides model mapping functionality for routing Amp CLI requests
2
+ // to alternative models when the requested model is not available locally.
3
+ package amp
4
+
5
+ import (
6
+ "regexp"
7
+ "strings"
8
+ "sync"
9
+
10
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
11
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
12
+ log "github.com/sirupsen/logrus"
13
+ )
14
+
15
+ // ModelMapper provides model name mapping/aliasing for Amp CLI requests.
16
+ // When an Amp request comes in for a model that isn't available locally,
17
+ // this mapper can redirect it to an alternative model that IS available.
18
+ type ModelMapper interface {
19
+ // MapModel returns the target model name if a mapping exists and the target
20
+ // model has available providers. Returns empty string if no mapping applies.
21
+ MapModel(requestedModel string) string
22
+
23
+ // UpdateMappings refreshes the mapping configuration (for hot-reload).
24
+ UpdateMappings(mappings []config.AmpModelMapping)
25
+ }
26
+
27
+ // DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
28
+ type DefaultModelMapper struct {
29
+ mu sync.RWMutex
30
+ mappings map[string]string // exact: from -> to (normalized lowercase keys)
31
+ regexps []regexMapping // regex rules evaluated in order
32
+ }
33
+
34
+ // NewModelMapper creates a new model mapper with the given initial mappings.
35
+ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
36
+ m := &DefaultModelMapper{
37
+ mappings: make(map[string]string),
38
+ regexps: nil,
39
+ }
40
+ m.UpdateMappings(mappings)
41
+ return m
42
+ }
43
+
44
+ // MapModel checks if a mapping exists for the requested model and if the
45
+ // target model has available local providers. Returns the mapped model name
46
+ // or empty string if no valid mapping exists.
47
+ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
48
+ if requestedModel == "" {
49
+ return ""
50
+ }
51
+
52
+ m.mu.RLock()
53
+ defer m.mu.RUnlock()
54
+
55
+ // Normalize the requested model for lookup
56
+ normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
57
+
58
+ // Check for direct mapping
59
+ targetModel, exists := m.mappings[normalizedRequest]
60
+ if !exists {
61
+ // Try regex mappings in order
62
+ base, _ := util.NormalizeThinkingModel(requestedModel)
63
+ for _, rm := range m.regexps {
64
+ if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
65
+ targetModel = rm.to
66
+ exists = true
67
+ break
68
+ }
69
+ }
70
+ if !exists {
71
+ return ""
72
+ }
73
+ }
74
+
75
+ // Verify target model has available providers
76
+ normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
77
+ providers := util.GetProviderName(normalizedTarget)
78
+ if len(providers) == 0 {
79
+ log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
80
+ return ""
81
+ }
82
+
83
+ // Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
84
+ return targetModel
85
+ }
86
+
87
+ // UpdateMappings refreshes the mapping configuration from config.
88
+ // This is called during initialization and on config hot-reload.
89
+ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
90
+ m.mu.Lock()
91
+ defer m.mu.Unlock()
92
+
93
+ // Clear and rebuild mappings
94
+ m.mappings = make(map[string]string, len(mappings))
95
+ m.regexps = make([]regexMapping, 0, len(mappings))
96
+
97
+ for _, mapping := range mappings {
98
+ from := strings.TrimSpace(mapping.From)
99
+ to := strings.TrimSpace(mapping.To)
100
+
101
+ if from == "" || to == "" {
102
+ log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
103
+ continue
104
+ }
105
+
106
+ if mapping.Regex {
107
+ // Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
108
+ pattern := "(?i)" + from
109
+ re, err := regexp.Compile(pattern)
110
+ if err != nil {
111
+ log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
112
+ continue
113
+ }
114
+ m.regexps = append(m.regexps, regexMapping{re: re, to: to})
115
+ log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
116
+ } else {
117
+ // Store with normalized lowercase key for case-insensitive lookup
118
+ normalizedFrom := strings.ToLower(from)
119
+ m.mappings[normalizedFrom] = to
120
+ log.Debugf("amp model mapping registered: %s -> %s", from, to)
121
+ }
122
+ }
123
+
124
+ if len(m.mappings) > 0 {
125
+ log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
126
+ }
127
+ if n := len(m.regexps); n > 0 {
128
+ log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
129
+ }
130
+ }
131
+
132
+ // GetMappings returns a copy of current mappings (for debugging/status).
133
+ func (m *DefaultModelMapper) GetMappings() map[string]string {
134
+ m.mu.RLock()
135
+ defer m.mu.RUnlock()
136
+
137
+ result := make(map[string]string, len(m.mappings))
138
+ for k, v := range m.mappings {
139
+ result[k] = v
140
+ }
141
+ return result
142
+ }
143
+
144
+ type regexMapping struct {
145
+ re *regexp.Regexp
146
+ to string
147
+ }
internal/api/modules/amp/model_mapping_test.go ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "testing"
5
+
6
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
7
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
8
+ )
9
+
10
+ func TestNewModelMapper(t *testing.T) {
11
+ mappings := []config.AmpModelMapping{
12
+ {From: "claude-opus-4.5", To: "claude-sonnet-4"},
13
+ {From: "gpt-5", To: "gemini-2.5-pro"},
14
+ }
15
+
16
+ mapper := NewModelMapper(mappings)
17
+ if mapper == nil {
18
+ t.Fatal("Expected non-nil mapper")
19
+ }
20
+
21
+ result := mapper.GetMappings()
22
+ if len(result) != 2 {
23
+ t.Errorf("Expected 2 mappings, got %d", len(result))
24
+ }
25
+ }
26
+
27
+ func TestNewModelMapper_Empty(t *testing.T) {
28
+ mapper := NewModelMapper(nil)
29
+ if mapper == nil {
30
+ t.Fatal("Expected non-nil mapper")
31
+ }
32
+
33
+ result := mapper.GetMappings()
34
+ if len(result) != 0 {
35
+ t.Errorf("Expected 0 mappings, got %d", len(result))
36
+ }
37
+ }
38
+
39
+ func TestModelMapper_MapModel_NoProvider(t *testing.T) {
40
+ mappings := []config.AmpModelMapping{
41
+ {From: "claude-opus-4.5", To: "claude-sonnet-4"},
42
+ }
43
+
44
+ mapper := NewModelMapper(mappings)
45
+
46
+ // Without a registered provider for the target, mapping should return empty
47
+ result := mapper.MapModel("claude-opus-4.5")
48
+ if result != "" {
49
+ t.Errorf("Expected empty result when target has no provider, got %s", result)
50
+ }
51
+ }
52
+
53
+ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
54
+ // Register a mock provider for the target model
55
+ reg := registry.GetGlobalRegistry()
56
+ reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
57
+ {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
58
+ })
59
+ defer reg.UnregisterClient("test-client")
60
+
61
+ mappings := []config.AmpModelMapping{
62
+ {From: "claude-opus-4.5", To: "claude-sonnet-4"},
63
+ }
64
+
65
+ mapper := NewModelMapper(mappings)
66
+
67
+ // With a registered provider, mapping should work
68
+ result := mapper.MapModel("claude-opus-4.5")
69
+ if result != "claude-sonnet-4" {
70
+ t.Errorf("Expected claude-sonnet-4, got %s", result)
71
+ }
72
+ }
73
+
74
+ func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
75
+ reg := registry.GetGlobalRegistry()
76
+ reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
77
+ {ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
78
+ })
79
+ defer reg.UnregisterClient("test-client-thinking")
80
+
81
+ mappings := []config.AmpModelMapping{
82
+ {From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
83
+ }
84
+
85
+ mapper := NewModelMapper(mappings)
86
+
87
+ result := mapper.MapModel("gpt-5.2-alias")
88
+ if result != "gpt-5.2(xhigh)" {
89
+ t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
90
+ }
91
+ }
92
+
93
+ func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
94
+ reg := registry.GetGlobalRegistry()
95
+ reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
96
+ {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
97
+ })
98
+ defer reg.UnregisterClient("test-client2")
99
+
100
+ mappings := []config.AmpModelMapping{
101
+ {From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
102
+ }
103
+
104
+ mapper := NewModelMapper(mappings)
105
+
106
+ // Should match case-insensitively
107
+ result := mapper.MapModel("claude-opus-4.5")
108
+ if result != "claude-sonnet-4" {
109
+ t.Errorf("Expected claude-sonnet-4, got %s", result)
110
+ }
111
+ }
112
+
113
+ func TestModelMapper_MapModel_NotFound(t *testing.T) {
114
+ mappings := []config.AmpModelMapping{
115
+ {From: "claude-opus-4.5", To: "claude-sonnet-4"},
116
+ }
117
+
118
+ mapper := NewModelMapper(mappings)
119
+
120
+ // Unknown model should return empty
121
+ result := mapper.MapModel("unknown-model")
122
+ if result != "" {
123
+ t.Errorf("Expected empty for unknown model, got %s", result)
124
+ }
125
+ }
126
+
127
+ func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
128
+ mappings := []config.AmpModelMapping{
129
+ {From: "claude-opus-4.5", To: "claude-sonnet-4"},
130
+ }
131
+
132
+ mapper := NewModelMapper(mappings)
133
+
134
+ result := mapper.MapModel("")
135
+ if result != "" {
136
+ t.Errorf("Expected empty for empty input, got %s", result)
137
+ }
138
+ }
139
+
140
+ func TestModelMapper_UpdateMappings(t *testing.T) {
141
+ mapper := NewModelMapper(nil)
142
+
143
+ // Initially empty
144
+ if len(mapper.GetMappings()) != 0 {
145
+ t.Error("Expected 0 initial mappings")
146
+ }
147
+
148
+ // Update with new mappings
149
+ mapper.UpdateMappings([]config.AmpModelMapping{
150
+ {From: "model-a", To: "model-b"},
151
+ {From: "model-c", To: "model-d"},
152
+ })
153
+
154
+ result := mapper.GetMappings()
155
+ if len(result) != 2 {
156
+ t.Errorf("Expected 2 mappings after update, got %d", len(result))
157
+ }
158
+
159
+ // Update again should replace, not append
160
+ mapper.UpdateMappings([]config.AmpModelMapping{
161
+ {From: "model-x", To: "model-y"},
162
+ })
163
+
164
+ result = mapper.GetMappings()
165
+ if len(result) != 1 {
166
+ t.Errorf("Expected 1 mapping after second update, got %d", len(result))
167
+ }
168
+ }
169
+
170
+ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
171
+ mapper := NewModelMapper(nil)
172
+
173
+ mapper.UpdateMappings([]config.AmpModelMapping{
174
+ {From: "", To: "model-b"}, // Invalid: empty from
175
+ {From: "model-a", To: ""}, // Invalid: empty to
176
+ {From: " ", To: "model-b"}, // Invalid: whitespace from
177
+ {From: "model-c", To: "model-d"}, // Valid
178
+ })
179
+
180
+ result := mapper.GetMappings()
181
+ if len(result) != 1 {
182
+ t.Errorf("Expected 1 valid mapping, got %d", len(result))
183
+ }
184
+ }
185
+
186
+ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
187
+ mappings := []config.AmpModelMapping{
188
+ {From: "model-a", To: "model-b"},
189
+ }
190
+
191
+ mapper := NewModelMapper(mappings)
192
+
193
+ // Get mappings and modify the returned map
194
+ result := mapper.GetMappings()
195
+ result["new-key"] = "new-value"
196
+
197
+ // Original should be unchanged
198
+ original := mapper.GetMappings()
199
+ if len(original) != 1 {
200
+ t.Errorf("Expected original to have 1 mapping, got %d", len(original))
201
+ }
202
+ if _, exists := original["new-key"]; exists {
203
+ t.Error("Original map was modified")
204
+ }
205
+ }
206
+
207
+ func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
208
+ reg := registry.GetGlobalRegistry()
209
+ reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
210
+ {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
211
+ })
212
+ defer reg.UnregisterClient("test-client-regex-1")
213
+
214
+ mappings := []config.AmpModelMapping{
215
+ {From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
216
+ }
217
+
218
+ mapper := NewModelMapper(mappings)
219
+
220
+ // Incoming model has reasoning suffix but should match base via regex
221
+ result := mapper.MapModel("gpt-5(high)")
222
+ if result != "gemini-2.5-pro" {
223
+ t.Errorf("Expected gemini-2.5-pro, got %s", result)
224
+ }
225
+ }
226
+
227
+ func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
228
+ reg := registry.GetGlobalRegistry()
229
+ reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
230
+ {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
231
+ })
232
+ reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
233
+ {ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
234
+ })
235
+ defer reg.UnregisterClient("test-client-regex-2")
236
+ defer reg.UnregisterClient("test-client-regex-3")
237
+
238
+ mappings := []config.AmpModelMapping{
239
+ {From: "gpt-5", To: "claude-sonnet-4"}, // exact
240
+ {From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
241
+ }
242
+
243
+ mapper := NewModelMapper(mappings)
244
+
245
+ // Exact match should win over regex
246
+ result := mapper.MapModel("gpt-5")
247
+ if result != "claude-sonnet-4" {
248
+ t.Errorf("Expected claude-sonnet-4, got %s", result)
249
+ }
250
+ }
251
+
252
+ func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
253
+ // Invalid regex should be skipped and not cause panic
254
+ mappings := []config.AmpModelMapping{
255
+ {From: "(", To: "target", Regex: true},
256
+ }
257
+
258
+ mapper := NewModelMapper(mappings)
259
+
260
+ result := mapper.MapModel("anything")
261
+ if result != "" {
262
+ t.Errorf("Expected empty result due to invalid regex, got %s", result)
263
+ }
264
+ }
265
+
266
+ func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
267
+ reg := registry.GetGlobalRegistry()
268
+ reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
269
+ {ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
270
+ })
271
+ defer reg.UnregisterClient("test-client-regex-4")
272
+
273
+ mappings := []config.AmpModelMapping{
274
+ {From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
275
+ }
276
+
277
+ mapper := NewModelMapper(mappings)
278
+
279
+ result := mapper.MapModel("claude-opus-4.5")
280
+ if result != "claude-sonnet-4" {
281
+ t.Errorf("Expected claude-sonnet-4, got %s", result)
282
+ }
283
+ }
internal/api/modules/amp/proxy.go ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "bytes"
5
+ "compress/gzip"
6
+ "context"
7
+ "errors"
8
+ "fmt"
9
+ "io"
10
+ "net"
11
+ "net/http"
12
+ "net/http/httputil"
13
+ "net/url"
14
+ "strconv"
15
+ "strings"
16
+
17
+ "github.com/gin-gonic/gin"
18
+ log "github.com/sirupsen/logrus"
19
+ )
20
+
21
+ func removeQueryValuesMatching(req *http.Request, key string, match string) {
22
+ if req == nil || req.URL == nil || match == "" {
23
+ return
24
+ }
25
+
26
+ q := req.URL.Query()
27
+ values, ok := q[key]
28
+ if !ok || len(values) == 0 {
29
+ return
30
+ }
31
+
32
+ kept := make([]string, 0, len(values))
33
+ for _, v := range values {
34
+ if v == match {
35
+ continue
36
+ }
37
+ kept = append(kept, v)
38
+ }
39
+
40
+ if len(kept) == 0 {
41
+ q.Del(key)
42
+ } else {
43
+ q[key] = kept
44
+ }
45
+ req.URL.RawQuery = q.Encode()
46
+ }
47
+
48
+ // readCloser wraps a reader and forwards Close to a separate closer.
49
+ // Used to restore peeked bytes while preserving upstream body Close behavior.
50
+ type readCloser struct {
51
+ r io.Reader
52
+ c io.Closer
53
+ }
54
+
55
+ func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
56
+ func (rc *readCloser) Close() error { return rc.c.Close() }
57
+
58
+ // createReverseProxy creates a reverse proxy handler for Amp upstream
59
+ // with automatic gzip decompression via ModifyResponse
60
+ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) {
61
+ parsed, err := url.Parse(upstreamURL)
62
+ if err != nil {
63
+ return nil, fmt.Errorf("invalid amp upstream url: %w", err)
64
+ }
65
+
66
+ proxy := httputil.NewSingleHostReverseProxy(parsed)
67
+ originalDirector := proxy.Director
68
+
69
+ // Modify outgoing requests to inject API key and fix routing
70
+ proxy.Director = func(req *http.Request) {
71
+ originalDirector(req)
72
+ req.Host = parsed.Host
73
+
74
+ // Remove client's Authorization header - it was only used for CLI Proxy API authentication
75
+ // We will set our own Authorization using the configured upstream-api-key
76
+ req.Header.Del("Authorization")
77
+ req.Header.Del("X-Api-Key")
78
+ req.Header.Del("X-Goog-Api-Key")
79
+
80
+ // Remove query-based credentials if they match the authenticated client API key.
81
+ // This prevents leaking client auth material to the Amp upstream while avoiding
82
+ // breaking unrelated upstream query parameters.
83
+ clientKey := getClientAPIKeyFromContext(req.Context())
84
+ removeQueryValuesMatching(req, "key", clientKey)
85
+ removeQueryValuesMatching(req, "auth_token", clientKey)
86
+
87
+ // Preserve correlation headers for debugging
88
+ if req.Header.Get("X-Request-ID") == "" {
89
+ // Could generate one here if needed
90
+ }
91
+
92
+ // Note: We do NOT filter Anthropic-Beta headers in the proxy path
93
+ // Users going through ampcode.com proxy are paying for the service and should get all features
94
+ // including 1M context window (context-1m-2025-08-07)
95
+
96
+ // Inject API key from secret source (only uses upstream-api-key from config)
97
+ if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
98
+ req.Header.Set("X-Api-Key", key)
99
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
100
+ } else if err != nil {
101
+ log.Warnf("amp secret source error (continuing without auth): %v", err)
102
+ }
103
+ }
104
+
105
+ // Modify incoming responses to handle gzip without Content-Encoding
106
+ // This addresses the same issue as inline handler gzip handling, but at the proxy level
107
+ proxy.ModifyResponse = func(resp *http.Response) error {
108
+ // Log upstream error responses for diagnostics (502, 503, etc.)
109
+ // These are NOT proxy connection errors - the upstream responded with an error status
110
+ if resp.StatusCode >= 500 {
111
+ log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
112
+ } else if resp.StatusCode >= 400 {
113
+ log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
114
+ }
115
+
116
+ // Only process successful responses for gzip decompression
117
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
118
+ return nil
119
+ }
120
+
121
+ // Skip if already marked as gzip (Content-Encoding set)
122
+ if resp.Header.Get("Content-Encoding") != "" {
123
+ return nil
124
+ }
125
+
126
+ // Skip streaming responses (SSE, chunked)
127
+ if isStreamingResponse(resp) {
128
+ return nil
129
+ }
130
+
131
+ // Save reference to original upstream body for proper cleanup
132
+ originalBody := resp.Body
133
+
134
+ // Peek at first 2 bytes to detect gzip magic bytes
135
+ header := make([]byte, 2)
136
+ n, _ := io.ReadFull(originalBody, header)
137
+
138
+ // Check for gzip magic bytes (0x1f 0x8b)
139
+ // If n < 2, we didn't get enough bytes, so it's not gzip
140
+ if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
141
+ // It's gzip - read the rest of the body
142
+ rest, err := io.ReadAll(originalBody)
143
+ if err != nil {
144
+ // Restore what we read and return original body (preserve Close behavior)
145
+ resp.Body = &readCloser{
146
+ r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
147
+ c: originalBody,
148
+ }
149
+ return nil
150
+ }
151
+
152
+ // Reconstruct complete gzipped data
153
+ gzippedData := append(header[:n], rest...)
154
+
155
+ // Decompress
156
+ gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData))
157
+ if err != nil {
158
+ log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err)
159
+ // Close original body and return in-memory copy
160
+ _ = originalBody.Close()
161
+ resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
162
+ return nil
163
+ }
164
+
165
+ decompressed, err := io.ReadAll(gzipReader)
166
+ _ = gzipReader.Close()
167
+ if err != nil {
168
+ log.Warnf("amp proxy: gzip decompress error: %v", err)
169
+ // Close original body and return in-memory copy
170
+ _ = originalBody.Close()
171
+ resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
172
+ return nil
173
+ }
174
+
175
+ // Close original body since we're replacing with in-memory decompressed content
176
+ _ = originalBody.Close()
177
+
178
+ // Replace body with decompressed content
179
+ resp.Body = io.NopCloser(bytes.NewReader(decompressed))
180
+ resp.ContentLength = int64(len(decompressed))
181
+
182
+ // Update headers to reflect decompressed state
183
+ resp.Header.Del("Content-Encoding") // No longer compressed
184
+ resp.Header.Del("Content-Length") // Remove stale compressed length
185
+ resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
186
+
187
+ log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
188
+ } else {
189
+ // Not gzip - restore peeked bytes while preserving Close behavior
190
+ // Handle edge cases: n might be 0, 1, or 2 depending on EOF
191
+ resp.Body = &readCloser{
192
+ r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
193
+ c: originalBody,
194
+ }
195
+ }
196
+
197
+ return nil
198
+ }
199
+
200
+ // Error handler for proxy failures with detailed error classification for diagnostics
201
+ proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
202
+ // Classify the error type for better diagnostics
203
+ var errType string
204
+ if errors.Is(err, context.DeadlineExceeded) {
205
+ errType = "timeout"
206
+ } else if errors.Is(err, context.Canceled) {
207
+ errType = "canceled"
208
+ } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
209
+ errType = "dial_timeout"
210
+ } else if _, ok := err.(net.Error); ok {
211
+ errType = "network_error"
212
+ } else {
213
+ errType = "connection_error"
214
+ }
215
+
216
+ // Don't log as error for context canceled - it's usually client closing connection
217
+ if errors.Is(err, context.Canceled) {
218
+ log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
219
+ } else {
220
+ log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
221
+ }
222
+
223
+ rw.Header().Set("Content-Type", "application/json")
224
+ rw.WriteHeader(http.StatusBadGateway)
225
+ _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
226
+ }
227
+
228
+ return proxy, nil
229
+ }
230
+
231
+ // isStreamingResponse detects if the response is streaming (SSE only)
232
+ // Note: We only treat text/event-stream as streaming. Chunked transfer encoding
233
+ // is a transport-level detail and doesn't mean we can't decompress the full response.
234
+ // Many JSON APIs use chunked encoding for normal responses.
235
+ func isStreamingResponse(resp *http.Response) bool {
236
+ contentType := resp.Header.Get("Content-Type")
237
+
238
+ // Only Server-Sent Events are true streaming responses
239
+ if strings.Contains(contentType, "text/event-stream") {
240
+ return true
241
+ }
242
+
243
+ return false
244
+ }
245
+
246
+ // proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc
247
+ func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc {
248
+ return func(c *gin.Context) {
249
+ proxy.ServeHTTP(c.Writer, c.Request)
250
+ }
251
+ }
252
+
253
+ // filterBetaFeatures removes a specific beta feature from comma-separated list
254
+ func filterBetaFeatures(header, featureToRemove string) string {
255
+ features := strings.Split(header, ",")
256
+ filtered := make([]string, 0, len(features))
257
+
258
+ for _, feature := range features {
259
+ trimmed := strings.TrimSpace(feature)
260
+ if trimmed != "" && trimmed != featureToRemove {
261
+ filtered = append(filtered, trimmed)
262
+ }
263
+ }
264
+
265
+ return strings.Join(filtered, ",")
266
+ }
internal/api/modules/amp/proxy_test.go ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "bytes"
5
+ "compress/gzip"
6
+ "context"
7
+ "fmt"
8
+ "io"
9
+ "net/http"
10
+ "net/http/httptest"
11
+ "strings"
12
+ "testing"
13
+
14
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
15
+ )
16
+
17
+ // Helper: compress data with gzip
18
+ func gzipBytes(b []byte) []byte {
19
+ var buf bytes.Buffer
20
+ zw := gzip.NewWriter(&buf)
21
+ zw.Write(b)
22
+ zw.Close()
23
+ return buf.Bytes()
24
+ }
25
+
26
+ // Helper: create a mock http.Response
27
+ func mkResp(status int, hdr http.Header, body []byte) *http.Response {
28
+ if hdr == nil {
29
+ hdr = http.Header{}
30
+ }
31
+ return &http.Response{
32
+ StatusCode: status,
33
+ Header: hdr,
34
+ Body: io.NopCloser(bytes.NewReader(body)),
35
+ ContentLength: int64(len(body)),
36
+ }
37
+ }
38
+
39
+ func TestCreateReverseProxy_ValidURL(t *testing.T) {
40
+ proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
41
+ if err != nil {
42
+ t.Fatalf("expected no error, got: %v", err)
43
+ }
44
+ if proxy == nil {
45
+ t.Fatal("expected proxy to be created")
46
+ }
47
+ }
48
+
49
+ func TestCreateReverseProxy_InvalidURL(t *testing.T) {
50
+ _, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
51
+ if err == nil {
52
+ t.Fatal("expected error for invalid URL")
53
+ }
54
+ }
55
+
56
+ func TestModifyResponse_GzipScenarios(t *testing.T) {
57
+ proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
58
+ if err != nil {
59
+ t.Fatal(err)
60
+ }
61
+
62
+ goodJSON := []byte(`{"ok":true}`)
63
+ good := gzipBytes(goodJSON)
64
+ truncated := good[:10]
65
+ corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
66
+
67
+ cases := []struct {
68
+ name string
69
+ header http.Header
70
+ body []byte
71
+ status int
72
+ wantBody []byte
73
+ wantCE string
74
+ }{
75
+ {
76
+ name: "decompresses_valid_gzip_no_header",
77
+ header: http.Header{},
78
+ body: good,
79
+ status: 200,
80
+ wantBody: goodJSON,
81
+ wantCE: "",
82
+ },
83
+ {
84
+ name: "skips_when_ce_present",
85
+ header: http.Header{"Content-Encoding": []string{"gzip"}},
86
+ body: good,
87
+ status: 200,
88
+ wantBody: good,
89
+ wantCE: "gzip",
90
+ },
91
+ {
92
+ name: "passes_truncated_unchanged",
93
+ header: http.Header{},
94
+ body: truncated,
95
+ status: 200,
96
+ wantBody: truncated,
97
+ wantCE: "",
98
+ },
99
+ {
100
+ name: "passes_corrupted_unchanged",
101
+ header: http.Header{},
102
+ body: corrupted,
103
+ status: 200,
104
+ wantBody: corrupted,
105
+ wantCE: "",
106
+ },
107
+ {
108
+ name: "non_gzip_unchanged",
109
+ header: http.Header{},
110
+ body: []byte("plain"),
111
+ status: 200,
112
+ wantBody: []byte("plain"),
113
+ wantCE: "",
114
+ },
115
+ {
116
+ name: "empty_body",
117
+ header: http.Header{},
118
+ body: []byte{},
119
+ status: 200,
120
+ wantBody: []byte{},
121
+ wantCE: "",
122
+ },
123
+ {
124
+ name: "single_byte_body",
125
+ header: http.Header{},
126
+ body: []byte{0x1f},
127
+ status: 200,
128
+ wantBody: []byte{0x1f},
129
+ wantCE: "",
130
+ },
131
+ {
132
+ name: "skips_non_2xx_status",
133
+ header: http.Header{},
134
+ body: good,
135
+ status: 404,
136
+ wantBody: good,
137
+ wantCE: "",
138
+ },
139
+ }
140
+
141
+ for _, tc := range cases {
142
+ t.Run(tc.name, func(t *testing.T) {
143
+ resp := mkResp(tc.status, tc.header, tc.body)
144
+ if err := proxy.ModifyResponse(resp); err != nil {
145
+ t.Fatalf("ModifyResponse error: %v", err)
146
+ }
147
+ got, err := io.ReadAll(resp.Body)
148
+ if err != nil {
149
+ t.Fatalf("ReadAll error: %v", err)
150
+ }
151
+ if !bytes.Equal(got, tc.wantBody) {
152
+ t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
153
+ }
154
+ if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
155
+ t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
156
+ }
157
+ })
158
+ }
159
+ }
160
+
161
+ func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
162
+ proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
163
+ if err != nil {
164
+ t.Fatal(err)
165
+ }
166
+
167
+ goodJSON := []byte(`{"message":"test response"}`)
168
+ gzipped := gzipBytes(goodJSON)
169
+
170
+ // Simulate upstream response with gzip body AND Content-Length header
171
+ // (this is the scenario the bot flagged - stale Content-Length after decompression)
172
+ resp := mkResp(200, http.Header{
173
+ "Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
174
+ }, gzipped)
175
+
176
+ if err := proxy.ModifyResponse(resp); err != nil {
177
+ t.Fatalf("ModifyResponse error: %v", err)
178
+ }
179
+
180
+ // Verify body is decompressed
181
+ got, _ := io.ReadAll(resp.Body)
182
+ if !bytes.Equal(got, goodJSON) {
183
+ t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
184
+ }
185
+
186
+ // Verify Content-Length header is updated to decompressed size
187
+ wantCL := fmt.Sprintf("%d", len(goodJSON))
188
+ gotCL := resp.Header.Get("Content-Length")
189
+ if gotCL != wantCL {
190
+ t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
191
+ }
192
+
193
+ // Verify struct field also matches
194
+ if resp.ContentLength != int64(len(goodJSON)) {
195
+ t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
196
+ }
197
+ }
198
+
199
+ func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
200
+ proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
201
+ if err != nil {
202
+ t.Fatal(err)
203
+ }
204
+
205
+ goodJSON := []byte(`{"ok":true}`)
206
+ gzipped := gzipBytes(goodJSON)
207
+
208
+ t.Run("sse_skips_decompression", func(t *testing.T) {
209
+ resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
210
+ if err := proxy.ModifyResponse(resp); err != nil {
211
+ t.Fatalf("ModifyResponse error: %v", err)
212
+ }
213
+ // SSE should NOT be decompressed
214
+ got, _ := io.ReadAll(resp.Body)
215
+ if !bytes.Equal(got, gzipped) {
216
+ t.Fatal("SSE response should not be decompressed")
217
+ }
218
+ })
219
+ }
220
+
221
+ func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
222
+ proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
223
+ if err != nil {
224
+ t.Fatal(err)
225
+ }
226
+
227
+ goodJSON := []byte(`{"ok":true}`)
228
+ gzipped := gzipBytes(goodJSON)
229
+
230
+ t.Run("chunked_json_decompresses", func(t *testing.T) {
231
+ // Chunked JSON responses (like thread APIs) should be decompressed
232
+ resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
233
+ if err := proxy.ModifyResponse(resp); err != nil {
234
+ t.Fatalf("ModifyResponse error: %v", err)
235
+ }
236
+ // Should decompress because it's not SSE
237
+ got, _ := io.ReadAll(resp.Body)
238
+ if !bytes.Equal(got, goodJSON) {
239
+ t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
240
+ }
241
+ })
242
+ }
243
+
244
+ func TestReverseProxy_InjectsHeaders(t *testing.T) {
245
+ gotHeaders := make(chan http.Header, 1)
246
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
247
+ gotHeaders <- r.Header.Clone()
248
+ w.WriteHeader(200)
249
+ w.Write([]byte(`ok`))
250
+ }))
251
+ defer upstream.Close()
252
+
253
+ proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
254
+ if err != nil {
255
+ t.Fatal(err)
256
+ }
257
+
258
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
259
+ proxy.ServeHTTP(w, r)
260
+ }))
261
+ defer srv.Close()
262
+
263
+ res, err := http.Get(srv.URL + "/test")
264
+ if err != nil {
265
+ t.Fatal(err)
266
+ }
267
+ res.Body.Close()
268
+
269
+ hdr := <-gotHeaders
270
+ if hdr.Get("X-Api-Key") != "secret" {
271
+ t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
272
+ }
273
+ if hdr.Get("Authorization") != "Bearer secret" {
274
+ t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
275
+ }
276
+ }
277
+
278
+ func TestReverseProxy_EmptySecret(t *testing.T) {
279
+ gotHeaders := make(chan http.Header, 1)
280
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281
+ gotHeaders <- r.Header.Clone()
282
+ w.WriteHeader(200)
283
+ w.Write([]byte(`ok`))
284
+ }))
285
+ defer upstream.Close()
286
+
287
+ proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
288
+ if err != nil {
289
+ t.Fatal(err)
290
+ }
291
+
292
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293
+ proxy.ServeHTTP(w, r)
294
+ }))
295
+ defer srv.Close()
296
+
297
+ res, err := http.Get(srv.URL + "/test")
298
+ if err != nil {
299
+ t.Fatal(err)
300
+ }
301
+ res.Body.Close()
302
+
303
+ hdr := <-gotHeaders
304
+ // Should NOT inject headers when secret is empty
305
+ if hdr.Get("X-Api-Key") != "" {
306
+ t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
307
+ }
308
+ if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
309
+ t.Fatalf("Authorization should not be set, got: %q", authVal)
310
+ }
311
+ }
312
+
313
+ func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
314
+ type captured struct {
315
+ headers http.Header
316
+ query string
317
+ }
318
+ got := make(chan captured, 1)
319
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
320
+ got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
321
+ w.WriteHeader(200)
322
+ w.Write([]byte(`ok`))
323
+ }))
324
+ defer upstream.Close()
325
+
326
+ proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
327
+ if err != nil {
328
+ t.Fatal(err)
329
+ }
330
+
331
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
332
+ // Simulate clientAPIKeyMiddleware injection (per-request)
333
+ ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
334
+ proxy.ServeHTTP(w, r.WithContext(ctx))
335
+ }))
336
+ defer srv.Close()
337
+
338
+ req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
339
+ if err != nil {
340
+ t.Fatal(err)
341
+ }
342
+ req.Header.Set("Authorization", "Bearer client-key")
343
+ req.Header.Set("X-Api-Key", "client-key")
344
+ req.Header.Set("X-Goog-Api-Key", "client-key")
345
+
346
+ res, err := http.DefaultClient.Do(req)
347
+ if err != nil {
348
+ t.Fatal(err)
349
+ }
350
+ res.Body.Close()
351
+
352
+ c := <-got
353
+
354
+ // These are client-provided credentials and must not reach the upstream.
355
+ if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
356
+ t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
357
+ }
358
+
359
+ // We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
360
+ if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
361
+ t.Fatalf("Authorization should be upstream-injected, got: %q", v)
362
+ }
363
+ if v := c.headers.Get("X-Api-Key"); v != "upstream" {
364
+ t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
365
+ }
366
+
367
+ // Query-based credentials should be stripped only when they match the authenticated client key.
368
+ // Should keep unrelated values and parameters.
369
+ if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
370
+ t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
371
+ }
372
+ if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
373
+ t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
374
+ }
375
+ }
376
+
377
+ func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
378
+ gotHeaders := make(chan http.Header, 1)
379
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
380
+ gotHeaders <- r.Header.Clone()
381
+ w.WriteHeader(200)
382
+ w.Write([]byte(`ok`))
383
+ }))
384
+ defer upstream.Close()
385
+
386
+ defaultSource := NewStaticSecretSource("default")
387
+ mapped := NewMappedSecretSource(defaultSource)
388
+ mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
389
+ {
390
+ UpstreamAPIKey: "u1",
391
+ APIKeys: []string{"k1"},
392
+ },
393
+ })
394
+
395
+ proxy, err := createReverseProxy(upstream.URL, mapped)
396
+ if err != nil {
397
+ t.Fatal(err)
398
+ }
399
+
400
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
401
+ // Simulate clientAPIKeyMiddleware injection (per-request)
402
+ ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
403
+ proxy.ServeHTTP(w, r.WithContext(ctx))
404
+ }))
405
+ defer srv.Close()
406
+
407
+ res, err := http.Get(srv.URL + "/test")
408
+ if err != nil {
409
+ t.Fatal(err)
410
+ }
411
+ res.Body.Close()
412
+
413
+ hdr := <-gotHeaders
414
+ if hdr.Get("X-Api-Key") != "u1" {
415
+ t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
416
+ }
417
+ if hdr.Get("Authorization") != "Bearer u1" {
418
+ t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
419
+ }
420
+ }
421
+
422
+ func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
423
+ gotHeaders := make(chan http.Header, 1)
424
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
425
+ gotHeaders <- r.Header.Clone()
426
+ w.WriteHeader(200)
427
+ w.Write([]byte(`ok`))
428
+ }))
429
+ defer upstream.Close()
430
+
431
+ defaultSource := NewStaticSecretSource("default")
432
+ mapped := NewMappedSecretSource(defaultSource)
433
+ mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
434
+ {
435
+ UpstreamAPIKey: "u1",
436
+ APIKeys: []string{"k1"},
437
+ },
438
+ })
439
+
440
+ proxy, err := createReverseProxy(upstream.URL, mapped)
441
+ if err != nil {
442
+ t.Fatal(err)
443
+ }
444
+
445
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
446
+ ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
447
+ proxy.ServeHTTP(w, r.WithContext(ctx))
448
+ }))
449
+ defer srv.Close()
450
+
451
+ res, err := http.Get(srv.URL + "/test")
452
+ if err != nil {
453
+ t.Fatal(err)
454
+ }
455
+ res.Body.Close()
456
+
457
+ hdr := <-gotHeaders
458
+ if hdr.Get("X-Api-Key") != "default" {
459
+ t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
460
+ }
461
+ if hdr.Get("Authorization") != "Bearer default" {
462
+ t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
463
+ }
464
+ }
465
+
466
+ func TestReverseProxy_ErrorHandler(t *testing.T) {
467
+ // Point proxy to a non-routable address to trigger error
468
+ proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
469
+ if err != nil {
470
+ t.Fatal(err)
471
+ }
472
+
473
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
474
+ proxy.ServeHTTP(w, r)
475
+ }))
476
+ defer srv.Close()
477
+
478
+ res, err := http.Get(srv.URL + "/any")
479
+ if err != nil {
480
+ t.Fatal(err)
481
+ }
482
+ body, _ := io.ReadAll(res.Body)
483
+ res.Body.Close()
484
+
485
+ if res.StatusCode != http.StatusBadGateway {
486
+ t.Fatalf("want 502, got %d", res.StatusCode)
487
+ }
488
+ if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
489
+ t.Fatalf("unexpected body: %s", body)
490
+ }
491
+ if ct := res.Header.Get("Content-Type"); ct != "application/json" {
492
+ t.Fatalf("content-type: want application/json, got %s", ct)
493
+ }
494
+ }
495
+
496
+ func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
497
+ // Upstream returns gzipped JSON without Content-Encoding header
498
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
499
+ w.WriteHeader(200)
500
+ w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
501
+ }))
502
+ defer upstream.Close()
503
+
504
+ proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
505
+ if err != nil {
506
+ t.Fatal(err)
507
+ }
508
+
509
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
510
+ proxy.ServeHTTP(w, r)
511
+ }))
512
+ defer srv.Close()
513
+
514
+ res, err := http.Get(srv.URL + "/test")
515
+ if err != nil {
516
+ t.Fatal(err)
517
+ }
518
+ body, _ := io.ReadAll(res.Body)
519
+ res.Body.Close()
520
+
521
+ expected := []byte(`{"upstream":"ok"}`)
522
+ if !bytes.Equal(body, expected) {
523
+ t.Fatalf("want decompressed JSON, got: %s", body)
524
+ }
525
+ }
526
+
527
+ func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
528
+ // Upstream returns plain JSON
529
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
530
+ w.Header().Set("Content-Type", "application/json")
531
+ w.WriteHeader(200)
532
+ w.Write([]byte(`{"plain":"json"}`))
533
+ }))
534
+ defer upstream.Close()
535
+
536
+ proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
537
+ if err != nil {
538
+ t.Fatal(err)
539
+ }
540
+
541
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
542
+ proxy.ServeHTTP(w, r)
543
+ }))
544
+ defer srv.Close()
545
+
546
+ res, err := http.Get(srv.URL + "/test")
547
+ if err != nil {
548
+ t.Fatal(err)
549
+ }
550
+ body, _ := io.ReadAll(res.Body)
551
+ res.Body.Close()
552
+
553
+ expected := []byte(`{"plain":"json"}`)
554
+ if !bytes.Equal(body, expected) {
555
+ t.Fatalf("want plain JSON unchanged, got: %s", body)
556
+ }
557
+ }
558
+
559
+ func TestIsStreamingResponse(t *testing.T) {
560
+ cases := []struct {
561
+ name string
562
+ header http.Header
563
+ want bool
564
+ }{
565
+ {
566
+ name: "sse",
567
+ header: http.Header{"Content-Type": []string{"text/event-stream"}},
568
+ want: true,
569
+ },
570
+ {
571
+ name: "chunked_not_streaming",
572
+ header: http.Header{"Transfer-Encoding": []string{"chunked"}},
573
+ want: false, // Chunked is transport-level, not streaming
574
+ },
575
+ {
576
+ name: "normal_json",
577
+ header: http.Header{"Content-Type": []string{"application/json"}},
578
+ want: false,
579
+ },
580
+ {
581
+ name: "empty",
582
+ header: http.Header{},
583
+ want: false,
584
+ },
585
+ }
586
+
587
+ for _, tc := range cases {
588
+ t.Run(tc.name, func(t *testing.T) {
589
+ resp := &http.Response{Header: tc.header}
590
+ got := isStreamingResponse(resp)
591
+ if got != tc.want {
592
+ t.Fatalf("want %v, got %v", tc.want, got)
593
+ }
594
+ })
595
+ }
596
+ }
597
+
598
+ func TestFilterBetaFeatures(t *testing.T) {
599
+ tests := []struct {
600
+ name string
601
+ header string
602
+ featureToRemove string
603
+ expected string
604
+ }{
605
+ {
606
+ name: "Remove context-1m from middle",
607
+ header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
608
+ featureToRemove: "context-1m-2025-08-07",
609
+ expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
610
+ },
611
+ {
612
+ name: "Remove context-1m from start",
613
+ header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
614
+ featureToRemove: "context-1m-2025-08-07",
615
+ expected: "fine-grained-tool-streaming-2025-05-14",
616
+ },
617
+ {
618
+ name: "Remove context-1m from end",
619
+ header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
620
+ featureToRemove: "context-1m-2025-08-07",
621
+ expected: "fine-grained-tool-streaming-2025-05-14",
622
+ },
623
+ {
624
+ name: "Feature not present",
625
+ header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
626
+ featureToRemove: "context-1m-2025-08-07",
627
+ expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
628
+ },
629
+ {
630
+ name: "Only feature to remove",
631
+ header: "context-1m-2025-08-07",
632
+ featureToRemove: "context-1m-2025-08-07",
633
+ expected: "",
634
+ },
635
+ {
636
+ name: "Empty header",
637
+ header: "",
638
+ featureToRemove: "context-1m-2025-08-07",
639
+ expected: "",
640
+ },
641
+ {
642
+ name: "Header with spaces",
643
+ header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
644
+ featureToRemove: "context-1m-2025-08-07",
645
+ expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
646
+ },
647
+ }
648
+
649
+ for _, tt := range tests {
650
+ t.Run(tt.name, func(t *testing.T) {
651
+ result := filterBetaFeatures(tt.header, tt.featureToRemove)
652
+ if result != tt.expected {
653
+ t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
654
+ }
655
+ })
656
+ }
657
+ }
internal/api/modules/amp/response_rewriter.go ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "bytes"
5
+ "net/http"
6
+ "strings"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ log "github.com/sirupsen/logrus"
10
+ "github.com/tidwall/gjson"
11
+ "github.com/tidwall/sjson"
12
+ )
13
+
14
+ // ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
15
+ // It's used to rewrite model names in responses when model mapping is used
16
+ type ResponseRewriter struct {
17
+ gin.ResponseWriter
18
+ body *bytes.Buffer
19
+ originalModel string
20
+ isStreaming bool
21
+ }
22
+
23
+ // NewResponseRewriter creates a new response rewriter for model name substitution
24
+ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
25
+ return &ResponseRewriter{
26
+ ResponseWriter: w,
27
+ body: &bytes.Buffer{},
28
+ originalModel: originalModel,
29
+ }
30
+ }
31
+
32
+ const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
33
+
34
+ func looksLikeSSEChunk(data []byte) bool {
35
+ // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
36
+ // Heuristics are intentionally simple and cheap.
37
+ return bytes.Contains(data, []byte("data:")) ||
38
+ bytes.Contains(data, []byte("event:")) ||
39
+ bytes.Contains(data, []byte("message_start")) ||
40
+ bytes.Contains(data, []byte("message_delta")) ||
41
+ bytes.Contains(data, []byte("content_block_start")) ||
42
+ bytes.Contains(data, []byte("content_block_delta")) ||
43
+ bytes.Contains(data, []byte("content_block_stop")) ||
44
+ bytes.Contains(data, []byte("\n\n"))
45
+ }
46
+
47
+ func (rw *ResponseRewriter) enableStreaming(reason string) error {
48
+ if rw.isStreaming {
49
+ return nil
50
+ }
51
+ rw.isStreaming = true
52
+
53
+ // Flush any previously buffered data to avoid reordering or data loss.
54
+ if rw.body != nil && rw.body.Len() > 0 {
55
+ buf := rw.body.Bytes()
56
+ // Copy before Reset() to keep bytes stable.
57
+ toFlush := make([]byte, len(buf))
58
+ copy(toFlush, buf)
59
+ rw.body.Reset()
60
+
61
+ if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
62
+ return err
63
+ }
64
+ if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
65
+ flusher.Flush()
66
+ }
67
+ }
68
+
69
+ log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
70
+ return nil
71
+ }
72
+
73
+ // Write intercepts response writes and buffers them for model name replacement
74
+ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
75
+ // Detect streaming on first write (header-based)
76
+ if !rw.isStreaming && rw.body.Len() == 0 {
77
+ contentType := rw.Header().Get("Content-Type")
78
+ rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
79
+ strings.Contains(contentType, "stream")
80
+ }
81
+
82
+ if !rw.isStreaming {
83
+ // Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
84
+ if looksLikeSSEChunk(data) {
85
+ if err := rw.enableStreaming("sse heuristic"); err != nil {
86
+ return 0, err
87
+ }
88
+ } else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
89
+ // Safety cap: avoid unbounded buffering on large responses.
90
+ log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
91
+ if err := rw.enableStreaming("buffer limit"); err != nil {
92
+ return 0, err
93
+ }
94
+ }
95
+ }
96
+
97
+ if rw.isStreaming {
98
+ n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
99
+ if err == nil {
100
+ if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
101
+ flusher.Flush()
102
+ }
103
+ }
104
+ return n, err
105
+ }
106
+ return rw.body.Write(data)
107
+ }
108
+
109
+ // Flush writes the buffered response with model names rewritten
110
+ func (rw *ResponseRewriter) Flush() {
111
+ if rw.isStreaming {
112
+ if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
113
+ flusher.Flush()
114
+ }
115
+ return
116
+ }
117
+ if rw.body.Len() > 0 {
118
+ if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
119
+ log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
120
+ }
121
+ }
122
+ }
123
+
124
+ // modelFieldPaths lists all JSON paths where model name may appear
125
+ var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
126
+
127
+ // rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
128
+ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
129
+ if rw.originalModel == "" {
130
+ return data
131
+ }
132
+ for _, path := range modelFieldPaths {
133
+ if gjson.GetBytes(data, path).Exists() {
134
+ data, _ = sjson.SetBytes(data, path, rw.originalModel)
135
+ }
136
+ }
137
+ return data
138
+ }
139
+
140
+ // rewriteStreamChunk rewrites model names in SSE stream chunks
141
+ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
142
+ if rw.originalModel == "" {
143
+ return chunk
144
+ }
145
+
146
+ // SSE format: "data: {json}\n\n"
147
+ lines := bytes.Split(chunk, []byte("\n"))
148
+ for i, line := range lines {
149
+ if bytes.HasPrefix(line, []byte("data: ")) {
150
+ jsonData := bytes.TrimPrefix(line, []byte("data: "))
151
+ if len(jsonData) > 0 && jsonData[0] == '{' {
152
+ // Rewrite JSON in the data line
153
+ rewritten := rw.rewriteModelInResponse(jsonData)
154
+ lines[i] = append([]byte("data: "), rewritten...)
155
+ }
156
+ }
157
+ }
158
+
159
+ return bytes.Join(lines, []byte("\n"))
160
+ }
internal/api/modules/amp/routes.go ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "context"
5
+ "errors"
6
+ "net"
7
+ "net/http"
8
+ "net/http/httputil"
9
+ "strings"
10
+
11
+ "github.com/gin-gonic/gin"
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
13
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
14
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
15
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
16
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
17
+ log "github.com/sirupsen/logrus"
18
+ )
19
+
20
+ // clientAPIKeyContextKey is the context key used to pass the client API key
21
+ // from gin.Context to the request context for SecretSource lookup.
22
+ type clientAPIKeyContextKey struct{}
23
+
24
+ // clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
25
+ // into the request context so that SecretSource can look it up for per-client upstream routing.
26
+ func clientAPIKeyMiddleware() gin.HandlerFunc {
27
+ return func(c *gin.Context) {
28
+ // Extract the client API key from gin context (set by AuthMiddleware)
29
+ if apiKey, exists := c.Get("apiKey"); exists {
30
+ if keyStr, ok := apiKey.(string); ok && keyStr != "" {
31
+ // Inject into request context for SecretSource.Get(ctx) to read
32
+ ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
33
+ c.Request = c.Request.WithContext(ctx)
34
+ }
35
+ }
36
+ c.Next()
37
+ }
38
+ }
39
+
40
+ // getClientAPIKeyFromContext retrieves the client API key from request context.
41
+ // Returns empty string if not present.
42
+ func getClientAPIKeyFromContext(ctx context.Context) string {
43
+ if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
44
+ if keyStr, ok := val.(string); ok {
45
+ return keyStr
46
+ }
47
+ }
48
+ return ""
49
+ }
50
+
51
+ // localhostOnlyMiddleware returns a middleware that dynamically checks the module's
52
+ // localhost restriction setting. This allows hot-reload of the restriction without restarting.
53
+ func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
54
+ return func(c *gin.Context) {
55
+ // Check current setting (hot-reloadable)
56
+ if !m.IsRestrictedToLocalhost() {
57
+ c.Next()
58
+ return
59
+ }
60
+
61
+ // Use actual TCP connection address (RemoteAddr) to prevent header spoofing
62
+ // This cannot be forged by X-Forwarded-For or other client-controlled headers
63
+ remoteAddr := c.Request.RemoteAddr
64
+
65
+ // RemoteAddr format is "IP:port" or "[IPv6]:port", extract just the IP
66
+ host, _, err := net.SplitHostPort(remoteAddr)
67
+ if err != nil {
68
+ // Try parsing as raw IP (shouldn't happen with standard HTTP, but be defensive)
69
+ host = remoteAddr
70
+ }
71
+
72
+ // Parse the IP to handle both IPv4 and IPv6
73
+ ip := net.ParseIP(host)
74
+ if ip == nil {
75
+ log.Warnf("amp management: invalid RemoteAddr %s, denying access", remoteAddr)
76
+ c.AbortWithStatusJSON(403, gin.H{
77
+ "error": "Access denied: management routes restricted to localhost",
78
+ })
79
+ return
80
+ }
81
+
82
+ // Check if IP is loopback (127.0.0.1 or ::1)
83
+ if !ip.IsLoopback() {
84
+ log.Warnf("amp management: non-localhost connection from %s attempted access, denying", remoteAddr)
85
+ c.AbortWithStatusJSON(403, gin.H{
86
+ "error": "Access denied: management routes restricted to localhost",
87
+ })
88
+ return
89
+ }
90
+
91
+ c.Next()
92
+ }
93
+ }
94
+
95
+ // noCORSMiddleware disables CORS for management routes to prevent browser-based attacks.
96
+ // This overwrites any global CORS headers set by the server.
97
+ func noCORSMiddleware() gin.HandlerFunc {
98
+ return func(c *gin.Context) {
99
+ // Remove CORS headers to prevent cross-origin access from browsers
100
+ c.Header("Access-Control-Allow-Origin", "")
101
+ c.Header("Access-Control-Allow-Methods", "")
102
+ c.Header("Access-Control-Allow-Headers", "")
103
+ c.Header("Access-Control-Allow-Credentials", "")
104
+
105
+ // For OPTIONS preflight, deny with 403
106
+ if c.Request.Method == "OPTIONS" {
107
+ c.AbortWithStatus(403)
108
+ return
109
+ }
110
+
111
+ c.Next()
112
+ }
113
+ }
114
+
115
+ // managementAvailabilityMiddleware short-circuits management routes when the upstream
116
+ // proxy is disabled, preventing noisy localhost warnings and accidental exposure.
117
+ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
118
+ return func(c *gin.Context) {
119
+ if m.getProxy() == nil {
120
+ logging.SkipGinRequestLogging(c)
121
+ c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
122
+ "error": "amp upstream proxy not available",
123
+ })
124
+ return
125
+ }
126
+ c.Next()
127
+ }
128
+ }
129
+
130
+ // wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
131
+ func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
132
+ return func(c *gin.Context) {
133
+ path := c.Request.URL.Path
134
+ for _, prefix := range prefixes {
135
+ if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
136
+ c.Next()
137
+ return
138
+ }
139
+ }
140
+ auth(c)
141
+ }
142
+ }
143
+
144
+ // registerManagementRoutes registers Amp management proxy routes
145
+ // These routes proxy through to the Amp control plane for OAuth, user management, etc.
146
+ // Uses dynamic middleware and proxy getter for hot-reload support.
147
+ // The auth middleware validates Authorization header against configured API keys.
148
+ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
149
+ ampAPI := engine.Group("/api")
150
+
151
+ // Always disable CORS for management routes to prevent browser-based attacks
152
+ ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
153
+
154
+ // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
155
+ ampAPI.Use(m.localhostOnlyMiddleware())
156
+
157
+ // Apply authentication middleware - requires valid API key in Authorization header
158
+ var authWithBypass gin.HandlerFunc
159
+ if auth != nil {
160
+ ampAPI.Use(auth)
161
+ authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
162
+ }
163
+
164
+ // Inject client API key into request context for per-client upstream routing
165
+ ampAPI.Use(clientAPIKeyMiddleware())
166
+
167
+ // Dynamic proxy handler that uses m.getProxy() for hot-reload support
168
+ proxyHandler := func(c *gin.Context) {
169
+ // Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
170
+ defer func() {
171
+ if rec := recover(); rec != nil {
172
+ if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
173
+ // Upstream already wrote the status (often 404) before the client/stream ended.
174
+ return
175
+ }
176
+ panic(rec)
177
+ }
178
+ }()
179
+
180
+ proxy := m.getProxy()
181
+ if proxy == nil {
182
+ c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
183
+ return
184
+ }
185
+ proxy.ServeHTTP(c.Writer, c.Request)
186
+ }
187
+
188
+ // Management routes - these are proxied directly to Amp upstream
189
+ ampAPI.Any("/internal", proxyHandler)
190
+ ampAPI.Any("/internal/*path", proxyHandler)
191
+ ampAPI.Any("/user", proxyHandler)
192
+ ampAPI.Any("/user/*path", proxyHandler)
193
+ ampAPI.Any("/auth", proxyHandler)
194
+ ampAPI.Any("/auth/*path", proxyHandler)
195
+ ampAPI.Any("/meta", proxyHandler)
196
+ ampAPI.Any("/meta/*path", proxyHandler)
197
+ ampAPI.Any("/ads", proxyHandler)
198
+ ampAPI.Any("/telemetry", proxyHandler)
199
+ ampAPI.Any("/telemetry/*path", proxyHandler)
200
+ ampAPI.Any("/threads", proxyHandler)
201
+ ampAPI.Any("/threads/*path", proxyHandler)
202
+ ampAPI.Any("/otel", proxyHandler)
203
+ ampAPI.Any("/otel/*path", proxyHandler)
204
+ ampAPI.Any("/tab", proxyHandler)
205
+ ampAPI.Any("/tab/*path", proxyHandler)
206
+
207
+ // Root-level routes that AMP CLI expects without /api prefix
208
+ // These need the same security middleware as the /api/* routes (dynamic for hot-reload)
209
+ rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
210
+ if authWithBypass != nil {
211
+ rootMiddleware = append(rootMiddleware, authWithBypass)
212
+ }
213
+ // Add clientAPIKeyMiddleware after auth for per-client upstream routing
214
+ rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
215
+ engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
216
+ engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
217
+ engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
218
+ engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
219
+ engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
220
+ engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
221
+
222
+ engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
223
+ engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
224
+
225
+ // Root-level auth routes for CLI login flow
226
+ // Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
227
+ // We proxy all /auth/* to support the complete OAuth flow
228
+ engine.Any("/auth", append(rootMiddleware, proxyHandler)...)
229
+ engine.Any("/auth/*path", append(rootMiddleware, proxyHandler)...)
230
+
231
+ // Google v1beta1 passthrough with OAuth fallback
232
+ // AMP CLI uses non-standard paths like /publishers/google/models/...
233
+ // We bridge these to our standard Gemini handler to enable local OAuth.
234
+ // If no local OAuth is available, falls back to ampcode.com proxy.
235
+ geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
236
+ geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
237
+ geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
238
+ return m.getProxy()
239
+ }, m.modelMapper, m.forceModelMappings)
240
+ geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
241
+
242
+ // Route POST model calls through Gemini bridge with FallbackHandler.
243
+ // FallbackHandler checks provider -> mapping -> proxy fallback automatically.
244
+ // All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
245
+ ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
246
+ if c.Request.Method == "POST" {
247
+ if path := c.Param("path"); strings.Contains(path, "/models/") {
248
+ // POST with /models/ path -> use Gemini bridge with fallback handler
249
+ // FallbackHandler will check provider/mapping and proxy if needed
250
+ geminiV1Beta1Handler(c)
251
+ return
252
+ }
253
+ }
254
+ // Non-POST or no local provider available -> proxy upstream
255
+ proxyHandler(c)
256
+ })
257
+ }
258
+
259
+ // registerProviderAliases registers /api/provider/{provider}/... routes
260
+ // These allow Amp CLI to route requests like:
261
+ //
262
+ // /api/provider/openai/v1/chat/completions
263
+ // /api/provider/anthropic/v1/messages
264
+ // /api/provider/google/v1beta/models
265
+ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
266
+ // Create handler instances for different providers
267
+ openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler)
268
+ geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
269
+ claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
270
+ openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
271
+
272
+ // Create fallback handler wrapper that forwards to ampcode.com when provider not found
273
+ // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
274
+ // Also includes model mapping support for routing unavailable models to alternatives
275
+ fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
276
+ return m.getProxy()
277
+ }, m.modelMapper, m.forceModelMappings)
278
+
279
+ // Provider-specific routes under /api/provider/:provider
280
+ ampProviders := engine.Group("/api/provider")
281
+ if auth != nil {
282
+ ampProviders.Use(auth)
283
+ }
284
+ // Inject client API key into request context for per-client upstream routing
285
+ ampProviders.Use(clientAPIKeyMiddleware())
286
+
287
+ provider := ampProviders.Group("/:provider")
288
+
289
+ // Dynamic models handler - routes to appropriate provider based on path parameter
290
+ ampModelsHandler := func(c *gin.Context) {
291
+ providerName := strings.ToLower(c.Param("provider"))
292
+
293
+ switch providerName {
294
+ case "anthropic":
295
+ claudeCodeHandlers.ClaudeModels(c)
296
+ case "google":
297
+ geminiHandlers.GeminiModels(c)
298
+ default:
299
+ // Default to OpenAI-compatible (works for openai, groq, cerebras, etc.)
300
+ openaiHandlers.OpenAIModels(c)
301
+ }
302
+ }
303
+
304
+ // Root-level routes (for providers that omit /v1, like groq/cerebras)
305
+ // Wrap handlers with fallback logic to forward to ampcode.com when provider not found
306
+ provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
307
+ provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
308
+ provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
309
+ provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
310
+
311
+ // /v1 routes (OpenAI/Claude-compatible endpoints)
312
+ v1Amp := provider.Group("/v1")
313
+ {
314
+ v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
315
+
316
+ // OpenAI-compatible endpoints with fallback
317
+ v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
318
+ v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
319
+ v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
320
+
321
+ // Claude/Anthropic-compatible endpoints with fallback
322
+ v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
323
+ v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
324
+ }
325
+
326
+ // /v1beta routes (Gemini native API)
327
+ // Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
328
+ v1betaAmp := provider.Group("/v1beta")
329
+ {
330
+ v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
331
+ v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
332
+ v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
333
+ }
334
+ }
internal/api/modules/amp/routes_test.go ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "net/http"
5
+ "net/http/httptest"
6
+ "testing"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
10
+ )
11
+
12
+ func TestRegisterManagementRoutes(t *testing.T) {
13
+ gin.SetMode(gin.TestMode)
14
+ r := gin.New()
15
+
16
+ // Create module with proxy for testing
17
+ m := &AmpModule{
18
+ restrictToLocalhost: false, // disable localhost restriction for tests
19
+ }
20
+
21
+ // Create a mock proxy that tracks calls
22
+ proxyCalled := false
23
+ mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24
+ proxyCalled = true
25
+ w.WriteHeader(200)
26
+ w.Write([]byte("proxied"))
27
+ }))
28
+ defer mockProxy.Close()
29
+
30
+ // Create real proxy to mock server
31
+ proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
32
+ m.setProxy(proxy)
33
+
34
+ base := &handlers.BaseAPIHandler{}
35
+ m.registerManagementRoutes(r, base, nil)
36
+ srv := httptest.NewServer(r)
37
+ defer srv.Close()
38
+
39
+ managementPaths := []struct {
40
+ path string
41
+ method string
42
+ }{
43
+ {"/api/internal", http.MethodGet},
44
+ {"/api/internal/some/path", http.MethodGet},
45
+ {"/api/user", http.MethodGet},
46
+ {"/api/user/profile", http.MethodGet},
47
+ {"/api/auth", http.MethodGet},
48
+ {"/api/auth/login", http.MethodGet},
49
+ {"/api/meta", http.MethodGet},
50
+ {"/api/telemetry", http.MethodGet},
51
+ {"/api/threads", http.MethodGet},
52
+ {"/threads/", http.MethodGet},
53
+ {"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
54
+ {"/api/otel", http.MethodGet},
55
+ {"/api/tab", http.MethodGet},
56
+ {"/api/tab/some/path", http.MethodGet},
57
+ {"/auth", http.MethodGet}, // Root-level auth route
58
+ {"/auth/cli-login", http.MethodGet}, // CLI login flow
59
+ {"/auth/callback", http.MethodGet}, // OAuth callback
60
+ // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
61
+ {"/api/provider/google/v1beta1/models", http.MethodGet},
62
+ {"/api/provider/google/v1beta1/models", http.MethodPost},
63
+ }
64
+
65
+ for _, path := range managementPaths {
66
+ t.Run(path.path, func(t *testing.T) {
67
+ proxyCalled = false
68
+ req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
69
+ if err != nil {
70
+ t.Fatalf("failed to build request: %v", err)
71
+ }
72
+ resp, err := http.DefaultClient.Do(req)
73
+ if err != nil {
74
+ t.Fatalf("request failed: %v", err)
75
+ }
76
+ defer resp.Body.Close()
77
+
78
+ if resp.StatusCode == http.StatusNotFound {
79
+ t.Fatalf("route %s not registered", path.path)
80
+ }
81
+ if !proxyCalled {
82
+ t.Fatalf("proxy handler not called for %s", path.path)
83
+ }
84
+ })
85
+ }
86
+ }
87
+
88
+ func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
89
+ gin.SetMode(gin.TestMode)
90
+ r := gin.New()
91
+
92
+ // Minimal base handler setup (no need to initialize, just check routing)
93
+ base := &handlers.BaseAPIHandler{}
94
+
95
+ // Track if auth middleware was called
96
+ authCalled := false
97
+ authMiddleware := func(c *gin.Context) {
98
+ authCalled = true
99
+ c.Header("X-Auth", "ok")
100
+ // Abort with success to avoid calling the actual handler (which needs full setup)
101
+ c.AbortWithStatus(http.StatusOK)
102
+ }
103
+
104
+ m := &AmpModule{authMiddleware_: authMiddleware}
105
+ m.registerProviderAliases(r, base, authMiddleware)
106
+
107
+ paths := []struct {
108
+ path string
109
+ method string
110
+ }{
111
+ {"/api/provider/openai/models", http.MethodGet},
112
+ {"/api/provider/anthropic/models", http.MethodGet},
113
+ {"/api/provider/google/models", http.MethodGet},
114
+ {"/api/provider/groq/models", http.MethodGet},
115
+ {"/api/provider/openai/chat/completions", http.MethodPost},
116
+ {"/api/provider/anthropic/v1/messages", http.MethodPost},
117
+ {"/api/provider/google/v1beta/models", http.MethodGet},
118
+ }
119
+
120
+ for _, tc := range paths {
121
+ t.Run(tc.path, func(t *testing.T) {
122
+ authCalled = false
123
+ req := httptest.NewRequest(tc.method, tc.path, nil)
124
+ w := httptest.NewRecorder()
125
+ r.ServeHTTP(w, req)
126
+
127
+ if w.Code == http.StatusNotFound {
128
+ t.Fatalf("route %s %s not registered", tc.method, tc.path)
129
+ }
130
+ if !authCalled {
131
+ t.Fatalf("auth middleware not executed for %s", tc.path)
132
+ }
133
+ if w.Header().Get("X-Auth") != "ok" {
134
+ t.Fatalf("auth middleware header not set for %s", tc.path)
135
+ }
136
+ })
137
+ }
138
+ }
139
+
140
+ func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
141
+ gin.SetMode(gin.TestMode)
142
+ r := gin.New()
143
+
144
+ base := &handlers.BaseAPIHandler{}
145
+
146
+ m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
147
+ m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
148
+
149
+ providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
150
+
151
+ for _, provider := range providers {
152
+ t.Run(provider, func(t *testing.T) {
153
+ path := "/api/provider/" + provider + "/models"
154
+ req := httptest.NewRequest(http.MethodGet, path, nil)
155
+ w := httptest.NewRecorder()
156
+ r.ServeHTTP(w, req)
157
+
158
+ // Should not 404
159
+ if w.Code == http.StatusNotFound {
160
+ t.Fatalf("models route not found for provider: %s", provider)
161
+ }
162
+ })
163
+ }
164
+ }
165
+
166
+ func TestRegisterProviderAliases_V1Routes(t *testing.T) {
167
+ gin.SetMode(gin.TestMode)
168
+ r := gin.New()
169
+
170
+ base := &handlers.BaseAPIHandler{}
171
+
172
+ m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
173
+ m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
174
+
175
+ v1Paths := []struct {
176
+ path string
177
+ method string
178
+ }{
179
+ {"/api/provider/openai/v1/models", http.MethodGet},
180
+ {"/api/provider/openai/v1/chat/completions", http.MethodPost},
181
+ {"/api/provider/openai/v1/completions", http.MethodPost},
182
+ {"/api/provider/anthropic/v1/messages", http.MethodPost},
183
+ {"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
184
+ }
185
+
186
+ for _, tc := range v1Paths {
187
+ t.Run(tc.path, func(t *testing.T) {
188
+ req := httptest.NewRequest(tc.method, tc.path, nil)
189
+ w := httptest.NewRecorder()
190
+ r.ServeHTTP(w, req)
191
+
192
+ if w.Code == http.StatusNotFound {
193
+ t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
194
+ }
195
+ })
196
+ }
197
+ }
198
+
199
+ func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
200
+ gin.SetMode(gin.TestMode)
201
+ r := gin.New()
202
+
203
+ base := &handlers.BaseAPIHandler{}
204
+
205
+ m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
206
+ m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
207
+
208
+ v1betaPaths := []struct {
209
+ path string
210
+ method string
211
+ }{
212
+ {"/api/provider/google/v1beta/models", http.MethodGet},
213
+ {"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
214
+ }
215
+
216
+ for _, tc := range v1betaPaths {
217
+ t.Run(tc.path, func(t *testing.T) {
218
+ req := httptest.NewRequest(tc.method, tc.path, nil)
219
+ w := httptest.NewRecorder()
220
+ r.ServeHTTP(w, req)
221
+
222
+ if w.Code == http.StatusNotFound {
223
+ t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
224
+ }
225
+ })
226
+ }
227
+ }
228
+
229
+ func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
230
+ // Test that routes still register even if auth middleware is nil (fallback behavior)
231
+ gin.SetMode(gin.TestMode)
232
+ r := gin.New()
233
+
234
+ base := &handlers.BaseAPIHandler{}
235
+
236
+ m := &AmpModule{authMiddleware_: nil} // No auth middleware
237
+ m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
238
+
239
+ req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
240
+ w := httptest.NewRecorder()
241
+ r.ServeHTTP(w, req)
242
+
243
+ // Should still work (with fallback no-op auth)
244
+ if w.Code == http.StatusNotFound {
245
+ t.Fatal("routes should register even without auth middleware")
246
+ }
247
+ }
248
+
249
+ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
250
+ gin.SetMode(gin.TestMode)
251
+ r := gin.New()
252
+
253
+ // Create module with localhost restriction enabled
254
+ m := &AmpModule{
255
+ restrictToLocalhost: true,
256
+ }
257
+
258
+ // Apply dynamic localhost-only middleware
259
+ r.Use(m.localhostOnlyMiddleware())
260
+ r.GET("/test", func(c *gin.Context) {
261
+ c.String(http.StatusOK, "ok")
262
+ })
263
+
264
+ tests := []struct {
265
+ name string
266
+ remoteAddr string
267
+ forwardedFor string
268
+ expectedStatus int
269
+ description string
270
+ }{
271
+ {
272
+ name: "spoofed_header_remote_connection",
273
+ remoteAddr: "192.168.1.100:12345",
274
+ forwardedFor: "127.0.0.1",
275
+ expectedStatus: http.StatusForbidden,
276
+ description: "Spoofed X-Forwarded-For header should be ignored",
277
+ },
278
+ {
279
+ name: "real_localhost_ipv4",
280
+ remoteAddr: "127.0.0.1:54321",
281
+ forwardedFor: "",
282
+ expectedStatus: http.StatusOK,
283
+ description: "Real localhost IPv4 connection should work",
284
+ },
285
+ {
286
+ name: "real_localhost_ipv6",
287
+ remoteAddr: "[::1]:54321",
288
+ forwardedFor: "",
289
+ expectedStatus: http.StatusOK,
290
+ description: "Real localhost IPv6 connection should work",
291
+ },
292
+ {
293
+ name: "remote_ipv4",
294
+ remoteAddr: "203.0.113.42:8080",
295
+ forwardedFor: "",
296
+ expectedStatus: http.StatusForbidden,
297
+ description: "Remote IPv4 connection should be blocked",
298
+ },
299
+ {
300
+ name: "remote_ipv6",
301
+ remoteAddr: "[2001:db8::1]:9090",
302
+ forwardedFor: "",
303
+ expectedStatus: http.StatusForbidden,
304
+ description: "Remote IPv6 connection should be blocked",
305
+ },
306
+ {
307
+ name: "spoofed_localhost_ipv6",
308
+ remoteAddr: "203.0.113.42:8080",
309
+ forwardedFor: "::1",
310
+ expectedStatus: http.StatusForbidden,
311
+ description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
312
+ },
313
+ }
314
+
315
+ for _, tt := range tests {
316
+ t.Run(tt.name, func(t *testing.T) {
317
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
318
+ req.RemoteAddr = tt.remoteAddr
319
+ if tt.forwardedFor != "" {
320
+ req.Header.Set("X-Forwarded-For", tt.forwardedFor)
321
+ }
322
+
323
+ w := httptest.NewRecorder()
324
+ r.ServeHTTP(w, req)
325
+
326
+ if w.Code != tt.expectedStatus {
327
+ t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
328
+ }
329
+ })
330
+ }
331
+ }
332
+
333
+ func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
334
+ gin.SetMode(gin.TestMode)
335
+ r := gin.New()
336
+
337
+ // Create module with localhost restriction initially enabled
338
+ m := &AmpModule{
339
+ restrictToLocalhost: true,
340
+ }
341
+
342
+ // Apply dynamic localhost-only middleware
343
+ r.Use(m.localhostOnlyMiddleware())
344
+ r.GET("/test", func(c *gin.Context) {
345
+ c.String(http.StatusOK, "ok")
346
+ })
347
+
348
+ // Test 1: Remote IP should be blocked when restriction is enabled
349
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
350
+ req.RemoteAddr = "192.168.1.100:12345"
351
+ w := httptest.NewRecorder()
352
+ r.ServeHTTP(w, req)
353
+
354
+ if w.Code != http.StatusForbidden {
355
+ t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
356
+ }
357
+
358
+ // Test 2: Hot-reload - disable restriction
359
+ m.setRestrictToLocalhost(false)
360
+
361
+ req = httptest.NewRequest(http.MethodGet, "/test", nil)
362
+ req.RemoteAddr = "192.168.1.100:12345"
363
+ w = httptest.NewRecorder()
364
+ r.ServeHTTP(w, req)
365
+
366
+ if w.Code != http.StatusOK {
367
+ t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
368
+ }
369
+
370
+ // Test 3: Hot-reload - re-enable restriction
371
+ m.setRestrictToLocalhost(true)
372
+
373
+ req = httptest.NewRequest(http.MethodGet, "/test", nil)
374
+ req.RemoteAddr = "192.168.1.100:12345"
375
+ w = httptest.NewRecorder()
376
+ r.ServeHTTP(w, req)
377
+
378
+ if w.Code != http.StatusForbidden {
379
+ t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
380
+ }
381
+ }
internal/api/modules/amp/secret.go ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "fmt"
7
+ "os"
8
+ "path/filepath"
9
+ "strings"
10
+ "sync"
11
+ "time"
12
+
13
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
14
+ log "github.com/sirupsen/logrus"
15
+ )
16
+
17
+ // SecretSource provides Amp API keys with configurable precedence and caching
18
+ type SecretSource interface {
19
+ Get(ctx context.Context) (string, error)
20
+ }
21
+
22
+ // cachedSecret holds a secret value with expiration
23
+ type cachedSecret struct {
24
+ value string
25
+ expiresAt time.Time
26
+ }
27
+
28
+ // MultiSourceSecret implements precedence-based secret lookup:
29
+ // 1. Explicit config value (highest priority)
30
+ // 2. Environment variable AMP_API_KEY
31
+ // 3. File-based secret (lowest priority)
32
+ type MultiSourceSecret struct {
33
+ explicitKey string
34
+ envKey string
35
+ filePath string
36
+ cacheTTL time.Duration
37
+
38
+ mu sync.RWMutex
39
+ cache *cachedSecret
40
+ }
41
+
42
+ // NewMultiSourceSecret creates a secret source with precedence and caching
43
+ func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret {
44
+ if cacheTTL == 0 {
45
+ cacheTTL = 5 * time.Minute // Default 5 minute cache
46
+ }
47
+
48
+ home, _ := os.UserHomeDir()
49
+ filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json")
50
+
51
+ return &MultiSourceSecret{
52
+ explicitKey: strings.TrimSpace(explicitKey),
53
+ envKey: "AMP_API_KEY",
54
+ filePath: filePath,
55
+ cacheTTL: cacheTTL,
56
+ }
57
+ }
58
+
59
+ // NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing)
60
+ func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret {
61
+ if cacheTTL == 0 {
62
+ cacheTTL = 5 * time.Minute
63
+ }
64
+
65
+ return &MultiSourceSecret{
66
+ explicitKey: strings.TrimSpace(explicitKey),
67
+ envKey: "AMP_API_KEY",
68
+ filePath: filePath,
69
+ cacheTTL: cacheTTL,
70
+ }
71
+ }
72
+
73
+ // Get retrieves the Amp API key using precedence: config > env > file
74
+ // Results are cached for cacheTTL duration to avoid excessive file reads
75
+ func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) {
76
+ // Precedence 1: Explicit config key (highest priority, no caching needed)
77
+ if s.explicitKey != "" {
78
+ return s.explicitKey, nil
79
+ }
80
+
81
+ // Precedence 2: Environment variable
82
+ if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" {
83
+ return envValue, nil
84
+ }
85
+
86
+ // Precedence 3: File-based secret (lowest priority, cached)
87
+ // Check cache first
88
+ s.mu.RLock()
89
+ if s.cache != nil && time.Now().Before(s.cache.expiresAt) {
90
+ value := s.cache.value
91
+ s.mu.RUnlock()
92
+ return value, nil
93
+ }
94
+ s.mu.RUnlock()
95
+
96
+ // Cache miss or expired - read from file
97
+ key, err := s.readFromFile()
98
+ if err != nil {
99
+ // Cache empty result to avoid repeated file reads on missing files
100
+ s.updateCache("")
101
+ return "", err
102
+ }
103
+
104
+ // Cache the result
105
+ s.updateCache(key)
106
+ return key, nil
107
+ }
108
+
109
+ // readFromFile reads the Amp API key from the secrets file
110
+ func (s *MultiSourceSecret) readFromFile() (string, error) {
111
+ content, err := os.ReadFile(s.filePath)
112
+ if err != nil {
113
+ if os.IsNotExist(err) {
114
+ return "", nil // Missing file is not an error, just no key available
115
+ }
116
+ return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err)
117
+ }
118
+
119
+ var secrets map[string]string
120
+ if err := json.Unmarshal(content, &secrets); err != nil {
121
+ return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err)
122
+ }
123
+
124
+ key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"])
125
+ return key, nil
126
+ }
127
+
128
+ // updateCache updates the cached secret value
129
+ func (s *MultiSourceSecret) updateCache(value string) {
130
+ s.mu.Lock()
131
+ defer s.mu.Unlock()
132
+ s.cache = &cachedSecret{
133
+ value: value,
134
+ expiresAt: time.Now().Add(s.cacheTTL),
135
+ }
136
+ }
137
+
138
+ // InvalidateCache clears the cached secret, forcing a fresh read on next Get
139
+ func (s *MultiSourceSecret) InvalidateCache() {
140
+ s.mu.Lock()
141
+ defer s.mu.Unlock()
142
+ s.cache = nil
143
+ }
144
+
145
+ // UpdateExplicitKey refreshes the config-provided key and clears cache.
146
+ func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
147
+ if s == nil {
148
+ return
149
+ }
150
+ s.mu.Lock()
151
+ s.explicitKey = strings.TrimSpace(key)
152
+ s.cache = nil
153
+ s.mu.Unlock()
154
+ }
155
+
156
+ // StaticSecretSource returns a fixed API key (for testing)
157
+ type StaticSecretSource struct {
158
+ key string
159
+ }
160
+
161
+ // NewStaticSecretSource creates a secret source with a fixed key
162
+ func NewStaticSecretSource(key string) *StaticSecretSource {
163
+ return &StaticSecretSource{key: strings.TrimSpace(key)}
164
+ }
165
+
166
+ // Get returns the static API key
167
+ func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
168
+ return s.key, nil
169
+ }
170
+
171
+ // MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
172
+ // When a request context contains a client API key that matches a configured mapping,
173
+ // the corresponding upstream key is returned. Otherwise, falls back to the default source.
174
+ type MappedSecretSource struct {
175
+ defaultSource SecretSource
176
+ mu sync.RWMutex
177
+ lookup map[string]string // clientKey -> upstreamKey
178
+ }
179
+
180
+ // NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
181
+ func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
182
+ return &MappedSecretSource{
183
+ defaultSource: defaultSource,
184
+ lookup: make(map[string]string),
185
+ }
186
+ }
187
+
188
+ // Get retrieves the Amp API key, checking per-client mappings first.
189
+ // If the request context contains a client API key that matches a configured mapping,
190
+ // returns the corresponding upstream key. Otherwise, falls back to the default source.
191
+ func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
192
+ // Try to get client API key from request context
193
+ clientKey := getClientAPIKeyFromContext(ctx)
194
+ if clientKey != "" {
195
+ s.mu.RLock()
196
+ if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
197
+ s.mu.RUnlock()
198
+ return upstreamKey, nil
199
+ }
200
+ s.mu.RUnlock()
201
+ }
202
+
203
+ // Fall back to default source
204
+ return s.defaultSource.Get(ctx)
205
+ }
206
+
207
+ // UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
208
+ // If the same client key appears in multiple entries, logs a warning and uses the first one.
209
+ func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
210
+ newLookup := make(map[string]string)
211
+
212
+ for _, entry := range entries {
213
+ upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
214
+ if upstreamKey == "" {
215
+ continue
216
+ }
217
+ for _, clientKey := range entry.APIKeys {
218
+ trimmedKey := strings.TrimSpace(clientKey)
219
+ if trimmedKey == "" {
220
+ continue
221
+ }
222
+ if _, exists := newLookup[trimmedKey]; exists {
223
+ // Log warning for duplicate client key, first one wins
224
+ log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
225
+ continue
226
+ }
227
+ newLookup[trimmedKey] = upstreamKey
228
+ }
229
+ }
230
+
231
+ s.mu.Lock()
232
+ s.lookup = newLookup
233
+ s.mu.Unlock()
234
+ }
235
+
236
+ // UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
237
+ func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
238
+ if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
239
+ ms.UpdateExplicitKey(key)
240
+ }
241
+ }
242
+
243
+ // InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
244
+ func (s *MappedSecretSource) InvalidateCache() {
245
+ if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
246
+ ms.InvalidateCache()
247
+ }
248
+ }
internal/api/modules/amp/secret_test.go ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package amp
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "os"
7
+ "path/filepath"
8
+ "sync"
9
+ "testing"
10
+ "time"
11
+
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
13
+ log "github.com/sirupsen/logrus"
14
+ "github.com/sirupsen/logrus/hooks/test"
15
+ )
16
+
17
+ func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
18
+ ctx := context.Background()
19
+
20
+ cases := []struct {
21
+ name string
22
+ configKey string
23
+ envKey string
24
+ fileJSON string
25
+ want string
26
+ }{
27
+ {"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"},
28
+ {"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
29
+ {"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
30
+ {"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
31
+ {"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
32
+ {"missing_file_returns_empty", "", "", "", ""},
33
+ {"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""},
34
+ }
35
+
36
+ for _, tc := range cases {
37
+ tc := tc // capture range variable
38
+ t.Run(tc.name, func(t *testing.T) {
39
+ tmpDir := t.TempDir()
40
+ secretsPath := filepath.Join(tmpDir, "secrets.json")
41
+
42
+ if tc.fileJSON != "" {
43
+ if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil {
44
+ t.Fatal(err)
45
+ }
46
+ }
47
+
48
+ t.Setenv("AMP_API_KEY", tc.envKey)
49
+
50
+ s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond)
51
+ got, err := s.Get(ctx)
52
+ if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) {
53
+ t.Fatalf("unexpected error: %v", err)
54
+ }
55
+ if got != tc.want {
56
+ t.Fatalf("want %q, got %q", tc.want, got)
57
+ }
58
+ })
59
+ }
60
+ }
61
+
62
+ func TestMultiSourceSecret_CacheBehavior(t *testing.T) {
63
+ ctx := context.Background()
64
+ tmpDir := t.TempDir()
65
+ p := filepath.Join(tmpDir, "secrets.json")
66
+
67
+ // Initial value
68
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
69
+ t.Fatal(err)
70
+ }
71
+
72
+ s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond)
73
+
74
+ // First read - should return v1
75
+ got1, err := s.Get(ctx)
76
+ if err != nil {
77
+ t.Fatalf("Get failed: %v", err)
78
+ }
79
+ if got1 != "v1" {
80
+ t.Fatalf("expected v1, got %s", got1)
81
+ }
82
+
83
+ // Change file; within TTL we should still see v1 (cached)
84
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil {
85
+ t.Fatal(err)
86
+ }
87
+ got2, _ := s.Get(ctx)
88
+ if got2 != "v1" {
89
+ t.Fatalf("cache hit expected v1, got %s", got2)
90
+ }
91
+
92
+ // After TTL expires, should see v2
93
+ time.Sleep(60 * time.Millisecond)
94
+ got3, _ := s.Get(ctx)
95
+ if got3 != "v2" {
96
+ t.Fatalf("cache miss expected v2, got %s", got3)
97
+ }
98
+
99
+ // Invalidate forces re-read immediately
100
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil {
101
+ t.Fatal(err)
102
+ }
103
+ s.InvalidateCache()
104
+ got4, _ := s.Get(ctx)
105
+ if got4 != "v3" {
106
+ t.Fatalf("invalidate expected v3, got %s", got4)
107
+ }
108
+ }
109
+
110
+ func TestMultiSourceSecret_FileHandling(t *testing.T) {
111
+ ctx := context.Background()
112
+
113
+ t.Run("missing_file_no_error", func(t *testing.T) {
114
+ s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond)
115
+ got, err := s.Get(ctx)
116
+ if err != nil {
117
+ t.Fatalf("expected no error for missing file, got: %v", err)
118
+ }
119
+ if got != "" {
120
+ t.Fatalf("expected empty string, got %q", got)
121
+ }
122
+ })
123
+
124
+ t.Run("invalid_json", func(t *testing.T) {
125
+ tmpDir := t.TempDir()
126
+ p := filepath.Join(tmpDir, "secrets.json")
127
+ if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil {
128
+ t.Fatal(err)
129
+ }
130
+
131
+ s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
132
+ _, err := s.Get(ctx)
133
+ if err == nil {
134
+ t.Fatal("expected error for invalid JSON")
135
+ }
136
+ })
137
+
138
+ t.Run("missing_key_in_json", func(t *testing.T) {
139
+ tmpDir := t.TempDir()
140
+ p := filepath.Join(tmpDir, "secrets.json")
141
+ if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil {
142
+ t.Fatal(err)
143
+ }
144
+
145
+ s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
146
+ got, err := s.Get(ctx)
147
+ if err != nil {
148
+ t.Fatalf("unexpected error: %v", err)
149
+ }
150
+ if got != "" {
151
+ t.Fatalf("expected empty string for missing key, got %q", got)
152
+ }
153
+ })
154
+
155
+ t.Run("empty_key_value", func(t *testing.T) {
156
+ tmpDir := t.TempDir()
157
+ p := filepath.Join(tmpDir, "secrets.json")
158
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil {
159
+ t.Fatal(err)
160
+ }
161
+
162
+ s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
163
+ got, _ := s.Get(ctx)
164
+ if got != "" {
165
+ t.Fatalf("expected empty after trim, got %q", got)
166
+ }
167
+ })
168
+ }
169
+
170
+ func TestMultiSourceSecret_Concurrency(t *testing.T) {
171
+ tmpDir := t.TempDir()
172
+ p := filepath.Join(tmpDir, "secrets.json")
173
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil {
174
+ t.Fatal(err)
175
+ }
176
+
177
+ s := NewMultiSourceSecretWithPath("", p, 5*time.Second)
178
+ ctx := context.Background()
179
+
180
+ // Spawn many goroutines calling Get concurrently
181
+ const goroutines = 50
182
+ const iterations = 100
183
+
184
+ var wg sync.WaitGroup
185
+ errors := make(chan error, goroutines)
186
+
187
+ for i := 0; i < goroutines; i++ {
188
+ wg.Add(1)
189
+ go func() {
190
+ defer wg.Done()
191
+ for j := 0; j < iterations; j++ {
192
+ val, err := s.Get(ctx)
193
+ if err != nil {
194
+ errors <- err
195
+ return
196
+ }
197
+ if val != "concurrent" {
198
+ errors <- err
199
+ return
200
+ }
201
+ }
202
+ }()
203
+ }
204
+
205
+ wg.Wait()
206
+ close(errors)
207
+
208
+ for err := range errors {
209
+ t.Errorf("concurrency error: %v", err)
210
+ }
211
+ }
212
+
213
+ func TestStaticSecretSource(t *testing.T) {
214
+ ctx := context.Background()
215
+
216
+ t.Run("returns_provided_key", func(t *testing.T) {
217
+ s := NewStaticSecretSource("test-key-123")
218
+ got, err := s.Get(ctx)
219
+ if err != nil {
220
+ t.Fatalf("unexpected error: %v", err)
221
+ }
222
+ if got != "test-key-123" {
223
+ t.Fatalf("want test-key-123, got %q", got)
224
+ }
225
+ })
226
+
227
+ t.Run("trims_whitespace", func(t *testing.T) {
228
+ s := NewStaticSecretSource(" test-key ")
229
+ got, err := s.Get(ctx)
230
+ if err != nil {
231
+ t.Fatalf("unexpected error: %v", err)
232
+ }
233
+ if got != "test-key" {
234
+ t.Fatalf("want test-key, got %q", got)
235
+ }
236
+ })
237
+
238
+ t.Run("empty_string", func(t *testing.T) {
239
+ s := NewStaticSecretSource("")
240
+ got, err := s.Get(ctx)
241
+ if err != nil {
242
+ t.Fatalf("unexpected error: %v", err)
243
+ }
244
+ if got != "" {
245
+ t.Fatalf("want empty string, got %q", got)
246
+ }
247
+ })
248
+ }
249
+
250
+ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
251
+ // Test that missing file results are cached to avoid repeated file reads
252
+ tmpDir := t.TempDir()
253
+ p := filepath.Join(tmpDir, "nonexistent.json")
254
+
255
+ s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
256
+ ctx := context.Background()
257
+
258
+ // First call - file doesn't exist, should cache empty result
259
+ got1, err := s.Get(ctx)
260
+ if err != nil {
261
+ t.Fatalf("expected no error for missing file, got: %v", err)
262
+ }
263
+ if got1 != "" {
264
+ t.Fatalf("expected empty string, got %q", got1)
265
+ }
266
+
267
+ // Create the file now
268
+ if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil {
269
+ t.Fatal(err)
270
+ }
271
+
272
+ // Second call - should still return empty (cached), not read the new file
273
+ got2, _ := s.Get(ctx)
274
+ if got2 != "" {
275
+ t.Fatalf("cache should return empty, got %q", got2)
276
+ }
277
+
278
+ // After TTL expires, should see the new value
279
+ time.Sleep(110 * time.Millisecond)
280
+ got3, _ := s.Get(ctx)
281
+ if got3 != "new-value" {
282
+ t.Fatalf("after cache expiry, expected new-value, got %q", got3)
283
+ }
284
+ }
285
+
286
+ func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
287
+ defaultSource := NewStaticSecretSource("default")
288
+ s := NewMappedSecretSource(defaultSource)
289
+ s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
290
+ {
291
+ UpstreamAPIKey: "u1",
292
+ APIKeys: []string{"k1"},
293
+ },
294
+ })
295
+
296
+ ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
297
+ got, err := s.Get(ctx)
298
+ if err != nil {
299
+ t.Fatalf("unexpected error: %v", err)
300
+ }
301
+ if got != "u1" {
302
+ t.Fatalf("want u1, got %q", got)
303
+ }
304
+
305
+ ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
306
+ got, err = s.Get(ctx)
307
+ if err != nil {
308
+ t.Fatalf("unexpected error: %v", err)
309
+ }
310
+ if got != "default" {
311
+ t.Fatalf("want default fallback, got %q", got)
312
+ }
313
+ }
314
+
315
+ func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
316
+ defaultSource := NewStaticSecretSource("default")
317
+ s := NewMappedSecretSource(defaultSource)
318
+ s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
319
+ {
320
+ UpstreamAPIKey: "u1",
321
+ APIKeys: []string{"k1"},
322
+ },
323
+ {
324
+ UpstreamAPIKey: "u2",
325
+ APIKeys: []string{"k1"},
326
+ },
327
+ })
328
+
329
+ ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
330
+ got, err := s.Get(ctx)
331
+ if err != nil {
332
+ t.Fatalf("unexpected error: %v", err)
333
+ }
334
+ if got != "u1" {
335
+ t.Fatalf("want u1 (first wins), got %q", got)
336
+ }
337
+ }
338
+
339
+ func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
340
+ hook := test.NewLocal(log.StandardLogger())
341
+ defer hook.Reset()
342
+
343
+ defaultSource := NewStaticSecretSource("default")
344
+ s := NewMappedSecretSource(defaultSource)
345
+ s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
346
+ {
347
+ UpstreamAPIKey: "u1",
348
+ APIKeys: []string{"k1"},
349
+ },
350
+ {
351
+ UpstreamAPIKey: "u2",
352
+ APIKeys: []string{"k1"},
353
+ },
354
+ })
355
+
356
+ foundWarning := false
357
+ for _, entry := range hook.AllEntries() {
358
+ if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
359
+ foundWarning = true
360
+ break
361
+ }
362
+ }
363
+ if !foundWarning {
364
+ t.Fatal("expected warning log for duplicate client key, but none was found")
365
+ }
366
+ }
internal/api/modules/modules.go ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package modules provides a pluggable routing module system for extending
2
+ // the API server with optional features without modifying core routing logic.
3
+ package modules
4
+
5
+ import (
6
+ "fmt"
7
+
8
+ "github.com/gin-gonic/gin"
9
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
10
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
11
+ )
12
+
13
+ // Context encapsulates the dependencies exposed to routing modules during
14
+ // registration. Modules can use the Gin engine to attach routes, the shared
15
+ // BaseAPIHandler for constructing SDK-specific handlers, and the resolved
16
+ // authentication middleware for protecting routes that require API keys.
17
+ type Context struct {
18
+ Engine *gin.Engine
19
+ BaseHandler *handlers.BaseAPIHandler
20
+ Config *config.Config
21
+ AuthMiddleware gin.HandlerFunc
22
+ }
23
+
24
+ // RouteModule represents a pluggable routing module that can register routes
25
+ // and handle configuration updates independently of the core server.
26
+ //
27
+ // DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for
28
+ // backwards compatibility and will be removed in a future version.
29
+ type RouteModule interface {
30
+ // Name returns a human-readable identifier for the module
31
+ Name() string
32
+
33
+ // Register sets up routes and handlers for this module.
34
+ // It receives the Gin engine, base handlers, and current configuration.
35
+ // Returns an error if registration fails (errors are logged but don't stop the server).
36
+ Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error
37
+
38
+ // OnConfigUpdated is called when the configuration is reloaded.
39
+ // Modules can respond to configuration changes here.
40
+ // Returns an error if the update cannot be applied.
41
+ OnConfigUpdated(cfg *config.Config) error
42
+ }
43
+
44
+ // RouteModuleV2 represents a pluggable bundle of routes that can integrate with
45
+ // the API server without modifying its core routing logic. Implementations can
46
+ // attach routes during Register and react to configuration updates via
47
+ // OnConfigUpdated.
48
+ //
49
+ // This is the preferred interface for new modules. It uses Context for cleaner
50
+ // dependency injection and supports idempotent registration.
51
+ type RouteModuleV2 interface {
52
+ // Name returns a unique identifier for logging and diagnostics.
53
+ Name() string
54
+
55
+ // Register wires the module's routes into the provided Gin engine. Modules
56
+ // should treat multiple calls as idempotent and avoid duplicate route
57
+ // registration when invoked more than once.
58
+ Register(ctx Context) error
59
+
60
+ // OnConfigUpdated notifies the module when the server configuration changes
61
+ // via hot reload. Implementations can refresh cached state or emit warnings.
62
+ OnConfigUpdated(cfg *config.Config) error
63
+ }
64
+
65
+ // RegisterModule is a helper that registers a module using either the V1 or V2
66
+ // interface. This allows gradual migration from V1 to V2 without breaking
67
+ // existing modules.
68
+ //
69
+ // Example usage:
70
+ //
71
+ // ctx := modules.Context{
72
+ // Engine: engine,
73
+ // BaseHandler: baseHandler,
74
+ // Config: cfg,
75
+ // AuthMiddleware: authMiddleware,
76
+ // }
77
+ // if err := modules.RegisterModule(ctx, ampModule); err != nil {
78
+ // log.Errorf("Failed to register module: %v", err)
79
+ // }
80
+ func RegisterModule(ctx Context, mod interface{}) error {
81
+ // Try V2 interface first (preferred)
82
+ if v2, ok := mod.(RouteModuleV2); ok {
83
+ return v2.Register(ctx)
84
+ }
85
+
86
+ // Fall back to V1 interface for backwards compatibility
87
+ if v1, ok := mod.(RouteModule); ok {
88
+ return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config)
89
+ }
90
+
91
+ return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod)
92
+ }
internal/api/server.go ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package api provides the HTTP API server implementation for the CLI Proxy API.
2
+ // It includes the main server struct, routing setup, middleware for CORS and authentication,
3
+ // and integration with various AI API handlers (OpenAI, Claude, Gemini).
4
+ // The server supports hot-reloading of clients and configuration.
5
+ package api
6
+
7
+ import (
8
+ "context"
9
+ "crypto/subtle"
10
+ "errors"
11
+ "fmt"
12
+ "net/http"
13
+ "os"
14
+ "path/filepath"
15
+ "strings"
16
+ "sync"
17
+ "sync/atomic"
18
+ "time"
19
+
20
+ "github.com/gin-gonic/gin"
21
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/access"
22
+ managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
23
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
24
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
25
+ ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
26
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
27
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
28
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
29
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
30
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
31
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
32
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
33
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
34
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
35
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
36
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
37
+ log "github.com/sirupsen/logrus"
38
+ "gopkg.in/yaml.v3"
39
+ )
40
+
41
+ const oauthCallbackSuccessHTML = `<html><head><meta charset="utf-8"><title>Authentication successful</title><script>setTimeout(function(){window.close();},5000);</script></head><body><h1>Authentication successful!</h1><p>You can close this window.</p><p>This window will close automatically in 5 seconds.</p></body></html>`
42
+
43
+ type serverOptionConfig struct {
44
+ extraMiddleware []gin.HandlerFunc
45
+ engineConfigurator func(*gin.Engine)
46
+ routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
47
+ requestLoggerFactory func(*config.Config, string) logging.RequestLogger
48
+ localPassword string
49
+ keepAliveEnabled bool
50
+ keepAliveTimeout time.Duration
51
+ keepAliveOnTimeout func()
52
+ }
53
+
54
+ // ServerOption customises HTTP server construction.
55
+ type ServerOption func(*serverOptionConfig)
56
+
57
+ func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
58
+ configDir := filepath.Dir(configPath)
59
+ if base := util.WritablePath(); base != "" {
60
+ return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
61
+ }
62
+ return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
63
+ }
64
+
65
+ // WithMiddleware appends additional Gin middleware during server construction.
66
+ func WithMiddleware(mw ...gin.HandlerFunc) ServerOption {
67
+ return func(cfg *serverOptionConfig) {
68
+ cfg.extraMiddleware = append(cfg.extraMiddleware, mw...)
69
+ }
70
+ }
71
+
72
+ // WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
73
+ func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
74
+ return func(cfg *serverOptionConfig) {
75
+ cfg.engineConfigurator = fn
76
+ }
77
+ }
78
+
79
+ // WithRouterConfigurator appends a callback after default routes are registered.
80
+ func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
81
+ return func(cfg *serverOptionConfig) {
82
+ cfg.routerConfigurator = fn
83
+ }
84
+ }
85
+
86
+ // WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
87
+ func WithLocalManagementPassword(password string) ServerOption {
88
+ return func(cfg *serverOptionConfig) {
89
+ cfg.localPassword = password
90
+ }
91
+ }
92
+
93
+ // WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
94
+ func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
95
+ return func(cfg *serverOptionConfig) {
96
+ if timeout <= 0 || onTimeout == nil {
97
+ return
98
+ }
99
+ cfg.keepAliveEnabled = true
100
+ cfg.keepAliveTimeout = timeout
101
+ cfg.keepAliveOnTimeout = onTimeout
102
+ }
103
+ }
104
+
105
+ // WithRequestLoggerFactory customises request logger creation.
106
+ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
107
+ return func(cfg *serverOptionConfig) {
108
+ cfg.requestLoggerFactory = factory
109
+ }
110
+ }
111
+
112
+ // Server represents the main API server.
113
+ // It encapsulates the Gin engine, HTTP server, handlers, and configuration.
114
+ type Server struct {
115
+ // engine is the Gin web framework engine instance.
116
+ engine *gin.Engine
117
+
118
+ // server is the underlying HTTP server.
119
+ server *http.Server
120
+
121
+ // handlers contains the API handlers for processing requests.
122
+ handlers *handlers.BaseAPIHandler
123
+
124
+ // cfg holds the current server configuration.
125
+ cfg *config.Config
126
+
127
+ // oldConfigYaml stores a YAML snapshot of the previous configuration for change detection.
128
+ // This prevents issues when the config object is modified in place by Management API.
129
+ oldConfigYaml []byte
130
+
131
+ // accessManager handles request authentication providers.
132
+ accessManager *sdkaccess.Manager
133
+
134
+ // requestLogger is the request logger instance for dynamic configuration updates.
135
+ requestLogger logging.RequestLogger
136
+ loggerToggle func(bool)
137
+
138
+ // configFilePath is the absolute path to the YAML config file for persistence.
139
+ configFilePath string
140
+
141
+ // currentPath is the absolute path to the current working directory.
142
+ currentPath string
143
+
144
+ // wsRoutes tracks registered websocket upgrade paths.
145
+ wsRouteMu sync.Mutex
146
+ wsRoutes map[string]struct{}
147
+ wsAuthChanged func(bool, bool)
148
+ wsAuthEnabled atomic.Bool
149
+
150
+ // management handler
151
+ mgmt *managementHandlers.Handler
152
+
153
+ // ampModule is the Amp routing module for model mapping hot-reload
154
+ ampModule *ampmodule.AmpModule
155
+
156
+ // managementRoutesRegistered tracks whether the management routes have been attached to the engine.
157
+ managementRoutesRegistered atomic.Bool
158
+ // managementRoutesEnabled controls whether management endpoints serve real handlers.
159
+ managementRoutesEnabled atomic.Bool
160
+
161
+ // envManagementSecret indicates whether MANAGEMENT_PASSWORD is configured.
162
+ envManagementSecret bool
163
+
164
+ localPassword string
165
+
166
+ keepAliveEnabled bool
167
+ keepAliveTimeout time.Duration
168
+ keepAliveOnTimeout func()
169
+ keepAliveHeartbeat chan struct{}
170
+ keepAliveStop chan struct{}
171
+ }
172
+
173
+ // NewServer creates and initializes a new API server instance.
174
+ // It sets up the Gin engine, middleware, routes, and handlers.
175
+ //
176
+ // Parameters:
177
+ // - cfg: The server configuration
178
+ // - authManager: core runtime auth manager
179
+ // - accessManager: request authentication manager
180
+ //
181
+ // Returns:
182
+ // - *Server: A new server instance
183
+ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server {
184
+ optionState := &serverOptionConfig{
185
+ requestLoggerFactory: defaultRequestLoggerFactory,
186
+ }
187
+ for i := range opts {
188
+ opts[i](optionState)
189
+ }
190
+ // Set gin mode
191
+ if !cfg.Debug {
192
+ gin.SetMode(gin.ReleaseMode)
193
+ }
194
+
195
+ // Create gin engine
196
+ engine := gin.New()
197
+ if optionState.engineConfigurator != nil {
198
+ optionState.engineConfigurator(engine)
199
+ }
200
+
201
+ // Add middleware
202
+ engine.Use(logging.GinLogrusLogger())
203
+ engine.Use(logging.GinLogrusRecovery())
204
+ for _, mw := range optionState.extraMiddleware {
205
+ engine.Use(mw)
206
+ }
207
+
208
+ // Add request logging middleware (positioned after recovery, before auth)
209
+ // Resolve logs directory relative to the configuration file directory.
210
+ var requestLogger logging.RequestLogger
211
+ var toggle func(bool)
212
+ if !cfg.CommercialMode {
213
+ if optionState.requestLoggerFactory != nil {
214
+ requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
215
+ }
216
+ if requestLogger != nil {
217
+ engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
218
+ if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
219
+ toggle = setter.SetEnabled
220
+ }
221
+ }
222
+ }
223
+
224
+ engine.Use(corsMiddleware())
225
+ wd, err := os.Getwd()
226
+ if err != nil {
227
+ wd = configFilePath
228
+ }
229
+
230
+ envAdminPassword, envAdminPasswordSet := os.LookupEnv("MANAGEMENT_PASSWORD")
231
+ envAdminPassword = strings.TrimSpace(envAdminPassword)
232
+ envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
233
+
234
+ // Create server instance
235
+ s := &Server{
236
+ engine: engine,
237
+ handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
238
+ cfg: cfg,
239
+ accessManager: accessManager,
240
+ requestLogger: requestLogger,
241
+ loggerToggle: toggle,
242
+ configFilePath: configFilePath,
243
+ currentPath: wd,
244
+ envManagementSecret: envManagementSecret,
245
+ wsRoutes: make(map[string]struct{}),
246
+ }
247
+ s.wsAuthEnabled.Store(cfg.WebsocketAuth)
248
+ // Save initial YAML snapshot
249
+ s.oldConfigYaml, _ = yaml.Marshal(cfg)
250
+ s.applyAccessConfig(nil, cfg)
251
+ if authManager != nil {
252
+ authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
253
+ }
254
+ managementasset.SetCurrentConfig(cfg)
255
+ auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
256
+ // Initialize management handler
257
+ s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
258
+ if optionState.localPassword != "" {
259
+ s.mgmt.SetLocalPassword(optionState.localPassword)
260
+ }
261
+ logDir := filepath.Join(s.currentPath, "logs")
262
+ if base := util.WritablePath(); base != "" {
263
+ logDir = filepath.Join(base, "logs")
264
+ }
265
+ s.mgmt.SetLogDirectory(logDir)
266
+ s.localPassword = optionState.localPassword
267
+
268
+ // Setup routes
269
+ s.setupRoutes()
270
+
271
+ // Register Amp module using V2 interface with Context
272
+ s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
273
+ ctx := modules.Context{
274
+ Engine: engine,
275
+ BaseHandler: s.handlers,
276
+ Config: cfg,
277
+ AuthMiddleware: AuthMiddleware(accessManager),
278
+ }
279
+ if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
280
+ log.Errorf("Failed to register Amp module: %v", err)
281
+ }
282
+
283
+ // Apply additional router configurators from options
284
+ if optionState.routerConfigurator != nil {
285
+ optionState.routerConfigurator(engine, s.handlers, cfg)
286
+ }
287
+
288
+ // Register management routes when configuration or environment secrets are available.
289
+ hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
290
+ s.managementRoutesEnabled.Store(hasManagementSecret)
291
+ if hasManagementSecret {
292
+ s.registerManagementRoutes()
293
+ }
294
+
295
+ if optionState.keepAliveEnabled {
296
+ s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
297
+ }
298
+
299
+ // Create HTTP server
300
+ s.server = &http.Server{
301
+ Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
302
+ Handler: engine,
303
+ }
304
+
305
+ return s
306
+ }
307
+
308
+ // setupRoutes configures the API routes for the server.
309
+ // It defines the endpoints and associates them with their respective handlers.
310
+ func (s *Server) setupRoutes() {
311
+ s.engine.GET("/management.html", s.serveManagementControlPanel)
312
+ openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
313
+ geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
314
+ geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers)
315
+ claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers)
316
+ openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers)
317
+
318
+ // OpenAI compatible API routes
319
+ v1 := s.engine.Group("/v1")
320
+ v1.Use(AuthMiddleware(s.accessManager))
321
+ {
322
+ v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
323
+ v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
324
+ v1.POST("/completions", openaiHandlers.Completions)
325
+ v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
326
+ v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
327
+ v1.POST("/responses", openaiResponsesHandlers.Responses)
328
+ }
329
+
330
+ // Gemini compatible API routes
331
+ v1beta := s.engine.Group("/v1beta")
332
+ v1beta.Use(AuthMiddleware(s.accessManager))
333
+ {
334
+ v1beta.GET("/models", geminiHandlers.GeminiModels)
335
+ v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
336
+ v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
337
+ }
338
+
339
+ // Root endpoint
340
+ s.engine.GET("/", func(c *gin.Context) {
341
+ c.JSON(http.StatusOK, gin.H{
342
+ "message": "CLI Proxy API Server",
343
+ "endpoints": []string{
344
+ "POST /v1/chat/completions",
345
+ "POST /v1/completions",
346
+ "GET /v1/models",
347
+ },
348
+ })
349
+ })
350
+
351
+ // Event logging endpoint - handles Claude Code telemetry requests
352
+ // Returns 200 OK to prevent 404 errors in logs
353
+ s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
354
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
355
+ })
356
+ s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
357
+
358
+ // OAuth callback endpoints (reuse main server port)
359
+ // These endpoints receive provider redirects and persist
360
+ // the short-lived code/state for the waiting goroutine.
361
+ s.engine.GET("/anthropic/callback", func(c *gin.Context) {
362
+ code := c.Query("code")
363
+ state := c.Query("state")
364
+ errStr := c.Query("error")
365
+ if errStr == "" {
366
+ errStr = c.Query("error_description")
367
+ }
368
+ if state != "" {
369
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
370
+ }
371
+ c.Header("Content-Type", "text/html; charset=utf-8")
372
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
373
+ })
374
+
375
+ s.engine.GET("/codex/callback", func(c *gin.Context) {
376
+ code := c.Query("code")
377
+ state := c.Query("state")
378
+ errStr := c.Query("error")
379
+ if errStr == "" {
380
+ errStr = c.Query("error_description")
381
+ }
382
+ if state != "" {
383
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
384
+ }
385
+ c.Header("Content-Type", "text/html; charset=utf-8")
386
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
387
+ })
388
+
389
+ s.engine.GET("/google/callback", func(c *gin.Context) {
390
+ code := c.Query("code")
391
+ state := c.Query("state")
392
+ errStr := c.Query("error")
393
+ if errStr == "" {
394
+ errStr = c.Query("error_description")
395
+ }
396
+ if state != "" {
397
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
398
+ }
399
+ c.Header("Content-Type", "text/html; charset=utf-8")
400
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
401
+ })
402
+
403
+ s.engine.GET("/iflow/callback", func(c *gin.Context) {
404
+ code := c.Query("code")
405
+ state := c.Query("state")
406
+ errStr := c.Query("error")
407
+ if errStr == "" {
408
+ errStr = c.Query("error_description")
409
+ }
410
+ if state != "" {
411
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
412
+ }
413
+ c.Header("Content-Type", "text/html; charset=utf-8")
414
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
415
+ })
416
+
417
+ s.engine.GET("/antigravity/callback", func(c *gin.Context) {
418
+ code := c.Query("code")
419
+ state := c.Query("state")
420
+ errStr := c.Query("error")
421
+ if errStr == "" {
422
+ errStr = c.Query("error_description")
423
+ }
424
+ if state != "" {
425
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
426
+ }
427
+ c.Header("Content-Type", "text/html; charset=utf-8")
428
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
429
+ })
430
+
431
+ s.engine.GET("/kiro/callback", func(c *gin.Context) {
432
+ code := c.Query("code")
433
+ state := c.Query("state")
434
+ errStr := c.Query("error")
435
+ if errStr == "" {
436
+ errStr = c.Query("error_description")
437
+ }
438
+ if state != "" {
439
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
440
+ }
441
+ c.Header("Content-Type", "text/html; charset=utf-8")
442
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
443
+ })
444
+
445
+ // Management routes are registered lazily by registerManagementRoutes when a secret is configured.
446
+ }
447
+
448
+ // AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
449
+ // The handler is served as-is without additional middleware beyond the standard stack already configured.
450
+ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
451
+ if s == nil || s.engine == nil || handler == nil {
452
+ return
453
+ }
454
+ trimmed := strings.TrimSpace(path)
455
+ if trimmed == "" {
456
+ trimmed = "/v1/ws"
457
+ }
458
+ if !strings.HasPrefix(trimmed, "/") {
459
+ trimmed = "/" + trimmed
460
+ }
461
+ s.wsRouteMu.Lock()
462
+ if _, exists := s.wsRoutes[trimmed]; exists {
463
+ s.wsRouteMu.Unlock()
464
+ return
465
+ }
466
+ s.wsRoutes[trimmed] = struct{}{}
467
+ s.wsRouteMu.Unlock()
468
+
469
+ authMiddleware := AuthMiddleware(s.accessManager)
470
+ conditionalAuth := func(c *gin.Context) {
471
+ if !s.wsAuthEnabled.Load() {
472
+ c.Next()
473
+ return
474
+ }
475
+ authMiddleware(c)
476
+ }
477
+ finalHandler := func(c *gin.Context) {
478
+ handler.ServeHTTP(c.Writer, c.Request)
479
+ c.Abort()
480
+ }
481
+
482
+ s.engine.GET(trimmed, conditionalAuth, finalHandler)
483
+ }
484
+
485
+ func (s *Server) registerManagementRoutes() {
486
+ if s == nil || s.engine == nil || s.mgmt == nil {
487
+ return
488
+ }
489
+ if !s.managementRoutesRegistered.CompareAndSwap(false, true) {
490
+ return
491
+ }
492
+
493
+ log.Info("management routes registered after secret key configuration")
494
+
495
+ mgmt := s.engine.Group("/v0/management")
496
+ mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
497
+ {
498
+ mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
499
+ mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
500
+ mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
501
+ mgmt.GET("/config", s.mgmt.GetConfig)
502
+ mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
503
+ mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
504
+ mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
505
+
506
+ mgmt.GET("/debug", s.mgmt.GetDebug)
507
+ mgmt.PUT("/debug", s.mgmt.PutDebug)
508
+ mgmt.PATCH("/debug", s.mgmt.PutDebug)
509
+
510
+ mgmt.GET("/logging-to-file", s.mgmt.GetLoggingToFile)
511
+ mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
512
+ mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
513
+
514
+ mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
515
+ mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
516
+ mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
517
+
518
+ mgmt.GET("/proxy-url", s.mgmt.GetProxyURL)
519
+ mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL)
520
+ mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
521
+ mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
522
+
523
+ mgmt.POST("/api-call", s.mgmt.APICall)
524
+
525
+ mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
526
+ mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
527
+ mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
528
+
529
+ mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel)
530
+ mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
531
+ mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
532
+
533
+ mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
534
+ mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
535
+ mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
536
+ mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
537
+
538
+ mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
539
+ mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
540
+ mgmt.PATCH("/gemini-api-key", s.mgmt.PatchGeminiKey)
541
+ mgmt.DELETE("/gemini-api-key", s.mgmt.DeleteGeminiKey)
542
+
543
+ mgmt.GET("/logs", s.mgmt.GetLogs)
544
+ mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
545
+ mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
546
+ mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
547
+ mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
548
+ mgmt.GET("/request-log", s.mgmt.GetRequestLog)
549
+ mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
550
+ mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
551
+ mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
552
+ mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
553
+ mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
554
+
555
+ mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
556
+ mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
557
+ mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
558
+ mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
559
+ mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
560
+ mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
561
+ mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
562
+ mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
563
+ mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
564
+ mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
565
+ mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
566
+ mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
567
+ mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
568
+ mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
569
+ mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
570
+ mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
571
+ mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
572
+ mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
573
+ mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
574
+ mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
575
+ mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
576
+ mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
577
+ mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
578
+
579
+ mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
580
+ mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
581
+ mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
582
+ mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
583
+ mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
584
+ mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
585
+
586
+ mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
587
+ mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
588
+ mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
589
+ mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey)
590
+
591
+ mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys)
592
+ mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys)
593
+ mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey)
594
+ mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey)
595
+
596
+ mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat)
597
+ mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat)
598
+ mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
599
+ mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
600
+
601
+ mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
602
+ mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
603
+ mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
604
+ mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
605
+
606
+ mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
607
+ mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
608
+ mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
609
+ mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
610
+ mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
611
+ mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
612
+
613
+ mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
614
+ mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
615
+ mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
616
+ mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
617
+ mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
618
+ mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
619
+ mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
620
+ mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
621
+ mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
622
+ mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
623
+ }
624
+ }
625
+
626
+ func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
627
+ return func(c *gin.Context) {
628
+ if !s.managementRoutesEnabled.Load() {
629
+ c.AbortWithStatus(http.StatusNotFound)
630
+ return
631
+ }
632
+ c.Next()
633
+ }
634
+ }
635
+
636
+ func (s *Server) serveManagementControlPanel(c *gin.Context) {
637
+ cfg := s.cfg
638
+ if cfg == nil || cfg.RemoteManagement.DisableControlPanel {
639
+ c.AbortWithStatus(http.StatusNotFound)
640
+ return
641
+ }
642
+ filePath := managementasset.FilePath(s.configFilePath)
643
+ if strings.TrimSpace(filePath) == "" {
644
+ c.AbortWithStatus(http.StatusNotFound)
645
+ return
646
+ }
647
+
648
+ if _, err := os.Stat(filePath); err != nil {
649
+ if os.IsNotExist(err) {
650
+ go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
651
+ c.AbortWithStatus(http.StatusNotFound)
652
+ return
653
+ }
654
+
655
+ log.WithError(err).Error("failed to stat management control panel asset")
656
+ c.AbortWithStatus(http.StatusInternalServerError)
657
+ return
658
+ }
659
+
660
+ c.File(filePath)
661
+ }
662
+
663
+ func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
664
+ if timeout <= 0 || onTimeout == nil {
665
+ return
666
+ }
667
+
668
+ s.keepAliveEnabled = true
669
+ s.keepAliveTimeout = timeout
670
+ s.keepAliveOnTimeout = onTimeout
671
+ s.keepAliveHeartbeat = make(chan struct{}, 1)
672
+ s.keepAliveStop = make(chan struct{}, 1)
673
+
674
+ s.engine.GET("/keep-alive", s.handleKeepAlive)
675
+
676
+ go s.watchKeepAlive()
677
+ }
678
+
679
+ func (s *Server) handleKeepAlive(c *gin.Context) {
680
+ if s.localPassword != "" {
681
+ provided := strings.TrimSpace(c.GetHeader("Authorization"))
682
+ if provided != "" {
683
+ parts := strings.SplitN(provided, " ", 2)
684
+ if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
685
+ provided = parts[1]
686
+ }
687
+ }
688
+ if provided == "" {
689
+ provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
690
+ }
691
+ if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
692
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
693
+ return
694
+ }
695
+ }
696
+
697
+ s.signalKeepAlive()
698
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
699
+ }
700
+
701
+ func (s *Server) signalKeepAlive() {
702
+ if !s.keepAliveEnabled {
703
+ return
704
+ }
705
+ select {
706
+ case s.keepAliveHeartbeat <- struct{}{}:
707
+ default:
708
+ }
709
+ }
710
+
711
+ func (s *Server) watchKeepAlive() {
712
+ if !s.keepAliveEnabled {
713
+ return
714
+ }
715
+
716
+ timer := time.NewTimer(s.keepAliveTimeout)
717
+ defer timer.Stop()
718
+
719
+ for {
720
+ select {
721
+ case <-timer.C:
722
+ log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
723
+ if s.keepAliveOnTimeout != nil {
724
+ s.keepAliveOnTimeout()
725
+ }
726
+ return
727
+ case <-s.keepAliveHeartbeat:
728
+ if !timer.Stop() {
729
+ select {
730
+ case <-timer.C:
731
+ default:
732
+ }
733
+ }
734
+ timer.Reset(s.keepAliveTimeout)
735
+ case <-s.keepAliveStop:
736
+ return
737
+ }
738
+ }
739
+ }
740
+
741
+ // unifiedModelsHandler creates a unified handler for the /v1/models endpoint
742
+ // that routes to different handlers based on the User-Agent header.
743
+ // If User-Agent starts with "claude-cli", it routes to Claude handler,
744
+ // otherwise it routes to OpenAI handler.
745
+ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
746
+ return func(c *gin.Context) {
747
+ userAgent := c.GetHeader("User-Agent")
748
+
749
+ // Route to Claude handler if User-Agent starts with "claude-cli"
750
+ if strings.HasPrefix(userAgent, "claude-cli") {
751
+ // log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
752
+ claudeHandler.ClaudeModels(c)
753
+ } else {
754
+ // log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
755
+ openaiHandler.OpenAIModels(c)
756
+ }
757
+ }
758
+ }
759
+
760
+ // Start begins listening for and serving HTTP or HTTPS requests.
761
+ // It's a blocking call and will only return on an unrecoverable error.
762
+ //
763
+ // Returns:
764
+ // - error: An error if the server fails to start
765
+ func (s *Server) Start() error {
766
+ if s == nil || s.server == nil {
767
+ return fmt.Errorf("failed to start HTTP server: server not initialized")
768
+ }
769
+
770
+ useTLS := s.cfg != nil && s.cfg.TLS.Enable
771
+ if useTLS {
772
+ cert := strings.TrimSpace(s.cfg.TLS.Cert)
773
+ key := strings.TrimSpace(s.cfg.TLS.Key)
774
+ if cert == "" || key == "" {
775
+ return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
776
+ }
777
+ log.Debugf("Starting API server on %s with TLS", s.server.Addr)
778
+ if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
779
+ return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
780
+ }
781
+ return nil
782
+ }
783
+
784
+ log.Debugf("Starting API server on %s", s.server.Addr)
785
+ if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
786
+ return fmt.Errorf("failed to start HTTP server: %v", errServe)
787
+ }
788
+
789
+ return nil
790
+ }
791
+
792
+ // Stop gracefully shuts down the API server without interrupting any
793
+ // active connections.
794
+ //
795
+ // Parameters:
796
+ // - ctx: The context for graceful shutdown
797
+ //
798
+ // Returns:
799
+ // - error: An error if the server fails to stop
800
+ func (s *Server) Stop(ctx context.Context) error {
801
+ log.Debug("Stopping API server...")
802
+
803
+ if s.keepAliveEnabled {
804
+ select {
805
+ case s.keepAliveStop <- struct{}{}:
806
+ default:
807
+ }
808
+ }
809
+
810
+ // Shutdown the HTTP server.
811
+ if err := s.server.Shutdown(ctx); err != nil {
812
+ return fmt.Errorf("failed to shutdown HTTP server: %v", err)
813
+ }
814
+
815
+ log.Debug("API server stopped")
816
+ return nil
817
+ }
818
+
819
+ // corsMiddleware returns a Gin middleware handler that adds CORS headers
820
+ // to every response, allowing cross-origin requests.
821
+ //
822
+ // Returns:
823
+ // - gin.HandlerFunc: The CORS middleware handler
824
+ func corsMiddleware() gin.HandlerFunc {
825
+ return func(c *gin.Context) {
826
+ c.Header("Access-Control-Allow-Origin", "*")
827
+ c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
828
+ c.Header("Access-Control-Allow-Headers", "*")
829
+
830
+ if c.Request.Method == "OPTIONS" {
831
+ c.AbortWithStatus(http.StatusNoContent)
832
+ return
833
+ }
834
+
835
+ c.Next()
836
+ }
837
+ }
838
+
839
+ func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
840
+ if s == nil || s.accessManager == nil || newCfg == nil {
841
+ return
842
+ }
843
+ if _, err := access.ApplyAccessProviders(s.accessManager, oldCfg, newCfg); err != nil {
844
+ return
845
+ }
846
+ }
847
+
848
+ // UpdateClients updates the server's client list and configuration.
849
+ // This method is called when the configuration or authentication tokens change.
850
+ //
851
+ // Parameters:
852
+ // - clients: The new slice of AI service clients
853
+ // - cfg: The new application configuration
854
+ func (s *Server) UpdateClients(cfg *config.Config) {
855
+ // Reconstruct old config from YAML snapshot to avoid reference sharing issues
856
+ var oldCfg *config.Config
857
+ if len(s.oldConfigYaml) > 0 {
858
+ _ = yaml.Unmarshal(s.oldConfigYaml, &oldCfg)
859
+ }
860
+
861
+ // Update request logger enabled state if it has changed
862
+ previousRequestLog := false
863
+ if oldCfg != nil {
864
+ previousRequestLog = oldCfg.RequestLog
865
+ }
866
+ if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
867
+ if s.loggerToggle != nil {
868
+ s.loggerToggle(cfg.RequestLog)
869
+ } else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
870
+ toggler.SetEnabled(cfg.RequestLog)
871
+ }
872
+ if oldCfg != nil {
873
+ log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
874
+ } else {
875
+ log.Debugf("request logging toggled to %t", cfg.RequestLog)
876
+ }
877
+ }
878
+
879
+ if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
880
+ if err := logging.ConfigureLogOutput(cfg); err != nil {
881
+ log.Errorf("failed to reconfigure log output: %v", err)
882
+ } else {
883
+ if oldCfg == nil {
884
+ log.Debug("log output configuration refreshed")
885
+ } else {
886
+ if oldCfg.LoggingToFile != cfg.LoggingToFile {
887
+ log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
888
+ }
889
+ if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
890
+ log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
891
+ }
892
+ }
893
+ }
894
+ }
895
+
896
+ if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
897
+ usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
898
+ if oldCfg != nil {
899
+ log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
900
+ } else {
901
+ log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
902
+ }
903
+ }
904
+
905
+ if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
906
+ auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
907
+ if oldCfg != nil {
908
+ log.Debugf("disable_cooling updated from %t to %t", oldCfg.DisableCooling, cfg.DisableCooling)
909
+ } else {
910
+ log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
911
+ }
912
+ }
913
+ if s.handlers != nil && s.handlers.AuthManager != nil {
914
+ s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
915
+ }
916
+
917
+ // Update log level dynamically when debug flag changes
918
+ if oldCfg == nil || oldCfg.Debug != cfg.Debug {
919
+ util.SetLogLevel(cfg)
920
+ if oldCfg != nil {
921
+ log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
922
+ } else {
923
+ log.Debugf("debug mode toggled to %t", cfg.Debug)
924
+ }
925
+ }
926
+
927
+ prevSecretEmpty := true
928
+ if oldCfg != nil {
929
+ prevSecretEmpty = oldCfg.RemoteManagement.SecretKey == ""
930
+ }
931
+ newSecretEmpty := cfg.RemoteManagement.SecretKey == ""
932
+ if s.envManagementSecret {
933
+ s.registerManagementRoutes()
934
+ if s.managementRoutesEnabled.CompareAndSwap(false, true) {
935
+ log.Info("management routes enabled via MANAGEMENT_PASSWORD")
936
+ } else {
937
+ s.managementRoutesEnabled.Store(true)
938
+ }
939
+ } else {
940
+ switch {
941
+ case prevSecretEmpty && !newSecretEmpty:
942
+ s.registerManagementRoutes()
943
+ if s.managementRoutesEnabled.CompareAndSwap(false, true) {
944
+ log.Info("management routes enabled after secret key update")
945
+ } else {
946
+ s.managementRoutesEnabled.Store(true)
947
+ }
948
+ case !prevSecretEmpty && newSecretEmpty:
949
+ if s.managementRoutesEnabled.CompareAndSwap(true, false) {
950
+ log.Info("management routes disabled after secret key removal")
951
+ } else {
952
+ s.managementRoutesEnabled.Store(false)
953
+ }
954
+ default:
955
+ s.managementRoutesEnabled.Store(!newSecretEmpty)
956
+ }
957
+ }
958
+
959
+ s.applyAccessConfig(oldCfg, cfg)
960
+ s.cfg = cfg
961
+ s.wsAuthEnabled.Store(cfg.WebsocketAuth)
962
+ if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
963
+ s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
964
+ }
965
+ managementasset.SetCurrentConfig(cfg)
966
+ // Save YAML snapshot for next comparison
967
+ s.oldConfigYaml, _ = yaml.Marshal(cfg)
968
+
969
+ s.handlers.UpdateClients(&cfg.SDKConfig)
970
+
971
+ if !cfg.RemoteManagement.DisableControlPanel {
972
+ staticDir := managementasset.StaticDir(s.configFilePath)
973
+ go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
974
+ }
975
+ if s.mgmt != nil {
976
+ s.mgmt.SetConfig(cfg)
977
+ s.mgmt.SetAuthManager(s.handlers.AuthManager)
978
+ }
979
+
980
+ // Notify Amp module of config changes (for model mapping hot-reload)
981
+ if s.ampModule != nil {
982
+ log.Debugf("triggering amp module config update")
983
+ if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
984
+ log.Errorf("failed to update Amp module config: %v", err)
985
+ }
986
+ } else {
987
+ log.Warnf("amp module is nil, skipping config update")
988
+ }
989
+
990
+ // Count client sources from configuration and auth directory
991
+ authFiles := util.CountAuthFiles(cfg.AuthDir)
992
+ geminiAPIKeyCount := len(cfg.GeminiKey)
993
+ claudeAPIKeyCount := len(cfg.ClaudeKey)
994
+ codexAPIKeyCount := len(cfg.CodexKey)
995
+ vertexAICompatCount := len(cfg.VertexCompatAPIKey)
996
+ openAICompatCount := 0
997
+ for i := range cfg.OpenAICompatibility {
998
+ entry := cfg.OpenAICompatibility[i]
999
+ openAICompatCount += len(entry.APIKeyEntries)
1000
+ }
1001
+
1002
+ total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
1003
+ fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
1004
+ total,
1005
+ authFiles,
1006
+ geminiAPIKeyCount,
1007
+ claudeAPIKeyCount,
1008
+ codexAPIKeyCount,
1009
+ vertexAICompatCount,
1010
+ openAICompatCount,
1011
+ )
1012
+ }
1013
+
1014
+ func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
1015
+ if s == nil {
1016
+ return
1017
+ }
1018
+ s.wsAuthChanged = fn
1019
+ }
1020
+
1021
+ // (management handlers moved to internal/api/handlers/management)
1022
+
1023
+ // AuthMiddleware returns a Gin middleware handler that authenticates requests
1024
+ // using the configured authentication providers. When no providers are available,
1025
+ // it allows all requests (legacy behaviour).
1026
+ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
1027
+ return func(c *gin.Context) {
1028
+ if manager == nil {
1029
+ c.Next()
1030
+ return
1031
+ }
1032
+
1033
+ result, err := manager.Authenticate(c.Request.Context(), c.Request)
1034
+ if err == nil {
1035
+ if result != nil {
1036
+ c.Set("apiKey", result.Principal)
1037
+ c.Set("accessProvider", result.Provider)
1038
+ if len(result.Metadata) > 0 {
1039
+ c.Set("accessMetadata", result.Metadata)
1040
+ }
1041
+ }
1042
+ c.Next()
1043
+ return
1044
+ }
1045
+
1046
+ switch {
1047
+ case errors.Is(err, sdkaccess.ErrNoCredentials):
1048
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
1049
+ case errors.Is(err, sdkaccess.ErrInvalidCredential):
1050
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
1051
+ default:
1052
+ log.Errorf("authentication middleware error: %v", err)
1053
+ c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
1054
+ }
1055
+ }
1056
+ }
internal/api/server_test.go ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package api
2
+
3
+ import (
4
+ "net/http"
5
+ "net/http/httptest"
6
+ "os"
7
+ "path/filepath"
8
+ "strings"
9
+ "testing"
10
+
11
+ gin "github.com/gin-gonic/gin"
12
+ proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
13
+ sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
14
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
15
+ sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
16
+ )
17
+
18
+ func newTestServer(t *testing.T) *Server {
19
+ t.Helper()
20
+
21
+ gin.SetMode(gin.TestMode)
22
+
23
+ tmpDir := t.TempDir()
24
+ authDir := filepath.Join(tmpDir, "auth")
25
+ if err := os.MkdirAll(authDir, 0o700); err != nil {
26
+ t.Fatalf("failed to create auth dir: %v", err)
27
+ }
28
+
29
+ cfg := &proxyconfig.Config{
30
+ SDKConfig: sdkconfig.SDKConfig{
31
+ APIKeys: []string{"test-key"},
32
+ },
33
+ Port: 0,
34
+ AuthDir: authDir,
35
+ Debug: true,
36
+ LoggingToFile: false,
37
+ UsageStatisticsEnabled: false,
38
+ }
39
+
40
+ authManager := auth.NewManager(nil, nil, nil)
41
+ accessManager := sdkaccess.NewManager()
42
+
43
+ configPath := filepath.Join(tmpDir, "config.yaml")
44
+ return NewServer(cfg, authManager, accessManager, configPath)
45
+ }
46
+
47
+ func TestAmpProviderModelRoutes(t *testing.T) {
48
+ testCases := []struct {
49
+ name string
50
+ path string
51
+ wantStatus int
52
+ wantContains string
53
+ }{
54
+ {
55
+ name: "openai root models",
56
+ path: "/api/provider/openai/models",
57
+ wantStatus: http.StatusOK,
58
+ wantContains: `"object":"list"`,
59
+ },
60
+ {
61
+ name: "groq root models",
62
+ path: "/api/provider/groq/models",
63
+ wantStatus: http.StatusOK,
64
+ wantContains: `"object":"list"`,
65
+ },
66
+ {
67
+ name: "openai models",
68
+ path: "/api/provider/openai/v1/models",
69
+ wantStatus: http.StatusOK,
70
+ wantContains: `"object":"list"`,
71
+ },
72
+ {
73
+ name: "anthropic models",
74
+ path: "/api/provider/anthropic/v1/models",
75
+ wantStatus: http.StatusOK,
76
+ wantContains: `"data"`,
77
+ },
78
+ {
79
+ name: "google models v1",
80
+ path: "/api/provider/google/v1/models",
81
+ wantStatus: http.StatusOK,
82
+ wantContains: `"models"`,
83
+ },
84
+ {
85
+ name: "google models v1beta",
86
+ path: "/api/provider/google/v1beta/models",
87
+ wantStatus: http.StatusOK,
88
+ wantContains: `"models"`,
89
+ },
90
+ }
91
+
92
+ for _, tc := range testCases {
93
+ tc := tc
94
+ t.Run(tc.name, func(t *testing.T) {
95
+ server := newTestServer(t)
96
+
97
+ req := httptest.NewRequest(http.MethodGet, tc.path, nil)
98
+ req.Header.Set("Authorization", "Bearer test-key")
99
+
100
+ rr := httptest.NewRecorder()
101
+ server.engine.ServeHTTP(rr, req)
102
+
103
+ if rr.Code != tc.wantStatus {
104
+ t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String())
105
+ }
106
+ if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) {
107
+ t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body)
108
+ }
109
+ })
110
+ }
111
+ }
internal/auth/claude/anthropic.go ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package claude
2
+
3
+ // PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
4
+ type PKCECodes struct {
5
+ // CodeVerifier is the cryptographically random string used to correlate
6
+ // the authorization request to the token request
7
+ CodeVerifier string `json:"code_verifier"`
8
+ // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
9
+ CodeChallenge string `json:"code_challenge"`
10
+ }
11
+
12
+ // ClaudeTokenData holds OAuth token information from Anthropic
13
+ type ClaudeTokenData struct {
14
+ // AccessToken is the OAuth2 access token for API access
15
+ AccessToken string `json:"access_token"`
16
+ // RefreshToken is used to obtain new access tokens
17
+ RefreshToken string `json:"refresh_token"`
18
+ // Email is the Anthropic account email
19
+ Email string `json:"email"`
20
+ // Expire is the timestamp of the token expire
21
+ Expire string `json:"expired"`
22
+ }
23
+
24
+ // ClaudeAuthBundle aggregates authentication data after OAuth flow completion
25
+ type ClaudeAuthBundle struct {
26
+ // APIKey is the Anthropic API key obtained from token exchange
27
+ APIKey string `json:"api_key"`
28
+ // TokenData contains the OAuth tokens from the authentication flow
29
+ TokenData ClaudeTokenData `json:"token_data"`
30
+ // LastRefresh is the timestamp of the last token refresh
31
+ LastRefresh string `json:"last_refresh"`
32
+ }
internal/auth/claude/anthropic_auth.go ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides OAuth2 authentication functionality for Anthropic's Claude API.
2
+ // This package implements the complete OAuth2 flow with PKCE (Proof Key for Code Exchange)
3
+ // for secure authentication with Claude API, including token exchange, refresh, and storage.
4
+ package claude
5
+
6
+ import (
7
+ "context"
8
+ "encoding/json"
9
+ "fmt"
10
+ "io"
11
+ "net/http"
12
+ "net/url"
13
+ "strings"
14
+ "time"
15
+
16
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
17
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
18
+ log "github.com/sirupsen/logrus"
19
+ )
20
+
21
+ const (
22
+ anthropicAuthURL = "https://claude.ai/oauth/authorize"
23
+ anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
24
+ anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
25
+ redirectURI = "http://localhost:54545/callback"
26
+ )
27
+
28
+ // tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
29
+ // It contains access token, refresh token, and associated user/organization information.
30
+ type tokenResponse struct {
31
+ AccessToken string `json:"access_token"`
32
+ RefreshToken string `json:"refresh_token"`
33
+ TokenType string `json:"token_type"`
34
+ ExpiresIn int `json:"expires_in"`
35
+ Organization struct {
36
+ UUID string `json:"uuid"`
37
+ Name string `json:"name"`
38
+ } `json:"organization"`
39
+ Account struct {
40
+ UUID string `json:"uuid"`
41
+ EmailAddress string `json:"email_address"`
42
+ } `json:"account"`
43
+ }
44
+
45
+ // ClaudeAuth handles Anthropic OAuth2 authentication flow.
46
+ // It provides methods for generating authorization URLs, exchanging codes for tokens,
47
+ // and refreshing expired tokens using PKCE for enhanced security.
48
+ type ClaudeAuth struct {
49
+ httpClient *http.Client
50
+ }
51
+
52
+ // NewClaudeAuth creates a new Anthropic authentication service.
53
+ // It initializes the HTTP client with proxy settings from the configuration.
54
+ //
55
+ // Parameters:
56
+ // - cfg: The application configuration containing proxy settings
57
+ //
58
+ // Returns:
59
+ // - *ClaudeAuth: A new Claude authentication service instance
60
+ func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
61
+ return &ClaudeAuth{
62
+ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
63
+ }
64
+ }
65
+
66
+ // GenerateAuthURL creates the OAuth authorization URL with PKCE.
67
+ // This method generates a secure authorization URL including PKCE challenge codes
68
+ // for the OAuth2 flow with Anthropic's API.
69
+ //
70
+ // Parameters:
71
+ // - state: A random state parameter for CSRF protection
72
+ // - pkceCodes: The PKCE codes for secure code exchange
73
+ //
74
+ // Returns:
75
+ // - string: The complete authorization URL
76
+ // - string: The state parameter for verification
77
+ // - error: An error if PKCE codes are missing or URL generation fails
78
+ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, string, error) {
79
+ if pkceCodes == nil {
80
+ return "", "", fmt.Errorf("PKCE codes are required")
81
+ }
82
+
83
+ params := url.Values{
84
+ "code": {"true"},
85
+ "client_id": {anthropicClientID},
86
+ "response_type": {"code"},
87
+ "redirect_uri": {redirectURI},
88
+ "scope": {"org:create_api_key user:profile user:inference"},
89
+ "code_challenge": {pkceCodes.CodeChallenge},
90
+ "code_challenge_method": {"S256"},
91
+ "state": {state},
92
+ }
93
+
94
+ authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
95
+ return authURL, state, nil
96
+ }
97
+
98
+ // parseCodeAndState extracts the authorization code and state from the callback response.
99
+ // It handles the parsing of the code parameter which may contain additional fragments.
100
+ //
101
+ // Parameters:
102
+ // - code: The raw code parameter from the OAuth callback
103
+ //
104
+ // Returns:
105
+ // - parsedCode: The extracted authorization code
106
+ // - parsedState: The extracted state parameter if present
107
+ func (c *ClaudeAuth) parseCodeAndState(code string) (parsedCode, parsedState string) {
108
+ splits := strings.Split(code, "#")
109
+ parsedCode = splits[0]
110
+ if len(splits) > 1 {
111
+ parsedState = splits[1]
112
+ }
113
+ return
114
+ }
115
+
116
+ // ExchangeCodeForTokens exchanges authorization code for access tokens.
117
+ // This method implements the OAuth2 token exchange flow using PKCE for security.
118
+ // It sends the authorization code along with PKCE verifier to get access and refresh tokens.
119
+ //
120
+ // Parameters:
121
+ // - ctx: The context for the request
122
+ // - code: The authorization code received from OAuth callback
123
+ // - state: The state parameter for verification
124
+ // - pkceCodes: The PKCE codes for secure verification
125
+ //
126
+ // Returns:
127
+ // - *ClaudeAuthBundle: The complete authentication bundle with tokens
128
+ // - error: An error if token exchange fails
129
+ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state string, pkceCodes *PKCECodes) (*ClaudeAuthBundle, error) {
130
+ if pkceCodes == nil {
131
+ return nil, fmt.Errorf("PKCE codes are required for token exchange")
132
+ }
133
+ newCode, newState := o.parseCodeAndState(code)
134
+
135
+ // Prepare token exchange request
136
+ reqBody := map[string]interface{}{
137
+ "code": newCode,
138
+ "state": state,
139
+ "grant_type": "authorization_code",
140
+ "client_id": anthropicClientID,
141
+ "redirect_uri": redirectURI,
142
+ "code_verifier": pkceCodes.CodeVerifier,
143
+ }
144
+
145
+ // Include state if present
146
+ if newState != "" {
147
+ reqBody["state"] = newState
148
+ }
149
+
150
+ jsonBody, err := json.Marshal(reqBody)
151
+ if err != nil {
152
+ return nil, fmt.Errorf("failed to marshal request body: %w", err)
153
+ }
154
+
155
+ // log.Debugf("Token exchange request: %s", string(jsonBody))
156
+
157
+ req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
158
+ if err != nil {
159
+ return nil, fmt.Errorf("failed to create token request: %w", err)
160
+ }
161
+ req.Header.Set("Content-Type", "application/json")
162
+ req.Header.Set("Accept", "application/json")
163
+
164
+ resp, err := o.httpClient.Do(req)
165
+ if err != nil {
166
+ return nil, fmt.Errorf("token exchange request failed: %w", err)
167
+ }
168
+ defer func() {
169
+ if errClose := resp.Body.Close(); errClose != nil {
170
+ log.Errorf("failed to close response body: %v", errClose)
171
+ }
172
+ }()
173
+
174
+ body, err := io.ReadAll(resp.Body)
175
+ if err != nil {
176
+ return nil, fmt.Errorf("failed to read token response: %w", err)
177
+ }
178
+ // log.Debugf("Token response: %s", string(body))
179
+
180
+ if resp.StatusCode != http.StatusOK {
181
+ return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
182
+ }
183
+ // log.Debugf("Token response: %s", string(body))
184
+
185
+ var tokenResp tokenResponse
186
+ if err = json.Unmarshal(body, &tokenResp); err != nil {
187
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
188
+ }
189
+
190
+ // Create token data
191
+ tokenData := ClaudeTokenData{
192
+ AccessToken: tokenResp.AccessToken,
193
+ RefreshToken: tokenResp.RefreshToken,
194
+ Email: tokenResp.Account.EmailAddress,
195
+ Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
196
+ }
197
+
198
+ // Create auth bundle
199
+ bundle := &ClaudeAuthBundle{
200
+ TokenData: tokenData,
201
+ LastRefresh: time.Now().Format(time.RFC3339),
202
+ }
203
+
204
+ return bundle, nil
205
+ }
206
+
207
+ // RefreshTokens refreshes the access token using the refresh token.
208
+ // This method exchanges a valid refresh token for a new access token,
209
+ // extending the user's authenticated session.
210
+ //
211
+ // Parameters:
212
+ // - ctx: The context for the request
213
+ // - refreshToken: The refresh token to use for getting new access token
214
+ //
215
+ // Returns:
216
+ // - *ClaudeTokenData: The new token data with updated access token
217
+ // - error: An error if token refresh fails
218
+ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
219
+ if refreshToken == "" {
220
+ return nil, fmt.Errorf("refresh token is required")
221
+ }
222
+
223
+ reqBody := map[string]interface{}{
224
+ "client_id": anthropicClientID,
225
+ "grant_type": "refresh_token",
226
+ "refresh_token": refreshToken,
227
+ }
228
+
229
+ jsonBody, err := json.Marshal(reqBody)
230
+ if err != nil {
231
+ return nil, fmt.Errorf("failed to marshal request body: %w", err)
232
+ }
233
+
234
+ req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
235
+ if err != nil {
236
+ return nil, fmt.Errorf("failed to create refresh request: %w", err)
237
+ }
238
+
239
+ req.Header.Set("Content-Type", "application/json")
240
+ req.Header.Set("Accept", "application/json")
241
+
242
+ resp, err := o.httpClient.Do(req)
243
+ if err != nil {
244
+ return nil, fmt.Errorf("token refresh request failed: %w", err)
245
+ }
246
+ defer func() {
247
+ _ = resp.Body.Close()
248
+ }()
249
+
250
+ body, err := io.ReadAll(resp.Body)
251
+ if err != nil {
252
+ return nil, fmt.Errorf("failed to read refresh response: %w", err)
253
+ }
254
+
255
+ if resp.StatusCode != http.StatusOK {
256
+ return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
257
+ }
258
+
259
+ // log.Debugf("Token response: %s", string(body))
260
+
261
+ var tokenResp tokenResponse
262
+ if err = json.Unmarshal(body, &tokenResp); err != nil {
263
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
264
+ }
265
+
266
+ // Create token data
267
+ return &ClaudeTokenData{
268
+ AccessToken: tokenResp.AccessToken,
269
+ RefreshToken: tokenResp.RefreshToken,
270
+ Email: tokenResp.Account.EmailAddress,
271
+ Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
272
+ }, nil
273
+ }
274
+
275
+ // CreateTokenStorage creates a new ClaudeTokenStorage from auth bundle and user info.
276
+ // This method converts the authentication bundle into a token storage structure
277
+ // suitable for persistence and later use.
278
+ //
279
+ // Parameters:
280
+ // - bundle: The authentication bundle containing token data
281
+ //
282
+ // Returns:
283
+ // - *ClaudeTokenStorage: A new token storage instance
284
+ func (o *ClaudeAuth) CreateTokenStorage(bundle *ClaudeAuthBundle) *ClaudeTokenStorage {
285
+ storage := &ClaudeTokenStorage{
286
+ AccessToken: bundle.TokenData.AccessToken,
287
+ RefreshToken: bundle.TokenData.RefreshToken,
288
+ LastRefresh: bundle.LastRefresh,
289
+ Email: bundle.TokenData.Email,
290
+ Expire: bundle.TokenData.Expire,
291
+ }
292
+
293
+ return storage
294
+ }
295
+
296
+ // RefreshTokensWithRetry refreshes tokens with automatic retry logic.
297
+ // This method implements exponential backoff retry logic for token refresh operations,
298
+ // providing resilience against temporary network or service issues.
299
+ //
300
+ // Parameters:
301
+ // - ctx: The context for the request
302
+ // - refreshToken: The refresh token to use
303
+ // - maxRetries: The maximum number of retry attempts
304
+ //
305
+ // Returns:
306
+ // - *ClaudeTokenData: The refreshed token data
307
+ // - error: An error if all retry attempts fail
308
+ func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*ClaudeTokenData, error) {
309
+ var lastErr error
310
+
311
+ for attempt := 0; attempt < maxRetries; attempt++ {
312
+ if attempt > 0 {
313
+ // Wait before retry
314
+ select {
315
+ case <-ctx.Done():
316
+ return nil, ctx.Err()
317
+ case <-time.After(time.Duration(attempt) * time.Second):
318
+ }
319
+ }
320
+
321
+ tokenData, err := o.RefreshTokens(ctx, refreshToken)
322
+ if err == nil {
323
+ return tokenData, nil
324
+ }
325
+
326
+ lastErr = err
327
+ log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
328
+ }
329
+
330
+ return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
331
+ }
332
+
333
+ // UpdateTokenStorage updates an existing token storage with new token data.
334
+ // This method refreshes the token storage with newly obtained access and refresh tokens,
335
+ // updating timestamps and expiration information.
336
+ //
337
+ // Parameters:
338
+ // - storage: The existing token storage to update
339
+ // - tokenData: The new token data to apply
340
+ func (o *ClaudeAuth) UpdateTokenStorage(storage *ClaudeTokenStorage, tokenData *ClaudeTokenData) {
341
+ storage.AccessToken = tokenData.AccessToken
342
+ storage.RefreshToken = tokenData.RefreshToken
343
+ storage.LastRefresh = time.Now().Format(time.RFC3339)
344
+ storage.Email = tokenData.Email
345
+ storage.Expire = tokenData.Expire
346
+ }
internal/auth/claude/errors.go ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides authentication and token management functionality
2
+ // for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
3
+ // and retrieval for maintaining authenticated sessions with the Claude API.
4
+ package claude
5
+
6
+ import (
7
+ "errors"
8
+ "fmt"
9
+ "net/http"
10
+ )
11
+
12
+ // OAuthError represents an OAuth-specific error.
13
+ type OAuthError struct {
14
+ // Code is the OAuth error code.
15
+ Code string `json:"error"`
16
+ // Description is a human-readable description of the error.
17
+ Description string `json:"error_description,omitempty"`
18
+ // URI is a URI identifying a human-readable web page with information about the error.
19
+ URI string `json:"error_uri,omitempty"`
20
+ // StatusCode is the HTTP status code associated with the error.
21
+ StatusCode int `json:"-"`
22
+ }
23
+
24
+ // Error returns a string representation of the OAuth error.
25
+ func (e *OAuthError) Error() string {
26
+ if e.Description != "" {
27
+ return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
28
+ }
29
+ return fmt.Sprintf("OAuth error: %s", e.Code)
30
+ }
31
+
32
+ // NewOAuthError creates a new OAuth error with the specified code, description, and status code.
33
+ func NewOAuthError(code, description string, statusCode int) *OAuthError {
34
+ return &OAuthError{
35
+ Code: code,
36
+ Description: description,
37
+ StatusCode: statusCode,
38
+ }
39
+ }
40
+
41
+ // AuthenticationError represents authentication-related errors.
42
+ type AuthenticationError struct {
43
+ // Type is the type of authentication error.
44
+ Type string `json:"type"`
45
+ // Message is a human-readable message describing the error.
46
+ Message string `json:"message"`
47
+ // Code is the HTTP status code associated with the error.
48
+ Code int `json:"code"`
49
+ // Cause is the underlying error that caused this authentication error.
50
+ Cause error `json:"-"`
51
+ }
52
+
53
+ // Error returns a string representation of the authentication error.
54
+ func (e *AuthenticationError) Error() string {
55
+ if e.Cause != nil {
56
+ return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
57
+ }
58
+ return fmt.Sprintf("%s: %s", e.Type, e.Message)
59
+ }
60
+
61
+ // Common authentication error types.
62
+ var (
63
+ // ErrTokenExpired = &AuthenticationError{
64
+ // Type: "token_expired",
65
+ // Message: "Access token has expired",
66
+ // Code: http.StatusUnauthorized,
67
+ // }
68
+
69
+ // ErrInvalidState represents an error for invalid OAuth state parameter.
70
+ ErrInvalidState = &AuthenticationError{
71
+ Type: "invalid_state",
72
+ Message: "OAuth state parameter is invalid",
73
+ Code: http.StatusBadRequest,
74
+ }
75
+
76
+ // ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
77
+ ErrCodeExchangeFailed = &AuthenticationError{
78
+ Type: "code_exchange_failed",
79
+ Message: "Failed to exchange authorization code for tokens",
80
+ Code: http.StatusBadRequest,
81
+ }
82
+
83
+ // ErrServerStartFailed represents an error when starting the OAuth callback server fails.
84
+ ErrServerStartFailed = &AuthenticationError{
85
+ Type: "server_start_failed",
86
+ Message: "Failed to start OAuth callback server",
87
+ Code: http.StatusInternalServerError,
88
+ }
89
+
90
+ // ErrPortInUse represents an error when the OAuth callback port is already in use.
91
+ ErrPortInUse = &AuthenticationError{
92
+ Type: "port_in_use",
93
+ Message: "OAuth callback port is already in use",
94
+ Code: 13, // Special exit code for port-in-use
95
+ }
96
+
97
+ // ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
98
+ ErrCallbackTimeout = &AuthenticationError{
99
+ Type: "callback_timeout",
100
+ Message: "Timeout waiting for OAuth callback",
101
+ Code: http.StatusRequestTimeout,
102
+ }
103
+ )
104
+
105
+ // NewAuthenticationError creates a new authentication error with a cause based on a base error.
106
+ func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
107
+ return &AuthenticationError{
108
+ Type: baseErr.Type,
109
+ Message: baseErr.Message,
110
+ Code: baseErr.Code,
111
+ Cause: cause,
112
+ }
113
+ }
114
+
115
+ // IsAuthenticationError checks if an error is an authentication error.
116
+ func IsAuthenticationError(err error) bool {
117
+ var authenticationError *AuthenticationError
118
+ ok := errors.As(err, &authenticationError)
119
+ return ok
120
+ }
121
+
122
+ // IsOAuthError checks if an error is an OAuth error.
123
+ func IsOAuthError(err error) bool {
124
+ var oAuthError *OAuthError
125
+ ok := errors.As(err, &oAuthError)
126
+ return ok
127
+ }
128
+
129
+ // GetUserFriendlyMessage returns a user-friendly error message based on the error type.
130
+ func GetUserFriendlyMessage(err error) string {
131
+ switch {
132
+ case IsAuthenticationError(err):
133
+ var authErr *AuthenticationError
134
+ errors.As(err, &authErr)
135
+ switch authErr.Type {
136
+ case "token_expired":
137
+ return "Your authentication has expired. Please log in again."
138
+ case "token_invalid":
139
+ return "Your authentication is invalid. Please log in again."
140
+ case "authentication_required":
141
+ return "Please log in to continue."
142
+ case "port_in_use":
143
+ return "The required port is already in use. Please close any applications using port 3000 and try again."
144
+ case "callback_timeout":
145
+ return "Authentication timed out. Please try again."
146
+ case "browser_open_failed":
147
+ return "Could not open your browser automatically. Please copy and paste the URL manually."
148
+ default:
149
+ return "Authentication failed. Please try again."
150
+ }
151
+ case IsOAuthError(err):
152
+ var oauthErr *OAuthError
153
+ errors.As(err, &oauthErr)
154
+ switch oauthErr.Code {
155
+ case "access_denied":
156
+ return "Authentication was cancelled or denied."
157
+ case "invalid_request":
158
+ return "Invalid authentication request. Please try again."
159
+ case "server_error":
160
+ return "Authentication server error. Please try again later."
161
+ default:
162
+ return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
163
+ }
164
+ default:
165
+ return "An unexpected error occurred. Please try again."
166
+ }
167
+ }
internal/auth/claude/html_templates.go ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides authentication and token management functionality
2
+ // for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
3
+ // and retrieval for maintaining authenticated sessions with the Claude API.
4
+ package claude
5
+
6
+ // LoginSuccessHtml is the HTML template displayed to users after successful OAuth authentication.
7
+ // This template provides a user-friendly success page with options to close the window
8
+ // or navigate to the Claude platform. It includes automatic window closing functionality
9
+ // and keyboard accessibility features.
10
+ const LoginSuccessHtml = `<!DOCTYPE html>
11
+ <html lang="en">
12
+ <head>
13
+ <meta charset="UTF-8">
14
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
15
+ <title>Authentication Successful - Claude</title>
16
+ <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%2310b981'%3E%3Cpath d='M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E">
17
+ <style>
18
+ * {
19
+ box-sizing: border-box;
20
+ }
21
+ body {
22
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
23
+ display: flex;
24
+ justify-content: center;
25
+ align-items: center;
26
+ min-height: 100vh;
27
+ margin: 0;
28
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
29
+ padding: 1rem;
30
+ }
31
+ .container {
32
+ text-align: center;
33
+ background: white;
34
+ padding: 2.5rem;
35
+ border-radius: 12px;
36
+ box-shadow: 0 10px 25px rgba(0,0,0,0.1);
37
+ max-width: 480px;
38
+ width: 100%;
39
+ animation: slideIn 0.3s ease-out;
40
+ }
41
+ @keyframes slideIn {
42
+ from {
43
+ opacity: 0;
44
+ transform: translateY(-20px);
45
+ }
46
+ to {
47
+ opacity: 1;
48
+ transform: translateY(0);
49
+ }
50
+ }
51
+ .success-icon {
52
+ width: 64px;
53
+ height: 64px;
54
+ margin: 0 auto 1.5rem;
55
+ background: #10b981;
56
+ border-radius: 50%;
57
+ display: flex;
58
+ align-items: center;
59
+ justify-content: center;
60
+ color: white;
61
+ font-size: 2rem;
62
+ font-weight: bold;
63
+ }
64
+ h1 {
65
+ color: #1f2937;
66
+ margin-bottom: 1rem;
67
+ font-size: 1.75rem;
68
+ font-weight: 600;
69
+ }
70
+ .subtitle {
71
+ color: #6b7280;
72
+ margin-bottom: 1.5rem;
73
+ font-size: 1rem;
74
+ line-height: 1.5;
75
+ }
76
+ .setup-notice {
77
+ background: #fef3c7;
78
+ border: 1px solid #f59e0b;
79
+ border-radius: 6px;
80
+ padding: 1rem;
81
+ margin: 1rem 0;
82
+ }
83
+ .setup-notice h3 {
84
+ color: #92400e;
85
+ margin: 0 0 0.5rem 0;
86
+ font-size: 1rem;
87
+ }
88
+ .setup-notice p {
89
+ color: #92400e;
90
+ margin: 0;
91
+ font-size: 0.875rem;
92
+ }
93
+ .setup-notice a {
94
+ color: #1d4ed8;
95
+ text-decoration: none;
96
+ }
97
+ .setup-notice a:hover {
98
+ text-decoration: underline;
99
+ }
100
+ .actions {
101
+ display: flex;
102
+ gap: 1rem;
103
+ justify-content: center;
104
+ flex-wrap: wrap;
105
+ margin-top: 2rem;
106
+ }
107
+ .button {
108
+ padding: 0.75rem 1.5rem;
109
+ border-radius: 8px;
110
+ font-size: 0.875rem;
111
+ font-weight: 500;
112
+ text-decoration: none;
113
+ transition: all 0.2s;
114
+ cursor: pointer;
115
+ border: none;
116
+ display: inline-flex;
117
+ align-items: center;
118
+ gap: 0.5rem;
119
+ }
120
+ .button-primary {
121
+ background: #3b82f6;
122
+ color: white;
123
+ }
124
+ .button-primary:hover {
125
+ background: #2563eb;
126
+ transform: translateY(-1px);
127
+ }
128
+ .button-secondary {
129
+ background: #f3f4f6;
130
+ color: #374151;
131
+ border: 1px solid #d1d5db;
132
+ }
133
+ .button-secondary:hover {
134
+ background: #e5e7eb;
135
+ }
136
+ .countdown {
137
+ color: #9ca3af;
138
+ font-size: 0.75rem;
139
+ margin-top: 1rem;
140
+ }
141
+ .footer {
142
+ margin-top: 2rem;
143
+ padding-top: 1.5rem;
144
+ border-top: 1px solid #e5e7eb;
145
+ color: #9ca3af;
146
+ font-size: 0.75rem;
147
+ }
148
+ .footer a {
149
+ color: #3b82f6;
150
+ text-decoration: none;
151
+ }
152
+ .footer a:hover {
153
+ text-decoration: underline;
154
+ }
155
+ </style>
156
+ </head>
157
+ <body>
158
+ <div class="container">
159
+ <div class="success-icon">✓</div>
160
+ <h1>Authentication Successful!</h1>
161
+ <p class="subtitle">You have successfully authenticated with Claude. You can now close this window and return to your terminal to continue.</p>
162
+
163
+ {{SETUP_NOTICE}}
164
+
165
+ <div class="actions">
166
+ <button class="button button-primary" onclick="window.close()">
167
+ <span>Close Window</span>
168
+ </button>
169
+ <a href="{{PLATFORM_URL}}" target="_blank" class="button button-secondary">
170
+ <span>Open Platform</span>
171
+ <span>↗</span>
172
+ </a>
173
+ </div>
174
+
175
+ <div class="countdown">
176
+ This window will close automatically in <span id="countdown">10</span> seconds
177
+ </div>
178
+
179
+ <div class="footer">
180
+ <p>Powered by <a href="https://chatgpt.com" target="_blank">ChatGPT</a></p>
181
+ </div>
182
+ </div>
183
+
184
+ <script>
185
+ let countdown = 10;
186
+ const countdownElement = document.getElementById('countdown');
187
+
188
+ const timer = setInterval(() => {
189
+ countdown--;
190
+ countdownElement.textContent = countdown;
191
+
192
+ if (countdown <= 0) {
193
+ clearInterval(timer);
194
+ window.close();
195
+ }
196
+ }, 1000);
197
+
198
+ // Close window when user presses Escape
199
+ document.addEventListener('keydown', (e) => {
200
+ if (e.key === 'Escape') {
201
+ window.close();
202
+ }
203
+ });
204
+
205
+ // Focus the close button for keyboard accessibility
206
+ document.querySelector('.button-primary').focus();
207
+ </script>
208
+ </body>
209
+ </html>`
210
+
211
+ // SetupNoticeHtml is the HTML template for the setup notice section.
212
+ // This template is embedded within the success page to inform users about
213
+ // additional setup steps required to complete their Claude account configuration.
214
+ const SetupNoticeHtml = `
215
+ <div class="setup-notice">
216
+ <h3>Additional Setup Required</h3>
217
+ <p>To complete your setup, please visit the <a href="{{PLATFORM_URL}}" target="_blank">Claude</a> to configure your account.</p>
218
+ </div>`
internal/auth/claude/oauth_server.go ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides authentication and token management functionality
2
+ // for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
3
+ // and retrieval for maintaining authenticated sessions with the Claude API.
4
+ package claude
5
+
6
+ import (
7
+ "context"
8
+ "errors"
9
+ "fmt"
10
+ "net"
11
+ "net/http"
12
+ "strings"
13
+ "sync"
14
+ "time"
15
+
16
+ log "github.com/sirupsen/logrus"
17
+ )
18
+
19
+ // OAuthServer handles the local HTTP server for OAuth callbacks.
20
+ // It listens for the authorization code response from the OAuth provider
21
+ // and captures the necessary parameters to complete the authentication flow.
22
+ type OAuthServer struct {
23
+ // server is the underlying HTTP server instance
24
+ server *http.Server
25
+ // port is the port number on which the server listens
26
+ port int
27
+ // resultChan is a channel for sending OAuth results
28
+ resultChan chan *OAuthResult
29
+ // errorChan is a channel for sending OAuth errors
30
+ errorChan chan error
31
+ // mu is a mutex for protecting server state
32
+ mu sync.Mutex
33
+ // running indicates whether the server is currently running
34
+ running bool
35
+ }
36
+
37
+ // OAuthResult contains the result of the OAuth callback.
38
+ // It holds either the authorization code and state for successful authentication
39
+ // or an error message if the authentication failed.
40
+ type OAuthResult struct {
41
+ // Code is the authorization code received from the OAuth provider
42
+ Code string
43
+ // State is the state parameter used to prevent CSRF attacks
44
+ State string
45
+ // Error contains any error message if the OAuth flow failed
46
+ Error string
47
+ }
48
+
49
+ // NewOAuthServer creates a new OAuth callback server.
50
+ // It initializes the server with the specified port and creates channels
51
+ // for handling OAuth results and errors.
52
+ //
53
+ // Parameters:
54
+ // - port: The port number on which the server should listen
55
+ //
56
+ // Returns:
57
+ // - *OAuthServer: A new OAuthServer instance
58
+ func NewOAuthServer(port int) *OAuthServer {
59
+ return &OAuthServer{
60
+ port: port,
61
+ resultChan: make(chan *OAuthResult, 1),
62
+ errorChan: make(chan error, 1),
63
+ }
64
+ }
65
+
66
+ // Start starts the OAuth callback server.
67
+ // It sets up the HTTP handlers for the callback and success endpoints,
68
+ // and begins listening on the specified port.
69
+ //
70
+ // Returns:
71
+ // - error: An error if the server fails to start
72
+ func (s *OAuthServer) Start() error {
73
+ s.mu.Lock()
74
+ defer s.mu.Unlock()
75
+
76
+ if s.running {
77
+ return fmt.Errorf("server is already running")
78
+ }
79
+
80
+ // Check if port is available
81
+ if !s.isPortAvailable() {
82
+ return fmt.Errorf("port %d is already in use", s.port)
83
+ }
84
+
85
+ mux := http.NewServeMux()
86
+ mux.HandleFunc("/callback", s.handleCallback)
87
+ mux.HandleFunc("/success", s.handleSuccess)
88
+
89
+ s.server = &http.Server{
90
+ Addr: fmt.Sprintf(":%d", s.port),
91
+ Handler: mux,
92
+ ReadTimeout: 10 * time.Second,
93
+ WriteTimeout: 10 * time.Second,
94
+ }
95
+
96
+ s.running = true
97
+
98
+ // Start server in goroutine
99
+ go func() {
100
+ if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
101
+ s.errorChan <- fmt.Errorf("server failed to start: %w", err)
102
+ }
103
+ }()
104
+
105
+ // Give server a moment to start
106
+ time.Sleep(100 * time.Millisecond)
107
+
108
+ return nil
109
+ }
110
+
111
+ // Stop gracefully stops the OAuth callback server.
112
+ // It performs a graceful shutdown of the HTTP server with a timeout.
113
+ //
114
+ // Parameters:
115
+ // - ctx: The context for controlling the shutdown process
116
+ //
117
+ // Returns:
118
+ // - error: An error if the server fails to stop gracefully
119
+ func (s *OAuthServer) Stop(ctx context.Context) error {
120
+ s.mu.Lock()
121
+ defer s.mu.Unlock()
122
+
123
+ if !s.running || s.server == nil {
124
+ return nil
125
+ }
126
+
127
+ log.Debug("Stopping OAuth callback server")
128
+
129
+ // Create a context with timeout for shutdown
130
+ shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
131
+ defer cancel()
132
+
133
+ err := s.server.Shutdown(shutdownCtx)
134
+ s.running = false
135
+ s.server = nil
136
+
137
+ return err
138
+ }
139
+
140
+ // WaitForCallback waits for the OAuth callback with a timeout.
141
+ // It blocks until either an OAuth result is received, an error occurs,
142
+ // or the specified timeout is reached.
143
+ //
144
+ // Parameters:
145
+ // - timeout: The maximum time to wait for the callback
146
+ //
147
+ // Returns:
148
+ // - *OAuthResult: The OAuth result if successful
149
+ // - error: An error if the callback times out or an error occurs
150
+ func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
151
+ select {
152
+ case result := <-s.resultChan:
153
+ return result, nil
154
+ case err := <-s.errorChan:
155
+ return nil, err
156
+ case <-time.After(timeout):
157
+ return nil, fmt.Errorf("timeout waiting for OAuth callback")
158
+ }
159
+ }
160
+
161
+ // handleCallback handles the OAuth callback endpoint.
162
+ // It extracts the authorization code and state from the callback URL,
163
+ // validates the parameters, and sends the result to the waiting channel.
164
+ //
165
+ // Parameters:
166
+ // - w: The HTTP response writer
167
+ // - r: The HTTP request
168
+ func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
169
+ log.Debug("Received OAuth callback")
170
+
171
+ // Validate request method
172
+ if r.Method != http.MethodGet {
173
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
174
+ return
175
+ }
176
+
177
+ // Extract parameters
178
+ query := r.URL.Query()
179
+ code := query.Get("code")
180
+ state := query.Get("state")
181
+ errorParam := query.Get("error")
182
+
183
+ // Validate required parameters
184
+ if errorParam != "" {
185
+ log.Errorf("OAuth error received: %s", errorParam)
186
+ result := &OAuthResult{
187
+ Error: errorParam,
188
+ }
189
+ s.sendResult(result)
190
+ http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
191
+ return
192
+ }
193
+
194
+ if code == "" {
195
+ log.Error("No authorization code received")
196
+ result := &OAuthResult{
197
+ Error: "no_code",
198
+ }
199
+ s.sendResult(result)
200
+ http.Error(w, "No authorization code received", http.StatusBadRequest)
201
+ return
202
+ }
203
+
204
+ if state == "" {
205
+ log.Error("No state parameter received")
206
+ result := &OAuthResult{
207
+ Error: "no_state",
208
+ }
209
+ s.sendResult(result)
210
+ http.Error(w, "No state parameter received", http.StatusBadRequest)
211
+ return
212
+ }
213
+
214
+ // Send successful result
215
+ result := &OAuthResult{
216
+ Code: code,
217
+ State: state,
218
+ }
219
+ s.sendResult(result)
220
+
221
+ // Redirect to success page
222
+ http.Redirect(w, r, "/success", http.StatusFound)
223
+ }
224
+
225
+ // handleSuccess handles the success page endpoint.
226
+ // It serves a user-friendly HTML page indicating that authentication was successful.
227
+ //
228
+ // Parameters:
229
+ // - w: The HTTP response writer
230
+ // - r: The HTTP request
231
+ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
232
+ log.Debug("Serving success page")
233
+
234
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
235
+ w.WriteHeader(http.StatusOK)
236
+
237
+ // Parse query parameters for customization
238
+ query := r.URL.Query()
239
+ setupRequired := query.Get("setup_required") == "true"
240
+ platformURL := query.Get("platform_url")
241
+ if platformURL == "" {
242
+ platformURL = "https://console.anthropic.com/"
243
+ }
244
+
245
+ // Validate platformURL to prevent XSS - only allow http/https URLs
246
+ if !isValidURL(platformURL) {
247
+ platformURL = "https://console.anthropic.com/"
248
+ }
249
+
250
+ // Generate success page HTML with dynamic content
251
+ successHTML := s.generateSuccessHTML(setupRequired, platformURL)
252
+
253
+ _, err := w.Write([]byte(successHTML))
254
+ if err != nil {
255
+ log.Errorf("Failed to write success page: %v", err)
256
+ }
257
+ }
258
+
259
+ // isValidURL checks if the URL is a valid http/https URL to prevent XSS
260
+ func isValidURL(urlStr string) bool {
261
+ urlStr = strings.TrimSpace(urlStr)
262
+ return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
263
+ }
264
+
265
+ // generateSuccessHTML creates the HTML content for the success page.
266
+ // It customizes the page based on whether additional setup is required
267
+ // and includes a link to the platform.
268
+ //
269
+ // Parameters:
270
+ // - setupRequired: Whether additional setup is required after authentication
271
+ // - platformURL: The URL to the platform for additional setup
272
+ //
273
+ // Returns:
274
+ // - string: The HTML content for the success page
275
+ func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
276
+ html := LoginSuccessHtml
277
+
278
+ // Replace platform URL placeholder
279
+ html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
280
+
281
+ // Add setup notice if required
282
+ if setupRequired {
283
+ setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
284
+ html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
285
+ } else {
286
+ html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
287
+ }
288
+
289
+ return html
290
+ }
291
+
292
+ // sendResult sends the OAuth result to the waiting channel.
293
+ // It ensures that the result is sent without blocking the handler.
294
+ //
295
+ // Parameters:
296
+ // - result: The OAuth result to send
297
+ func (s *OAuthServer) sendResult(result *OAuthResult) {
298
+ select {
299
+ case s.resultChan <- result:
300
+ log.Debug("OAuth result sent to channel")
301
+ default:
302
+ log.Warn("OAuth result channel is full, result dropped")
303
+ }
304
+ }
305
+
306
+ // isPortAvailable checks if the specified port is available.
307
+ // It attempts to listen on the port to determine availability.
308
+ //
309
+ // Returns:
310
+ // - bool: True if the port is available, false otherwise
311
+ func (s *OAuthServer) isPortAvailable() bool {
312
+ addr := fmt.Sprintf(":%d", s.port)
313
+ listener, err := net.Listen("tcp", addr)
314
+ if err != nil {
315
+ return false
316
+ }
317
+ defer func() {
318
+ _ = listener.Close()
319
+ }()
320
+ return true
321
+ }
322
+
323
+ // IsRunning returns whether the server is currently running.
324
+ //
325
+ // Returns:
326
+ // - bool: True if the server is running, false otherwise
327
+ func (s *OAuthServer) IsRunning() bool {
328
+ s.mu.Lock()
329
+ defer s.mu.Unlock()
330
+ return s.running
331
+ }
internal/auth/claude/pkce.go ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides authentication and token management functionality
2
+ // for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
3
+ // and retrieval for maintaining authenticated sessions with the Claude API.
4
+ package claude
5
+
6
+ import (
7
+ "crypto/rand"
8
+ "crypto/sha256"
9
+ "encoding/base64"
10
+ "fmt"
11
+ )
12
+
13
+ // GeneratePKCECodes generates a PKCE code verifier and challenge pair
14
+ // following RFC 7636 specifications for OAuth 2.0 PKCE extension.
15
+ // This provides additional security for the OAuth flow by ensuring that
16
+ // only the client that initiated the request can exchange the authorization code.
17
+ //
18
+ // Returns:
19
+ // - *PKCECodes: A struct containing the code verifier and challenge
20
+ // - error: An error if the generation fails, nil otherwise
21
+ func GeneratePKCECodes() (*PKCECodes, error) {
22
+ // Generate code verifier: 43-128 characters, URL-safe
23
+ codeVerifier, err := generateCodeVerifier()
24
+ if err != nil {
25
+ return nil, fmt.Errorf("failed to generate code verifier: %w", err)
26
+ }
27
+
28
+ // Generate code challenge using S256 method
29
+ codeChallenge := generateCodeChallenge(codeVerifier)
30
+
31
+ return &PKCECodes{
32
+ CodeVerifier: codeVerifier,
33
+ CodeChallenge: codeChallenge,
34
+ }, nil
35
+ }
36
+
37
+ // generateCodeVerifier creates a cryptographically random string
38
+ // of 128 characters using URL-safe base64 encoding
39
+ func generateCodeVerifier() (string, error) {
40
+ // Generate 96 random bytes (will result in 128 base64 characters)
41
+ bytes := make([]byte, 96)
42
+ _, err := rand.Read(bytes)
43
+ if err != nil {
44
+ return "", fmt.Errorf("failed to generate random bytes: %w", err)
45
+ }
46
+
47
+ // Encode to URL-safe base64 without padding
48
+ return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
49
+ }
50
+
51
+ // generateCodeChallenge creates a SHA256 hash of the code verifier
52
+ // and encodes it using URL-safe base64 encoding without padding
53
+ func generateCodeChallenge(codeVerifier string) string {
54
+ hash := sha256.Sum256([]byte(codeVerifier))
55
+ return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
56
+ }
internal/auth/claude/token.go ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Package claude provides authentication and token management functionality
2
+ // for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
3
+ // and retrieval for maintaining authenticated sessions with the Claude API.
4
+ package claude
5
+
6
+ import (
7
+ "encoding/json"
8
+ "fmt"
9
+ "os"
10
+ "path/filepath"
11
+
12
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
13
+ )
14
+
15
+ // ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication.
16
+ // It maintains compatibility with the existing auth system while adding Claude-specific fields
17
+ // for managing access tokens, refresh tokens, and user account information.
18
+ type ClaudeTokenStorage struct {
19
+ // IDToken is the JWT ID token containing user claims and identity information.
20
+ IDToken string `json:"id_token"`
21
+
22
+ // AccessToken is the OAuth2 access token used for authenticating API requests.
23
+ AccessToken string `json:"access_token"`
24
+
25
+ // RefreshToken is used to obtain new access tokens when the current one expires.
26
+ RefreshToken string `json:"refresh_token"`
27
+
28
+ // LastRefresh is the timestamp of the last token refresh operation.
29
+ LastRefresh string `json:"last_refresh"`
30
+
31
+ // Email is the Anthropic account email address associated with this token.
32
+ Email string `json:"email"`
33
+
34
+ // Type indicates the authentication provider type, always "claude" for this storage.
35
+ Type string `json:"type"`
36
+
37
+ // Expire is the timestamp when the current access token expires.
38
+ Expire string `json:"expired"`
39
+ }
40
+
41
+ // SaveTokenToFile serializes the Claude token storage to a JSON file.
42
+ // This method creates the necessary directory structure and writes the token
43
+ // data in JSON format to the specified file path for persistent storage.
44
+ //
45
+ // Parameters:
46
+ // - authFilePath: The full path where the token file should be saved
47
+ //
48
+ // Returns:
49
+ // - error: An error if the operation fails, nil otherwise
50
+ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
51
+ misc.LogSavingCredentials(authFilePath)
52
+ ts.Type = "claude"
53
+
54
+ // Create directory structure if it doesn't exist
55
+ if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
56
+ return fmt.Errorf("failed to create directory: %v", err)
57
+ }
58
+
59
+ // Create the token file
60
+ f, err := os.Create(authFilePath)
61
+ if err != nil {
62
+ return fmt.Errorf("failed to create token file: %w", err)
63
+ }
64
+ defer func() {
65
+ _ = f.Close()
66
+ }()
67
+
68
+ // Encode and write the token data as JSON
69
+ if err = json.NewEncoder(f).Encode(ts); err != nil {
70
+ return fmt.Errorf("failed to write token to file: %w", err)
71
+ }
72
+ return nil
73
+ }