kek
commited on
Commit
·
f606b10
0
Parent(s):
Fresh start: Go 1.23 + go-git/v5 compatibility
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +56 -0
- Dockerfile +30 -0
- Dockerfile.hf +27 -0
- LICENSE +22 -0
- README.md +16 -0
- cmd/server/main.go +535 -0
- config.example.yaml +281 -0
- config.yaml +75 -0
- go.mod +78 -0
- go.sum +196 -0
- internal/access/config_access/provider.go +112 -0
- internal/access/reconcile.go +270 -0
- internal/api/handlers/management/api_tools.go +538 -0
- internal/api/handlers/management/auth_files.go +2606 -0
- internal/api/handlers/management/config_basic.go +243 -0
- internal/api/handlers/management/config_lists.go +1090 -0
- internal/api/handlers/management/handler.go +277 -0
- internal/api/handlers/management/logs.go +592 -0
- internal/api/handlers/management/oauth_callback.go +100 -0
- internal/api/handlers/management/oauth_sessions.go +290 -0
- internal/api/handlers/management/quota.go +18 -0
- internal/api/handlers/management/usage.go +79 -0
- internal/api/handlers/management/vertex_import.go +156 -0
- internal/api/middleware/request_logging.go +122 -0
- internal/api/middleware/response_writer.go +382 -0
- internal/api/modules/amp/amp.go +428 -0
- internal/api/modules/amp/amp_test.go +352 -0
- internal/api/modules/amp/fallback_handlers.go +329 -0
- internal/api/modules/amp/fallback_handlers_test.go +73 -0
- internal/api/modules/amp/gemini_bridge.go +59 -0
- internal/api/modules/amp/gemini_bridge_test.go +93 -0
- internal/api/modules/amp/model_mapping.go +147 -0
- internal/api/modules/amp/model_mapping_test.go +283 -0
- internal/api/modules/amp/proxy.go +266 -0
- internal/api/modules/amp/proxy_test.go +657 -0
- internal/api/modules/amp/response_rewriter.go +160 -0
- internal/api/modules/amp/routes.go +334 -0
- internal/api/modules/amp/routes_test.go +381 -0
- internal/api/modules/amp/secret.go +248 -0
- internal/api/modules/amp/secret_test.go +366 -0
- internal/api/modules/modules.go +92 -0
- internal/api/server.go +1056 -0
- internal/api/server_test.go +111 -0
- internal/auth/claude/anthropic.go +32 -0
- internal/auth/claude/anthropic_auth.go +346 -0
- internal/auth/claude/errors.go +167 -0
- internal/auth/claude/html_templates.go +218 -0
- internal/auth/claude/oauth_server.go +331 -0
- internal/auth/claude/pkce.go +56 -0
- internal/auth/claude/token.go +73 -0
.gitignore
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Binaries
|
| 2 |
+
cli-proxy-api
|
| 3 |
+
cliproxy
|
| 4 |
+
*.exe
|
| 5 |
+
|
| 6 |
+
# Configuration
|
| 7 |
+
config.yaml
|
| 8 |
+
.env
|
| 9 |
+
|
| 10 |
+
# Generated content
|
| 11 |
+
bin/*
|
| 12 |
+
logs/*
|
| 13 |
+
conv/*
|
| 14 |
+
temp/*
|
| 15 |
+
refs/*
|
| 16 |
+
|
| 17 |
+
# Storage backends
|
| 18 |
+
pgstore/*
|
| 19 |
+
gitstore/*
|
| 20 |
+
objectstore/*
|
| 21 |
+
|
| 22 |
+
# Static assets
|
| 23 |
+
static/*
|
| 24 |
+
|
| 25 |
+
# Authentication data
|
| 26 |
+
auths/*
|
| 27 |
+
!auths/.gitkeep
|
| 28 |
+
|
| 29 |
+
# Documentation
|
| 30 |
+
docs/*
|
| 31 |
+
AGENTS.md
|
| 32 |
+
CLAUDE.md
|
| 33 |
+
GEMINI.md
|
| 34 |
+
|
| 35 |
+
# Tooling metadata
|
| 36 |
+
.vscode/*
|
| 37 |
+
.codex/*
|
| 38 |
+
.claude/*
|
| 39 |
+
.gemini/*
|
| 40 |
+
.serena/*
|
| 41 |
+
.agent/*
|
| 42 |
+
.agents/*
|
| 43 |
+
.agents/*
|
| 44 |
+
.opencode/*
|
| 45 |
+
.bmad/*
|
| 46 |
+
_bmad/*
|
| 47 |
+
_bmad-output/*
|
| 48 |
+
.mcp/cache/
|
| 49 |
+
|
| 50 |
+
# macOS
|
| 51 |
+
.DS_Store
|
| 52 |
+
._*
|
| 53 |
+
cli-proxy-api-plus
|
| 54 |
+
CLIProxyAPIPlus_*.tar.gz
|
| 55 |
+
cli-proxy-api-plus
|
| 56 |
+
cli-proxy-api-plus
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. Schritt: Build des Go-Proxys
|
| 2 |
+
FROM golang:1.23-alpine AS builder
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
COPY . .
|
| 5 |
+
RUN go mod download
|
| 6 |
+
RUN CGO_ENABLED=0 GOOS=linux go build -o /app/cliproxy ./cmd/server/
|
| 7 |
+
|
| 8 |
+
# 2. Schritt: Schlankes Runtime-Image
|
| 9 |
+
FROM alpine:latest
|
| 10 |
+
RUN apk add --no-cache ca-certificates bash
|
| 11 |
+
|
| 12 |
+
# Arbeitsverzeichnis
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# Kopiere den Proxy, die Config und den statischen Web-Ordner
|
| 16 |
+
COPY --from=builder /app/cliproxy /app/cliproxy
|
| 17 |
+
COPY config.yaml /app/config.yaml
|
| 18 |
+
COPY static /app/static
|
| 19 |
+
|
| 20 |
+
# Start-Skript
|
| 21 |
+
RUN echo "#!/bin/bash" > /start.sh && \
|
| 22 |
+
echo "Starting CLI Proxy API on Port 7860..." >> /start.sh && \
|
| 23 |
+
echo "exec /app/cliproxy -config /app/config.yaml" >> /start.sh && \
|
| 24 |
+
chmod +x /start.sh
|
| 25 |
+
|
| 26 |
+
# Port 7860 ist Pflicht für Hugging Face
|
| 27 |
+
EXPOSE 7860
|
| 28 |
+
|
| 29 |
+
# Proxy starten
|
| 30 |
+
ENTRYPOINT ["/start.sh"]
|
Dockerfile.hf
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. Schritt: Wir bauen den Go-Proxy (CLIProxyAPIPlus)
|
| 2 |
+
FROM golang:1.24-alpine AS builder
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
COPY . .
|
| 5 |
+
RUN go mod download
|
| 6 |
+
RUN CGO_ENABLED=0 GOOS=linux go build -o /app/cliproxy ./cmd/server/
|
| 7 |
+
|
| 8 |
+
# 2. Schritt: Wir nehmen Puter (den Desktop)
|
| 9 |
+
FROM heyputer/puter:latest
|
| 10 |
+
|
| 11 |
+
USER root
|
| 12 |
+
|
| 13 |
+
# Proxy und Config kopieren
|
| 14 |
+
COPY --from=builder /app/cliproxy /usr/local/bin/cliproxy
|
| 15 |
+
COPY config.yaml /etc/cliproxy/config.yaml
|
| 16 |
+
|
| 17 |
+
# Start-Skript sauber erstellen (Alles in einer RUN-Anweisung)
|
| 18 |
+
RUN echo "#!/bin/bash" > /start.sh && \
|
| 19 |
+
echo "echo 'Starting CLI Proxy...'" >> /start.sh && \
|
| 20 |
+
echo "/usr/local/bin/cliproxy -config /etc/cliproxy/config.yaml &" >> /start.sh && \
|
| 21 |
+
echo "echo 'Starting Puter on Port 7860...'" >> /start.sh && \
|
| 22 |
+
echo "exec python3 /opt/puter/puter/server.py --port 7860" >> /start.sh && \
|
| 23 |
+
chmod +x /start.sh
|
| 24 |
+
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
ENTRYPOINT ["/start.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025-2005.9 Luis Pater
|
| 4 |
+
Copyright (c) 2025.9-present Router-For.ME
|
| 5 |
+
|
| 6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 7 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 8 |
+
in the Software without restriction, including without limitation the rights
|
| 9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 10 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 11 |
+
furnished to do so, subject to the following conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
copies or substantial portions of the Software.
|
| 15 |
+
|
| 16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 22 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Socializer Admin
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# CLIProxyAPI Plus (Socializer Admin)
|
| 11 |
+
|
| 12 |
+
English | [Chinese](README_CN.md)
|
| 13 |
+
|
| 14 |
+
This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project.
|
| 15 |
+
|
| 16 |
+
Running on Hugging Face Spaces with Puter OS.
|
cmd/server/main.go
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package main provides the entry point for the CLI Proxy API server.
|
| 2 |
+
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
|
| 3 |
+
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
|
| 4 |
+
package main
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"context"
|
| 8 |
+
"errors"
|
| 9 |
+
"flag"
|
| 10 |
+
"fmt"
|
| 11 |
+
"io/fs"
|
| 12 |
+
"net/url"
|
| 13 |
+
"os"
|
| 14 |
+
"path/filepath"
|
| 15 |
+
"strings"
|
| 16 |
+
"time"
|
| 17 |
+
|
| 18 |
+
"github.com/joho/godotenv"
|
| 19 |
+
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
| 20 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
| 21 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
| 22 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 23 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
| 24 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
| 25 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
| 26 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
| 27 |
+
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
| 28 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
| 29 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 30 |
+
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
| 31 |
+
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 32 |
+
log "github.com/sirupsen/logrus"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
var (
|
| 36 |
+
Version = "dev"
|
| 37 |
+
Commit = "none"
|
| 38 |
+
BuildDate = "unknown"
|
| 39 |
+
DefaultConfigPath = ""
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
// init initializes the shared logger setup.
|
| 43 |
+
func init() {
|
| 44 |
+
logging.SetupBaseLogger()
|
| 45 |
+
buildinfo.Version = Version
|
| 46 |
+
buildinfo.Commit = Commit
|
| 47 |
+
buildinfo.BuildDate = BuildDate
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
|
| 51 |
+
// Kiro defaults to incognito mode for multi-account support.
|
| 52 |
+
// Users can explicitly override with --incognito or --no-incognito flags.
|
| 53 |
+
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
|
| 54 |
+
if useIncognito {
|
| 55 |
+
cfg.IncognitoBrowser = true
|
| 56 |
+
} else if noIncognito {
|
| 57 |
+
cfg.IncognitoBrowser = false
|
| 58 |
+
} else {
|
| 59 |
+
cfg.IncognitoBrowser = true // Kiro default
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// main is the entry point of the application.
|
| 64 |
+
// It parses command-line flags, loads configuration, and starts the appropriate
|
| 65 |
+
// service based on the provided flags (login, codex-login, or server mode).
|
| 66 |
+
func main() {
|
| 67 |
+
fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
| 68 |
+
|
| 69 |
+
// Command-line flags to control the application's behavior.
|
| 70 |
+
var login bool
|
| 71 |
+
var codexLogin bool
|
| 72 |
+
var claudeLogin bool
|
| 73 |
+
var qwenLogin bool
|
| 74 |
+
var iflowLogin bool
|
| 75 |
+
var iflowCookie bool
|
| 76 |
+
var noBrowser bool
|
| 77 |
+
var antigravityLogin bool
|
| 78 |
+
var kiroLogin bool
|
| 79 |
+
var kiroGoogleLogin bool
|
| 80 |
+
var kiroAWSLogin bool
|
| 81 |
+
var kiroAWSAuthCode bool
|
| 82 |
+
var kiroImport bool
|
| 83 |
+
var githubCopilotLogin bool
|
| 84 |
+
var projectID string
|
| 85 |
+
var vertexImport string
|
| 86 |
+
var configPath string
|
| 87 |
+
var password string
|
| 88 |
+
var noIncognito bool
|
| 89 |
+
var useIncognito bool
|
| 90 |
+
|
| 91 |
+
// Define command-line flags for different operation modes.
|
| 92 |
+
flag.BoolVar(&login, "login", false, "Login Google Account")
|
| 93 |
+
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
| 94 |
+
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
| 95 |
+
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
| 96 |
+
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
| 97 |
+
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
| 98 |
+
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
| 99 |
+
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
| 100 |
+
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
| 101 |
+
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
| 102 |
+
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
| 103 |
+
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
| 104 |
+
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
| 105 |
+
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
|
| 106 |
+
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
| 107 |
+
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
| 108 |
+
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
| 109 |
+
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
| 110 |
+
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
| 111 |
+
flag.StringVar(&password, "password", "", "")
|
| 112 |
+
|
| 113 |
+
flag.CommandLine.Usage = func() {
|
| 114 |
+
out := flag.CommandLine.Output()
|
| 115 |
+
_, _ = fmt.Fprintf(out, "Usage of %s\n", os.Args[0])
|
| 116 |
+
flag.CommandLine.VisitAll(func(f *flag.Flag) {
|
| 117 |
+
if f.Name == "password" {
|
| 118 |
+
return
|
| 119 |
+
}
|
| 120 |
+
s := fmt.Sprintf(" -%s", f.Name)
|
| 121 |
+
name, unquoteUsage := flag.UnquoteUsage(f)
|
| 122 |
+
if name != "" {
|
| 123 |
+
s += " " + name
|
| 124 |
+
}
|
| 125 |
+
if len(s) <= 4 {
|
| 126 |
+
s += " "
|
| 127 |
+
} else {
|
| 128 |
+
s += "\n "
|
| 129 |
+
}
|
| 130 |
+
if unquoteUsage != "" {
|
| 131 |
+
s += unquoteUsage
|
| 132 |
+
}
|
| 133 |
+
if f.DefValue != "" && f.DefValue != "false" && f.DefValue != "0" {
|
| 134 |
+
s += fmt.Sprintf(" (default %s)", f.DefValue)
|
| 135 |
+
}
|
| 136 |
+
_, _ = fmt.Fprint(out, s+"\n")
|
| 137 |
+
})
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Parse the command-line flags.
|
| 141 |
+
flag.Parse()
|
| 142 |
+
|
| 143 |
+
// Core application variables.
|
| 144 |
+
var err error
|
| 145 |
+
var cfg *config.Config
|
| 146 |
+
var isCloudDeploy bool
|
| 147 |
+
var (
|
| 148 |
+
usePostgresStore bool
|
| 149 |
+
pgStoreDSN string
|
| 150 |
+
pgStoreSchema string
|
| 151 |
+
pgStoreLocalPath string
|
| 152 |
+
pgStoreInst *store.PostgresStore
|
| 153 |
+
useGitStore bool
|
| 154 |
+
gitStoreRemoteURL string
|
| 155 |
+
gitStoreUser string
|
| 156 |
+
gitStorePassword string
|
| 157 |
+
gitStoreLocalPath string
|
| 158 |
+
gitStoreInst *store.GitTokenStore
|
| 159 |
+
gitStoreRoot string
|
| 160 |
+
useObjectStore bool
|
| 161 |
+
objectStoreEndpoint string
|
| 162 |
+
objectStoreAccess string
|
| 163 |
+
objectStoreSecret string
|
| 164 |
+
objectStoreBucket string
|
| 165 |
+
objectStoreLocalPath string
|
| 166 |
+
objectStoreInst *store.ObjectTokenStore
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
wd, err := os.Getwd()
|
| 170 |
+
if err != nil {
|
| 171 |
+
log.Errorf("failed to get working directory: %v", err)
|
| 172 |
+
return
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Load environment variables from .env if present.
|
| 176 |
+
if errLoad := godotenv.Load(filepath.Join(wd, ".env")); errLoad != nil {
|
| 177 |
+
if !errors.Is(errLoad, os.ErrNotExist) {
|
| 178 |
+
log.WithError(errLoad).Warn("failed to load .env file")
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
lookupEnv := func(keys ...string) (string, bool) {
|
| 183 |
+
for _, key := range keys {
|
| 184 |
+
if value, ok := os.LookupEnv(key); ok {
|
| 185 |
+
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
| 186 |
+
return trimmed, true
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
return "", false
|
| 191 |
+
}
|
| 192 |
+
writableBase := util.WritablePath()
|
| 193 |
+
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
|
| 194 |
+
usePostgresStore = true
|
| 195 |
+
pgStoreDSN = value
|
| 196 |
+
}
|
| 197 |
+
if usePostgresStore {
|
| 198 |
+
if value, ok := lookupEnv("PGSTORE_SCHEMA", "pgstore_schema"); ok {
|
| 199 |
+
pgStoreSchema = value
|
| 200 |
+
}
|
| 201 |
+
if value, ok := lookupEnv("PGSTORE_LOCAL_PATH", "pgstore_local_path"); ok {
|
| 202 |
+
pgStoreLocalPath = value
|
| 203 |
+
}
|
| 204 |
+
if pgStoreLocalPath == "" {
|
| 205 |
+
if writableBase != "" {
|
| 206 |
+
pgStoreLocalPath = writableBase
|
| 207 |
+
} else {
|
| 208 |
+
pgStoreLocalPath = wd
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
useGitStore = false
|
| 212 |
+
}
|
| 213 |
+
if value, ok := lookupEnv("GITSTORE_GIT_URL", "gitstore_git_url"); ok {
|
| 214 |
+
useGitStore = true
|
| 215 |
+
gitStoreRemoteURL = value
|
| 216 |
+
}
|
| 217 |
+
if value, ok := lookupEnv("GITSTORE_GIT_USERNAME", "gitstore_git_username"); ok {
|
| 218 |
+
gitStoreUser = value
|
| 219 |
+
}
|
| 220 |
+
if value, ok := lookupEnv("GITSTORE_GIT_TOKEN", "gitstore_git_token"); ok {
|
| 221 |
+
gitStorePassword = value
|
| 222 |
+
}
|
| 223 |
+
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
|
| 224 |
+
gitStoreLocalPath = value
|
| 225 |
+
}
|
| 226 |
+
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
|
| 227 |
+
useObjectStore = true
|
| 228 |
+
objectStoreEndpoint = value
|
| 229 |
+
}
|
| 230 |
+
if value, ok := lookupEnv("OBJECTSTORE_ACCESS_KEY", "objectstore_access_key"); ok {
|
| 231 |
+
objectStoreAccess = value
|
| 232 |
+
}
|
| 233 |
+
if value, ok := lookupEnv("OBJECTSTORE_SECRET_KEY", "objectstore_secret_key"); ok {
|
| 234 |
+
objectStoreSecret = value
|
| 235 |
+
}
|
| 236 |
+
if value, ok := lookupEnv("OBJECTSTORE_BUCKET", "objectstore_bucket"); ok {
|
| 237 |
+
objectStoreBucket = value
|
| 238 |
+
}
|
| 239 |
+
if value, ok := lookupEnv("OBJECTSTORE_LOCAL_PATH", "objectstore_local_path"); ok {
|
| 240 |
+
objectStoreLocalPath = value
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// Check for cloud deploy mode only on first execution
|
| 244 |
+
// Read env var name in uppercase: DEPLOY
|
| 245 |
+
deployEnv := os.Getenv("DEPLOY")
|
| 246 |
+
if deployEnv == "cloud" {
|
| 247 |
+
isCloudDeploy = true
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// Determine and load the configuration file.
|
| 251 |
+
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
|
| 252 |
+
var configFilePath string
|
| 253 |
+
if usePostgresStore {
|
| 254 |
+
if pgStoreLocalPath == "" {
|
| 255 |
+
pgStoreLocalPath = wd
|
| 256 |
+
}
|
| 257 |
+
pgStoreLocalPath = filepath.Join(pgStoreLocalPath, "pgstore")
|
| 258 |
+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
| 259 |
+
pgStoreInst, err = store.NewPostgresStore(ctx, store.PostgresStoreConfig{
|
| 260 |
+
DSN: pgStoreDSN,
|
| 261 |
+
Schema: pgStoreSchema,
|
| 262 |
+
SpoolDir: pgStoreLocalPath,
|
| 263 |
+
})
|
| 264 |
+
cancel()
|
| 265 |
+
if err != nil {
|
| 266 |
+
log.Errorf("failed to initialize postgres token store: %v", err)
|
| 267 |
+
return
|
| 268 |
+
}
|
| 269 |
+
examplePath := filepath.Join(wd, "config.example.yaml")
|
| 270 |
+
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
|
| 271 |
+
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
| 272 |
+
cancel()
|
| 273 |
+
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
|
| 274 |
+
return
|
| 275 |
+
}
|
| 276 |
+
cancel()
|
| 277 |
+
configFilePath = pgStoreInst.ConfigPath()
|
| 278 |
+
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
| 279 |
+
if err == nil {
|
| 280 |
+
cfg.AuthDir = pgStoreInst.AuthDir()
|
| 281 |
+
log.Infof("postgres-backed token store enabled, workspace path: %s", pgStoreInst.WorkDir())
|
| 282 |
+
}
|
| 283 |
+
} else if useObjectStore {
|
| 284 |
+
if objectStoreLocalPath == "" {
|
| 285 |
+
if writableBase != "" {
|
| 286 |
+
objectStoreLocalPath = writableBase
|
| 287 |
+
} else {
|
| 288 |
+
objectStoreLocalPath = wd
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
objectStoreRoot := filepath.Join(objectStoreLocalPath, "objectstore")
|
| 292 |
+
resolvedEndpoint := strings.TrimSpace(objectStoreEndpoint)
|
| 293 |
+
useSSL := true
|
| 294 |
+
if strings.Contains(resolvedEndpoint, "://") {
|
| 295 |
+
parsed, errParse := url.Parse(resolvedEndpoint)
|
| 296 |
+
if errParse != nil {
|
| 297 |
+
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
|
| 298 |
+
return
|
| 299 |
+
}
|
| 300 |
+
switch strings.ToLower(parsed.Scheme) {
|
| 301 |
+
case "http":
|
| 302 |
+
useSSL = false
|
| 303 |
+
case "https":
|
| 304 |
+
useSSL = true
|
| 305 |
+
default:
|
| 306 |
+
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
|
| 307 |
+
return
|
| 308 |
+
}
|
| 309 |
+
if parsed.Host == "" {
|
| 310 |
+
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
|
| 311 |
+
return
|
| 312 |
+
}
|
| 313 |
+
resolvedEndpoint = parsed.Host
|
| 314 |
+
if parsed.Path != "" && parsed.Path != "/" {
|
| 315 |
+
resolvedEndpoint = strings.TrimSuffix(parsed.Host+parsed.Path, "/")
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
resolvedEndpoint = strings.TrimRight(resolvedEndpoint, "/")
|
| 319 |
+
objCfg := store.ObjectStoreConfig{
|
| 320 |
+
Endpoint: resolvedEndpoint,
|
| 321 |
+
Bucket: objectStoreBucket,
|
| 322 |
+
AccessKey: objectStoreAccess,
|
| 323 |
+
SecretKey: objectStoreSecret,
|
| 324 |
+
LocalRoot: objectStoreRoot,
|
| 325 |
+
UseSSL: useSSL,
|
| 326 |
+
PathStyle: true,
|
| 327 |
+
}
|
| 328 |
+
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
|
| 329 |
+
if err != nil {
|
| 330 |
+
log.Errorf("failed to initialize object token store: %v", err)
|
| 331 |
+
return
|
| 332 |
+
}
|
| 333 |
+
examplePath := filepath.Join(wd, "config.example.yaml")
|
| 334 |
+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
| 335 |
+
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
|
| 336 |
+
cancel()
|
| 337 |
+
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
|
| 338 |
+
return
|
| 339 |
+
}
|
| 340 |
+
cancel()
|
| 341 |
+
configFilePath = objectStoreInst.ConfigPath()
|
| 342 |
+
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
| 343 |
+
if err == nil {
|
| 344 |
+
if cfg == nil {
|
| 345 |
+
cfg = &config.Config{}
|
| 346 |
+
}
|
| 347 |
+
cfg.AuthDir = objectStoreInst.AuthDir()
|
| 348 |
+
log.Infof("object-backed token store enabled, bucket: %s", objectStoreBucket)
|
| 349 |
+
}
|
| 350 |
+
} else if useGitStore {
|
| 351 |
+
if gitStoreLocalPath == "" {
|
| 352 |
+
if writableBase != "" {
|
| 353 |
+
gitStoreLocalPath = writableBase
|
| 354 |
+
} else {
|
| 355 |
+
gitStoreLocalPath = wd
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
|
| 359 |
+
authDir := filepath.Join(gitStoreRoot, "auths")
|
| 360 |
+
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
|
| 361 |
+
gitStoreInst.SetBaseDir(authDir)
|
| 362 |
+
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
|
| 363 |
+
log.Errorf("failed to prepare git token store: %v", errRepo)
|
| 364 |
+
return
|
| 365 |
+
}
|
| 366 |
+
configFilePath = gitStoreInst.ConfigPath()
|
| 367 |
+
if configFilePath == "" {
|
| 368 |
+
configFilePath = filepath.Join(gitStoreRoot, "config", "config.yaml")
|
| 369 |
+
}
|
| 370 |
+
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
|
| 371 |
+
examplePath := filepath.Join(wd, "config.example.yaml")
|
| 372 |
+
if _, errExample := os.Stat(examplePath); errExample != nil {
|
| 373 |
+
log.Errorf("failed to find template config file: %v", errExample)
|
| 374 |
+
return
|
| 375 |
+
}
|
| 376 |
+
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
|
| 377 |
+
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
|
| 378 |
+
return
|
| 379 |
+
}
|
| 380 |
+
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
|
| 381 |
+
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
|
| 382 |
+
return
|
| 383 |
+
}
|
| 384 |
+
log.Infof("git-backed config initialized from template: %s", configFilePath)
|
| 385 |
+
} else if statErr != nil {
|
| 386 |
+
log.Errorf("failed to inspect git-backed config: %v", statErr)
|
| 387 |
+
return
|
| 388 |
+
}
|
| 389 |
+
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
| 390 |
+
if err == nil {
|
| 391 |
+
cfg.AuthDir = gitStoreInst.AuthDir()
|
| 392 |
+
log.Infof("git-backed token store enabled, repository path: %s", gitStoreRoot)
|
| 393 |
+
}
|
| 394 |
+
} else if configPath != "" {
|
| 395 |
+
configFilePath = configPath
|
| 396 |
+
cfg, err = config.LoadConfigOptional(configPath, isCloudDeploy)
|
| 397 |
+
} else {
|
| 398 |
+
wd, err = os.Getwd()
|
| 399 |
+
if err != nil {
|
| 400 |
+
log.Errorf("failed to get working directory: %v", err)
|
| 401 |
+
return
|
| 402 |
+
}
|
| 403 |
+
configFilePath = filepath.Join(wd, "config.yaml")
|
| 404 |
+
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
|
| 405 |
+
}
|
| 406 |
+
if err != nil {
|
| 407 |
+
log.Errorf("failed to load config: %v", err)
|
| 408 |
+
return
|
| 409 |
+
}
|
| 410 |
+
if cfg == nil {
|
| 411 |
+
cfg = &config.Config{}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// In cloud deploy mode, check if we have a valid configuration
|
| 415 |
+
var configFileExists bool
|
| 416 |
+
if isCloudDeploy {
|
| 417 |
+
if info, errStat := os.Stat(configFilePath); errStat != nil {
|
| 418 |
+
// Don't mislead: API server will not start until configuration is provided.
|
| 419 |
+
log.Info("Cloud deploy mode: No configuration file detected; standing by for configuration")
|
| 420 |
+
configFileExists = false
|
| 421 |
+
} else if info.IsDir() {
|
| 422 |
+
log.Info("Cloud deploy mode: Config path is a directory; standing by for configuration")
|
| 423 |
+
configFileExists = false
|
| 424 |
+
} else if cfg.Port == 0 {
|
| 425 |
+
// LoadConfigOptional returns empty config when file is empty or invalid.
|
| 426 |
+
// Config file exists but is empty or invalid; treat as missing config
|
| 427 |
+
log.Info("Cloud deploy mode: Configuration file is empty or invalid; standing by for valid configuration")
|
| 428 |
+
configFileExists = false
|
| 429 |
+
} else {
|
| 430 |
+
log.Info("Cloud deploy mode: Configuration file detected; starting service")
|
| 431 |
+
configFileExists = true
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
| 435 |
+
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
| 436 |
+
|
| 437 |
+
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
| 438 |
+
log.Errorf("failed to configure log output: %v", err)
|
| 439 |
+
return
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
| 443 |
+
|
| 444 |
+
// Set the log level based on the configuration.
|
| 445 |
+
util.SetLogLevel(cfg)
|
| 446 |
+
|
| 447 |
+
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
| 448 |
+
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
| 449 |
+
return
|
| 450 |
+
} else {
|
| 451 |
+
cfg.AuthDir = resolvedAuthDir
|
| 452 |
+
}
|
| 453 |
+
managementasset.SetCurrentConfig(cfg)
|
| 454 |
+
|
| 455 |
+
// Create login options to be used in authentication flows.
|
| 456 |
+
options := &cmd.LoginOptions{
|
| 457 |
+
NoBrowser: noBrowser,
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
// Register the shared token store once so all components use the same persistence backend.
|
| 461 |
+
if usePostgresStore {
|
| 462 |
+
sdkAuth.RegisterTokenStore(pgStoreInst)
|
| 463 |
+
} else if useObjectStore {
|
| 464 |
+
sdkAuth.RegisterTokenStore(objectStoreInst)
|
| 465 |
+
} else if useGitStore {
|
| 466 |
+
sdkAuth.RegisterTokenStore(gitStoreInst)
|
| 467 |
+
} else {
|
| 468 |
+
sdkAuth.RegisterTokenStore(sdkAuth.NewFileTokenStore())
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// Register built-in access providers before constructing services.
|
| 472 |
+
configaccess.Register()
|
| 473 |
+
|
| 474 |
+
// Handle different command modes based on the provided flags.
|
| 475 |
+
|
| 476 |
+
if vertexImport != "" {
|
| 477 |
+
// Handle Vertex service account import
|
| 478 |
+
cmd.DoVertexImport(cfg, vertexImport)
|
| 479 |
+
} else if login {
|
| 480 |
+
// Handle Google/Gemini login
|
| 481 |
+
cmd.DoLogin(cfg, projectID, options)
|
| 482 |
+
} else if antigravityLogin {
|
| 483 |
+
// Handle Antigravity login
|
| 484 |
+
cmd.DoAntigravityLogin(cfg, options)
|
| 485 |
+
} else if githubCopilotLogin {
|
| 486 |
+
// Handle GitHub Copilot login
|
| 487 |
+
cmd.DoGitHubCopilotLogin(cfg, options)
|
| 488 |
+
} else if codexLogin {
|
| 489 |
+
// Handle Codex login
|
| 490 |
+
cmd.DoCodexLogin(cfg, options)
|
| 491 |
+
} else if claudeLogin {
|
| 492 |
+
// Handle Claude login
|
| 493 |
+
cmd.DoClaudeLogin(cfg, options)
|
| 494 |
+
} else if qwenLogin {
|
| 495 |
+
cmd.DoQwenLogin(cfg, options)
|
| 496 |
+
} else if iflowLogin {
|
| 497 |
+
cmd.DoIFlowLogin(cfg, options)
|
| 498 |
+
} else if iflowCookie {
|
| 499 |
+
cmd.DoIFlowCookieAuth(cfg, options)
|
| 500 |
+
} else if kiroLogin {
|
| 501 |
+
// For Kiro auth, default to incognito mode for multi-account support
|
| 502 |
+
// Users can explicitly override with --no-incognito
|
| 503 |
+
// Note: This config mutation is safe - auth commands exit after completion
|
| 504 |
+
// and don't share config with StartService (which is in the else branch)
|
| 505 |
+
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
| 506 |
+
cmd.DoKiroLogin(cfg, options)
|
| 507 |
+
} else if kiroGoogleLogin {
|
| 508 |
+
// For Kiro auth, default to incognito mode for multi-account support
|
| 509 |
+
// Users can explicitly override with --no-incognito
|
| 510 |
+
// Note: This config mutation is safe - auth commands exit after completion
|
| 511 |
+
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
| 512 |
+
cmd.DoKiroGoogleLogin(cfg, options)
|
| 513 |
+
} else if kiroAWSLogin {
|
| 514 |
+
// For Kiro auth, default to incognito mode for multi-account support
|
| 515 |
+
// Users can explicitly override with --no-incognito
|
| 516 |
+
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
| 517 |
+
cmd.DoKiroAWSLogin(cfg, options)
|
| 518 |
+
} else if kiroAWSAuthCode {
|
| 519 |
+
// For Kiro auth with authorization code flow (better UX)
|
| 520 |
+
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
| 521 |
+
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
|
| 522 |
+
} else if kiroImport {
|
| 523 |
+
cmd.DoKiroImport(cfg, options)
|
| 524 |
+
} else {
|
| 525 |
+
// In cloud deploy mode without config file, just wait for shutdown signals
|
| 526 |
+
if isCloudDeploy && !configFileExists {
|
| 527 |
+
// No config file available, just wait for shutdown
|
| 528 |
+
cmd.WaitForCloudDeploy()
|
| 529 |
+
return
|
| 530 |
+
}
|
| 531 |
+
// Start the main proxy service
|
| 532 |
+
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
| 533 |
+
cmd.StartService(cfg, configFilePath, password)
|
| 534 |
+
}
|
| 535 |
+
}
|
config.example.yaml
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
|
| 2 |
+
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
|
| 3 |
+
host: ""
|
| 4 |
+
|
| 5 |
+
# Server port
|
| 6 |
+
port: 8317
|
| 7 |
+
|
| 8 |
+
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
| 9 |
+
tls:
|
| 10 |
+
enable: false
|
| 11 |
+
cert: ""
|
| 12 |
+
key: ""
|
| 13 |
+
|
| 14 |
+
# Management API settings
|
| 15 |
+
remote-management:
|
| 16 |
+
# Whether to allow remote (non-localhost) management access.
|
| 17 |
+
# When false, only localhost can access management endpoints (a key is still required).
|
| 18 |
+
allow-remote: false
|
| 19 |
+
|
| 20 |
+
# Management key. If a plaintext value is provided here, it will be hashed on startup.
|
| 21 |
+
# All management requests (even from localhost) require this key.
|
| 22 |
+
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
|
| 23 |
+
secret-key: ""
|
| 24 |
+
|
| 25 |
+
# Disable the bundled management control panel asset download and HTTP route when true.
|
| 26 |
+
disable-control-panel: false
|
| 27 |
+
|
| 28 |
+
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
| 29 |
+
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
| 30 |
+
|
| 31 |
+
# Authentication directory (supports ~ for home directory)
|
| 32 |
+
auth-dir: "~/.cli-proxy-api"
|
| 33 |
+
|
| 34 |
+
# API keys for authentication
|
| 35 |
+
api-keys:
|
| 36 |
+
- "your-api-key-1"
|
| 37 |
+
- "your-api-key-2"
|
| 38 |
+
- "your-api-key-3"
|
| 39 |
+
|
| 40 |
+
# Enable debug logging
|
| 41 |
+
debug: false
|
| 42 |
+
|
| 43 |
+
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
| 44 |
+
commercial-mode: false
|
| 45 |
+
|
| 46 |
+
# Open OAuth URLs in incognito/private browser mode.
|
| 47 |
+
# Useful when you want to login with a different account without logging out from your current session.
|
| 48 |
+
# Default: false (but Kiro auth defaults to true for multi-account support)
|
| 49 |
+
incognito-browser: true
|
| 50 |
+
|
| 51 |
+
# When true, write application logs to rotating files instead of stdout
|
| 52 |
+
logging-to-file: false
|
| 53 |
+
|
| 54 |
+
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
| 55 |
+
# files are deleted until within the limit. Set to 0 to disable.
|
| 56 |
+
logs-max-total-size-mb: 0
|
| 57 |
+
|
| 58 |
+
# When false, disable in-memory usage statistics aggregation
|
| 59 |
+
usage-statistics-enabled: false
|
| 60 |
+
|
| 61 |
+
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
| 62 |
+
proxy-url: ""
|
| 63 |
+
|
| 64 |
+
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
| 65 |
+
force-model-prefix: false
|
| 66 |
+
|
| 67 |
+
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
| 68 |
+
request-retry: 3
|
| 69 |
+
|
| 70 |
+
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
| 71 |
+
max-retry-interval: 30
|
| 72 |
+
|
| 73 |
+
# Quota exceeded behavior
|
| 74 |
+
quota-exceeded:
|
| 75 |
+
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
| 76 |
+
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
| 77 |
+
|
| 78 |
+
# Routing strategy for selecting credentials when multiple match.
|
| 79 |
+
routing:
|
| 80 |
+
strategy: "round-robin" # round-robin (default), fill-first
|
| 81 |
+
|
| 82 |
+
# When true, enable authentication for the WebSocket API (/v1/ws).
|
| 83 |
+
ws-auth: false
|
| 84 |
+
|
| 85 |
+
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
| 86 |
+
# streaming:
|
| 87 |
+
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
| 88 |
+
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
| 89 |
+
|
| 90 |
+
# Gemini API keys
|
| 91 |
+
# gemini-api-key:
|
| 92 |
+
# - api-key: "AIzaSy...01"
|
| 93 |
+
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
| 94 |
+
# base-url: "https://generativelanguage.googleapis.com"
|
| 95 |
+
# headers:
|
| 96 |
+
# X-Custom-Header: "custom-value"
|
| 97 |
+
# proxy-url: "socks5://proxy.example.com:1080"
|
| 98 |
+
# models:
|
| 99 |
+
# - name: "gemini-2.5-flash" # upstream model name
|
| 100 |
+
# alias: "gemini-flash" # client alias mapped to the upstream model
|
| 101 |
+
# excluded-models:
|
| 102 |
+
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
| 103 |
+
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
| 104 |
+
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
|
| 105 |
+
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
|
| 106 |
+
# - api-key: "AIzaSy...02"
|
| 107 |
+
|
| 108 |
+
# Codex API keys
|
| 109 |
+
# codex-api-key:
|
| 110 |
+
# - api-key: "sk-atSM..."
|
| 111 |
+
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
| 112 |
+
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
| 113 |
+
# headers:
|
| 114 |
+
# X-Custom-Header: "custom-value"
|
| 115 |
+
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
| 116 |
+
# models:
|
| 117 |
+
# - name: "gpt-5-codex" # upstream model name
|
| 118 |
+
# alias: "codex-latest" # client alias mapped to the upstream model
|
| 119 |
+
# excluded-models:
|
| 120 |
+
# - "gpt-5.1" # exclude specific models (exact match)
|
| 121 |
+
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
| 122 |
+
# - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini)
|
| 123 |
+
# - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low)
|
| 124 |
+
|
| 125 |
+
# Claude API keys
|
| 126 |
+
# claude-api-key:
|
| 127 |
+
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
| 128 |
+
# - api-key: "sk-atSM..."
|
| 129 |
+
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
| 130 |
+
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
| 131 |
+
# headers:
|
| 132 |
+
# X-Custom-Header: "custom-value"
|
| 133 |
+
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
| 134 |
+
# models:
|
| 135 |
+
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
| 136 |
+
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
| 137 |
+
# excluded-models:
|
| 138 |
+
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
| 139 |
+
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
| 140 |
+
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
| 141 |
+
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
| 142 |
+
|
| 143 |
+
# Kiro (AWS CodeWhisperer) configuration
|
| 144 |
+
# Note: Kiro API currently only operates in us-east-1 region
|
| 145 |
+
#kiro:
|
| 146 |
+
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
| 147 |
+
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
| 148 |
+
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
| 149 |
+
# refresh-token: "aorAAAAA..."
|
| 150 |
+
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
| 151 |
+
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
| 152 |
+
|
| 153 |
+
# OpenAI compatibility providers
|
| 154 |
+
# openai-compatibility:
|
| 155 |
+
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
| 156 |
+
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
| 157 |
+
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
| 158 |
+
# headers:
|
| 159 |
+
# X-Custom-Header: "custom-value"
|
| 160 |
+
# api-key-entries:
|
| 161 |
+
# - api-key: "sk-or-v1-...b780"
|
| 162 |
+
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
| 163 |
+
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
| 164 |
+
# models: # The models supported by the provider.
|
| 165 |
+
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
| 166 |
+
# alias: "kimi-k2" # The alias used in the API.
|
| 167 |
+
|
| 168 |
+
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
| 169 |
+
# vertex-api-key:
|
| 170 |
+
# - api-key: "vk-123..." # x-goog-api-key header
|
| 171 |
+
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
| 172 |
+
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
| 173 |
+
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
| 174 |
+
# headers:
|
| 175 |
+
# X-Custom-Header: "custom-value"
|
| 176 |
+
# models: # optional: map aliases to upstream model names
|
| 177 |
+
# - name: "gemini-2.5-flash" # upstream model name
|
| 178 |
+
# alias: "vertex-flash" # client-visible alias
|
| 179 |
+
# - name: "gemini-2.5-pro"
|
| 180 |
+
# alias: "vertex-pro"
|
| 181 |
+
|
| 182 |
+
# Amp Integration
|
| 183 |
+
# ampcode:
|
| 184 |
+
# # Configure upstream URL for Amp CLI OAuth and management features
|
| 185 |
+
# upstream-url: "https://ampcode.com"
|
| 186 |
+
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
| 187 |
+
# upstream-api-key: ""
|
| 188 |
+
# # Per-client upstream API key mapping
|
| 189 |
+
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
| 190 |
+
# # Useful when different clients need to use different Amp accounts/quotas.
|
| 191 |
+
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
| 192 |
+
# upstream-api-keys:
|
| 193 |
+
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
| 194 |
+
# api-keys: # Client keys that use this upstream key
|
| 195 |
+
# - "your-api-key-1"
|
| 196 |
+
# - "your-api-key-2"
|
| 197 |
+
# - upstream-api-key: "amp_key_for_team_b"
|
| 198 |
+
# api-keys:
|
| 199 |
+
# - "your-api-key-3"
|
| 200 |
+
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
| 201 |
+
# restrict-management-to-localhost: false
|
| 202 |
+
# # Force model mappings to run before checking local API keys (default: false)
|
| 203 |
+
# force-model-mappings: false
|
| 204 |
+
# # Amp Model Mappings
|
| 205 |
+
# # Route unavailable Amp models to alternative models available in your local proxy.
|
| 206 |
+
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
| 207 |
+
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
| 208 |
+
# model-mappings:
|
| 209 |
+
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
| 210 |
+
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
| 211 |
+
# - from: "claude-sonnet-4-5-20250929"
|
| 212 |
+
# to: "gemini-claude-sonnet-4-5-thinking"
|
| 213 |
+
# - from: "claude-haiku-4-5-20251001"
|
| 214 |
+
# to: "gemini-2.5-flash"
|
| 215 |
+
|
| 216 |
+
# Global OAuth model name mappings (per channel)
|
| 217 |
+
# These mappings rename model IDs for both model listing and request routing.
|
| 218 |
+
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
| 219 |
+
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
| 220 |
+
# oauth-model-mappings:
|
| 221 |
+
# gemini-cli:
|
| 222 |
+
# - name: "gemini-2.5-pro" # original model name under this channel
|
| 223 |
+
# alias: "g2.5p" # client-visible alias
|
| 224 |
+
# vertex:
|
| 225 |
+
# - name: "gemini-2.5-pro"
|
| 226 |
+
# alias: "g2.5p"
|
| 227 |
+
# aistudio:
|
| 228 |
+
# - name: "gemini-2.5-pro"
|
| 229 |
+
# alias: "g2.5p"
|
| 230 |
+
# antigravity:
|
| 231 |
+
# - name: "gemini-3-pro-preview"
|
| 232 |
+
# alias: "g3p"
|
| 233 |
+
# claude:
|
| 234 |
+
# - name: "claude-sonnet-4-5-20250929"
|
| 235 |
+
# alias: "cs4.5"
|
| 236 |
+
# codex:
|
| 237 |
+
# - name: "gpt-5"
|
| 238 |
+
# alias: "g5"
|
| 239 |
+
# qwen:
|
| 240 |
+
# - name: "qwen3-coder-plus"
|
| 241 |
+
# alias: "qwen-plus"
|
| 242 |
+
# iflow:
|
| 243 |
+
# - name: "glm-4.7"
|
| 244 |
+
# alias: "glm-god"
|
| 245 |
+
|
| 246 |
+
# OAuth provider excluded models
|
| 247 |
+
# oauth-excluded-models:
|
| 248 |
+
# gemini-cli:
|
| 249 |
+
# - "gemini-2.5-pro" # exclude specific models (exact match)
|
| 250 |
+
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
| 251 |
+
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
|
| 252 |
+
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
|
| 253 |
+
# vertex:
|
| 254 |
+
# - "gemini-3-pro-preview"
|
| 255 |
+
# aistudio:
|
| 256 |
+
# - "gemini-3-pro-preview"
|
| 257 |
+
# antigravity:
|
| 258 |
+
# - "gemini-3-pro-preview"
|
| 259 |
+
# claude:
|
| 260 |
+
# - "claude-3-5-haiku-20241022"
|
| 261 |
+
# codex:
|
| 262 |
+
# - "gpt-5-codex-mini"
|
| 263 |
+
# qwen:
|
| 264 |
+
# - "vision-model"
|
| 265 |
+
# iflow:
|
| 266 |
+
# - "tstars2.0"
|
| 267 |
+
|
| 268 |
+
# Optional payload configuration
|
| 269 |
+
# payload:
|
| 270 |
+
# default: # Default rules only set parameters when they are missing in the payload.
|
| 271 |
+
# - models:
|
| 272 |
+
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
| 273 |
+
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
| 274 |
+
# params: # JSON path (gjson/sjson syntax) -> value
|
| 275 |
+
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
| 276 |
+
# override: # Override rules always set parameters, overwriting any existing values.
|
| 277 |
+
# - models:
|
| 278 |
+
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
| 279 |
+
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
| 280 |
+
# params: # JSON path (gjson/sjson syntax) -> value
|
| 281 |
+
# "reasoning.effort": "high"
|
config.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLIProxyAPI Plus - Ultimate Power Config
|
| 2 |
+
host: ""
|
| 3 |
+
port: 7860
|
| 4 |
+
|
| 5 |
+
# TLS settings (disabled for HF)
|
| 6 |
+
tls:
|
| 7 |
+
enable: false
|
| 8 |
+
|
| 9 |
+
# Management API settings (REMOTE ENABLED FOR HF)
|
| 10 |
+
remote-management:
|
| 11 |
+
allow-remote: true
|
| 12 |
+
secret-key: "$2a$10$Yt27TytUvABKw192YdTW2urLkQ5oQkHGuSz6PrzFFlsNJ5TE1EOFe"
|
| 13 |
+
disable-control-panel: false
|
| 14 |
+
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
| 15 |
+
|
| 16 |
+
# Storage
|
| 17 |
+
auth-dir: "./auth"
|
| 18 |
+
|
| 19 |
+
# Client Keys for YOU to access the API
|
| 20 |
+
api-keys:
|
| 21 |
+
- "sk-admin-power-1"
|
| 22 |
+
- "sk-client-key-custom"
|
| 23 |
+
|
| 24 |
+
# Performance & Logging
|
| 25 |
+
debug: false
|
| 26 |
+
commercial-mode: false
|
| 27 |
+
incognito-browser: true
|
| 28 |
+
logging-to-file: false
|
| 29 |
+
usage-statistics-enabled: true
|
| 30 |
+
|
| 31 |
+
# Network
|
| 32 |
+
proxy-url: ""
|
| 33 |
+
|
| 34 |
+
# Smart Routing & Retries
|
| 35 |
+
request-retry: 5
|
| 36 |
+
max-retry-interval: 15
|
| 37 |
+
routing:
|
| 38 |
+
strategy: "round-robin" # Rotates through your 16 keys!
|
| 39 |
+
|
| 40 |
+
quota-exceeded:
|
| 41 |
+
switch-project: true
|
| 42 |
+
switch-preview-model: true
|
| 43 |
+
|
| 44 |
+
# --- PROVIDER SECTIONS (FILL THESE IN THE ADMIN PANEL) ---
|
| 45 |
+
|
| 46 |
+
# 1. Gemini API Keys (Google AI Studio)
|
| 47 |
+
gemini-api-key:
|
| 48 |
+
- api-key: "DEIN_GEMINI_KEY_1"
|
| 49 |
+
- api-key: "DEIN_GEMINI_KEY_2"
|
| 50 |
+
|
| 51 |
+
# 2. Claude API Keys (Anthropic)
|
| 52 |
+
claude-api-key:
|
| 53 |
+
- api-key: "DEIN_CLAUDE_KEY_1"
|
| 54 |
+
|
| 55 |
+
# 3. OpenAI / Compatibility (OpenRouter, DeepSeek, Groq, etc.)
|
| 56 |
+
openai-compatibility:
|
| 57 |
+
- name: "openrouter"
|
| 58 |
+
base-url: "https://openrouter.ai/api/v1"
|
| 59 |
+
api-key-entries:
|
| 60 |
+
- api-key: "DEIN_OPENROUTER_KEY"
|
| 61 |
+
- name: "deepseek"
|
| 62 |
+
base-url: "https://api.deepseek.com/v1"
|
| 63 |
+
api-key-entries:
|
| 64 |
+
- api-key: "DEIN_DEEPSEEK_KEY"
|
| 65 |
+
|
| 66 |
+
# 4. Global Model Mappings (Rename models for easier use)
|
| 67 |
+
oauth-model-mappings:
|
| 68 |
+
gemini-cli:
|
| 69 |
+
- name: "gemini-2.5-pro"
|
| 70 |
+
alias: "g-pro"
|
| 71 |
+
- name: "gemini-2.5-flash"
|
| 72 |
+
alias: "g-flash"
|
| 73 |
+
claude:
|
| 74 |
+
- name: "claude-3-5-sonnet-20241022"
|
| 75 |
+
alias: "sonnet"
|
go.mod
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module github.com/router-for-me/CLIProxyAPI/v6
|
| 2 |
+
|
| 3 |
+
go 1.23
|
| 4 |
+
|
| 5 |
+
require (
|
| 6 |
+
github.com/andybalholm/brotli v1.0.6
|
| 7 |
+
github.com/fsnotify/fsnotify v1.9.0
|
| 8 |
+
github.com/gin-gonic/gin v1.10.1
|
| 9 |
+
github.com/go-git/go-git/v5 v5.12.0
|
| 10 |
+
github.com/google/uuid v1.6.0
|
| 11 |
+
github.com/gorilla/websocket v1.5.3
|
| 12 |
+
github.com/jackc/pgx/v5 v5.7.0
|
| 13 |
+
github.com/joho/godotenv v1.5.1
|
| 14 |
+
github.com/klauspost/compress v1.17.4
|
| 15 |
+
github.com/minio/minio-go/v7 v7.0.66
|
| 16 |
+
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
| 17 |
+
github.com/sirupsen/logrus v1.9.3
|
| 18 |
+
github.com/tidwall/gjson v1.18.0
|
| 19 |
+
github.com/tidwall/sjson v1.2.5
|
| 20 |
+
github.com/tiktoken-go/tokenizer v0.7.0
|
| 21 |
+
golang.org/x/crypto v0.46.0
|
| 22 |
+
golang.org/x/net v0.48.0
|
| 23 |
+
golang.org/x/oauth2 v0.21.0
|
| 24 |
+
golang.org/x/term v0.38.0
|
| 25 |
+
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
| 26 |
+
gopkg.in/yaml.v3 v3.0.1
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
require (
|
| 30 |
+
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
| 31 |
+
github.com/Microsoft/go-winio v0.6.2 // indirect
|
| 32 |
+
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
| 33 |
+
github.com/bytedance/sonic v1.11.6 // indirect
|
| 34 |
+
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
| 35 |
+
github.com/cloudflare/circl v1.6.1 // indirect
|
| 36 |
+
github.com/cloudwego/base64x v0.1.4 // indirect
|
| 37 |
+
github.com/cloudwego/iasm v0.2.0 // indirect
|
| 38 |
+
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
|
| 39 |
+
github.com/dlclark/regexp2 v1.11.5 // indirect
|
| 40 |
+
github.com/dustin/go-humanize v1.0.1 // indirect
|
| 41 |
+
github.com/emirpasic/gods v1.18.1 // indirect
|
| 42 |
+
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
| 43 |
+
github.com/gin-contrib/sse v0.1.0 // indirect
|
| 44 |
+
github.com/go-git/gcfg v2.0.2 // indirect
|
| 45 |
+
github.com/go-playground/locales v0.14.1 // indirect
|
| 46 |
+
github.com/go-playground/universal-translator v0.18.1 // indirect
|
| 47 |
+
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
| 48 |
+
github.com/goccy/go-json v0.10.2 // indirect
|
| 49 |
+
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
| 50 |
+
github.com/google/go-cmp v0.6.0 // indirect
|
| 51 |
+
github.com/jackc/pgpassfile v1.0.0 // indirect
|
| 52 |
+
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
| 53 |
+
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
| 54 |
+
github.com/json-iterator/go v1.1.12 // indirect
|
| 55 |
+
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
| 56 |
+
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
| 57 |
+
github.com/kr/text v0.2.0 // indirect
|
| 58 |
+
github.com/leodido/go-urn v1.4.0 // indirect
|
| 59 |
+
github.com/mattn/go-isatty v0.0.20 // indirect
|
| 60 |
+
github.com/minio/md5-simd v1.1.2 // indirect
|
| 61 |
+
github.com/minio/sha256-simd v1.0.1 // indirect
|
| 62 |
+
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
| 63 |
+
github.com/modern-go/reflect2 v1.0.2 // indirect
|
| 64 |
+
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
| 65 |
+
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
| 66 |
+
github.com/rs/xid v1.5.0 // indirect
|
| 67 |
+
github.com/sergi/go-diff v1.4.0 // indirect
|
| 68 |
+
github.com/tidwall/match v1.1.1 // indirect
|
| 69 |
+
github.com/tidwall/pretty v1.2.0 // indirect
|
| 70 |
+
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
| 71 |
+
github.com/ugorji/go/codec v1.2.12 // indirect
|
| 72 |
+
golang.org/x/arch v0.8.0 // indirect
|
| 73 |
+
golang.org/x/sync v0.19.0 // indirect
|
| 74 |
+
golang.org/x/sys v0.39.0 // indirect
|
| 75 |
+
golang.org/x/text v0.32.0 // indirect
|
| 76 |
+
google.golang.org/protobuf v1.34.1 // indirect
|
| 77 |
+
gopkg.in/ini.v1 v1.67.0 // indirect
|
| 78 |
+
)
|
go.sum
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
| 2 |
+
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
| 3 |
+
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
| 4 |
+
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
| 5 |
+
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
|
| 6 |
+
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
| 7 |
+
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
|
| 8 |
+
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
| 9 |
+
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
| 10 |
+
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
| 11 |
+
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
| 12 |
+
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
| 13 |
+
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
| 14 |
+
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
| 15 |
+
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
| 16 |
+
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
| 17 |
+
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
|
| 18 |
+
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
|
| 19 |
+
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
| 20 |
+
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
| 21 |
+
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
| 22 |
+
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
| 23 |
+
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
| 24 |
+
github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE=
|
| 25 |
+
github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc=
|
| 26 |
+
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
| 27 |
+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
| 28 |
+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
| 29 |
+
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
| 30 |
+
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
| 31 |
+
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
| 32 |
+
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
| 33 |
+
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
| 34 |
+
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
| 35 |
+
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
| 36 |
+
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
| 37 |
+
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
| 38 |
+
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
| 39 |
+
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
| 40 |
+
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
| 41 |
+
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
|
| 42 |
+
github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
| 43 |
+
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
|
| 44 |
+
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
|
| 45 |
+
github.com/go-git/gcfg/v2 v2.0.2 h1:MY5SIIfTGGEMhdA7d7JePuVVxtKL7Hp+ApGDJAJ7dpo=
|
| 46 |
+
github.com/go-git/gcfg/v2 v2.0.2/go.mod h1:/lv2NsxvhepuMrldsFilrgct6pxzpGdSRC13ydTLSLs=
|
| 47 |
+
github.com/go-git/go-billy/v6 v6.0.0-20251217170237-e9738f50a3cd h1:Gd/f9cGi/3h1JOPaa6er+CkKUGyGX2DBJdFbDKVO+R0=
|
| 48 |
+
github.com/go-git/go-billy/v6 v6.0.0-20251217170237-e9738f50a3cd/go.mod h1:d3XQcsHu1idnquxt48kAv+h+1MUiYKLH/e7LAzjP+pI=
|
| 49 |
+
github.com/go-git/go-git-fixtures/v5 v5.1.2-0.20251229094738-4b14af179146 h1:xYfxAopYyL44ot6dMBIb1Z1njFM0ZBQ99HdIB99KxLs=
|
| 50 |
+
github.com/go-git/go-git-fixtures/v5 v5.1.2-0.20251229094738-4b14af179146/go.mod h1:QE/75B8tBSLNGyUUbA9tw3EGHoFtYOtypa2h8YJxsWI=
|
| 51 |
+
github.com/go-git/go-git/v6 v6.0.0-20251231065035-29ae690a9f19 h1:0lz2eJScP8v5YZQsrEw+ggWC5jNySjg4bIZo5BIh6iI=
|
| 52 |
+
github.com/go-git/go-git/v6 v6.0.0-20251231065035-29ae690a9f19/go.mod h1:L+Evfcs7EdTqxwv854354cb6+++7TFL3hJn3Wy4g+3w=
|
| 53 |
+
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
| 54 |
+
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
| 55 |
+
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
| 56 |
+
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
| 57 |
+
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
| 58 |
+
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
| 59 |
+
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
| 60 |
+
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
| 61 |
+
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
| 62 |
+
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
| 63 |
+
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
| 64 |
+
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
|
| 65 |
+
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
| 66 |
+
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
| 67 |
+
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
| 68 |
+
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
| 69 |
+
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
| 70 |
+
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
| 71 |
+
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
| 72 |
+
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
| 73 |
+
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
| 74 |
+
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
| 75 |
+
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
| 76 |
+
github.com/jackc/pgx/v5 v5.7.0 h1:FG6VLIdzvAPhnYqP14sQ2xhFLkiUQHCs6ySqO91kF4g=
|
| 77 |
+
github.com/jackc/pgx/v5 v5.7.0/go.mod h1:awP1KNnjylvpxHuHP63gzjhnGkI1iw+PMoIwvoleN/8=
|
| 78 |
+
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
| 79 |
+
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
| 80 |
+
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
| 81 |
+
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
| 82 |
+
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
| 83 |
+
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
| 84 |
+
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
| 85 |
+
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
| 86 |
+
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
| 87 |
+
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
| 88 |
+
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
| 89 |
+
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
| 90 |
+
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
| 91 |
+
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
| 92 |
+
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
| 93 |
+
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
| 94 |
+
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
| 95 |
+
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
| 96 |
+
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
| 97 |
+
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
| 98 |
+
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
| 99 |
+
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
| 100 |
+
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
| 101 |
+
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
| 102 |
+
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
| 103 |
+
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
| 104 |
+
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
| 105 |
+
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
| 106 |
+
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
| 107 |
+
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
|
| 108 |
+
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
| 109 |
+
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
| 110 |
+
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
| 111 |
+
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
| 112 |
+
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
| 113 |
+
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
| 114 |
+
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
| 115 |
+
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
| 116 |
+
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
| 117 |
+
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
| 118 |
+
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
| 119 |
+
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
| 120 |
+
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
| 121 |
+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
| 122 |
+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
| 123 |
+
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
| 124 |
+
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
| 125 |
+
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
| 126 |
+
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
| 127 |
+
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
| 128 |
+
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
| 129 |
+
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
| 130 |
+
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
| 131 |
+
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
| 132 |
+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
| 133 |
+
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
| 134 |
+
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
| 135 |
+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
| 136 |
+
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
| 137 |
+
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
| 138 |
+
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
| 139 |
+
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
| 140 |
+
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
| 141 |
+
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
| 142 |
+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
| 143 |
+
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
| 144 |
+
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
| 145 |
+
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
| 146 |
+
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
| 147 |
+
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
| 148 |
+
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
| 149 |
+
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
| 150 |
+
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
| 151 |
+
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
| 152 |
+
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
| 153 |
+
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
| 154 |
+
github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
|
| 155 |
+
github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
| 156 |
+
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
| 157 |
+
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
| 158 |
+
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
| 159 |
+
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
| 160 |
+
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
| 161 |
+
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
| 162 |
+
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
| 163 |
+
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
| 164 |
+
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
| 165 |
+
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
| 166 |
+
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
| 167 |
+
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
| 168 |
+
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
| 169 |
+
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
| 170 |
+
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
| 171 |
+
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 172 |
+
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 173 |
+
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 174 |
+
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
| 175 |
+
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
| 176 |
+
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
| 177 |
+
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
| 178 |
+
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
| 179 |
+
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
| 180 |
+
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
| 181 |
+
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
| 182 |
+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
| 183 |
+
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
| 184 |
+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
| 185 |
+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
| 186 |
+
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
| 187 |
+
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
| 188 |
+
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
| 189 |
+
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
| 190 |
+
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
| 191 |
+
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
| 192 |
+
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
| 193 |
+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
| 194 |
+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
| 195 |
+
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
| 196 |
+
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
internal/access/config_access/provider.go
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package configaccess
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"net/http"
|
| 6 |
+
"strings"
|
| 7 |
+
"sync"
|
| 8 |
+
|
| 9 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 10 |
+
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
var registerOnce sync.Once
|
| 14 |
+
|
| 15 |
+
// Register ensures the config-access provider is available to the access manager.
|
| 16 |
+
func Register() {
|
| 17 |
+
registerOnce.Do(func() {
|
| 18 |
+
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
| 19 |
+
})
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
type provider struct {
|
| 23 |
+
name string
|
| 24 |
+
keys map[string]struct{}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
| 28 |
+
name := cfg.Name
|
| 29 |
+
if name == "" {
|
| 30 |
+
name = sdkconfig.DefaultAccessProviderName
|
| 31 |
+
}
|
| 32 |
+
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
| 33 |
+
for _, key := range cfg.APIKeys {
|
| 34 |
+
if key == "" {
|
| 35 |
+
continue
|
| 36 |
+
}
|
| 37 |
+
keys[key] = struct{}{}
|
| 38 |
+
}
|
| 39 |
+
return &provider{name: name, keys: keys}, nil
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
func (p *provider) Identifier() string {
|
| 43 |
+
if p == nil || p.name == "" {
|
| 44 |
+
return sdkconfig.DefaultAccessProviderName
|
| 45 |
+
}
|
| 46 |
+
return p.name
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
| 50 |
+
if p == nil {
|
| 51 |
+
return nil, sdkaccess.ErrNotHandled
|
| 52 |
+
}
|
| 53 |
+
if len(p.keys) == 0 {
|
| 54 |
+
return nil, sdkaccess.ErrNotHandled
|
| 55 |
+
}
|
| 56 |
+
authHeader := r.Header.Get("Authorization")
|
| 57 |
+
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
| 58 |
+
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
| 59 |
+
queryKey := ""
|
| 60 |
+
queryAuthToken := ""
|
| 61 |
+
if r.URL != nil {
|
| 62 |
+
queryKey = r.URL.Query().Get("key")
|
| 63 |
+
queryAuthToken = r.URL.Query().Get("auth_token")
|
| 64 |
+
}
|
| 65 |
+
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
| 66 |
+
return nil, sdkaccess.ErrNoCredentials
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
apiKey := extractBearerToken(authHeader)
|
| 70 |
+
|
| 71 |
+
candidates := []struct {
|
| 72 |
+
value string
|
| 73 |
+
source string
|
| 74 |
+
}{
|
| 75 |
+
{apiKey, "authorization"},
|
| 76 |
+
{authHeaderGoogle, "x-goog-api-key"},
|
| 77 |
+
{authHeaderAnthropic, "x-api-key"},
|
| 78 |
+
{queryKey, "query-key"},
|
| 79 |
+
{queryAuthToken, "query-auth-token"},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
for _, candidate := range candidates {
|
| 83 |
+
if candidate.value == "" {
|
| 84 |
+
continue
|
| 85 |
+
}
|
| 86 |
+
if _, ok := p.keys[candidate.value]; ok {
|
| 87 |
+
return &sdkaccess.Result{
|
| 88 |
+
Provider: p.Identifier(),
|
| 89 |
+
Principal: candidate.value,
|
| 90 |
+
Metadata: map[string]string{
|
| 91 |
+
"source": candidate.source,
|
| 92 |
+
},
|
| 93 |
+
}, nil
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
return nil, sdkaccess.ErrInvalidCredential
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
func extractBearerToken(header string) string {
|
| 101 |
+
if header == "" {
|
| 102 |
+
return ""
|
| 103 |
+
}
|
| 104 |
+
parts := strings.SplitN(header, " ", 2)
|
| 105 |
+
if len(parts) != 2 {
|
| 106 |
+
return header
|
| 107 |
+
}
|
| 108 |
+
if strings.ToLower(parts[0]) != "bearer" {
|
| 109 |
+
return header
|
| 110 |
+
}
|
| 111 |
+
return strings.TrimSpace(parts[1])
|
| 112 |
+
}
|
internal/access/reconcile.go
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package access
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"reflect"
|
| 6 |
+
"sort"
|
| 7 |
+
"strings"
|
| 8 |
+
|
| 9 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 10 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 11 |
+
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
| 12 |
+
log "github.com/sirupsen/logrus"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
|
| 16 |
+
// and creating or removing providers only when their configuration changed. It returns the final
|
| 17 |
+
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
| 18 |
+
// removed compared to the previous configuration.
|
| 19 |
+
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
| 20 |
+
if newCfg == nil {
|
| 21 |
+
return nil, nil, nil, nil, nil
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
| 25 |
+
for _, provider := range existing {
|
| 26 |
+
if provider == nil {
|
| 27 |
+
continue
|
| 28 |
+
}
|
| 29 |
+
existingMap[provider.Identifier()] = provider
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
oldCfgMap := accessProviderMap(oldCfg)
|
| 33 |
+
newEntries := collectProviderEntries(newCfg)
|
| 34 |
+
|
| 35 |
+
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
| 36 |
+
finalIDs := make(map[string]struct{}, len(newEntries))
|
| 37 |
+
|
| 38 |
+
isInlineProvider := func(id string) bool {
|
| 39 |
+
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
| 40 |
+
}
|
| 41 |
+
appendChange := func(list *[]string, id string) {
|
| 42 |
+
if isInlineProvider(id) {
|
| 43 |
+
return
|
| 44 |
+
}
|
| 45 |
+
*list = append(*list, id)
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
for _, providerCfg := range newEntries {
|
| 49 |
+
key := providerIdentifier(providerCfg)
|
| 50 |
+
if key == "" {
|
| 51 |
+
continue
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
| 55 |
+
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
| 56 |
+
isAliased := oldCfgProvider == providerCfg
|
| 57 |
+
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
| 58 |
+
if existingProvider, okExisting := existingMap[key]; okExisting {
|
| 59 |
+
result = append(result, existingProvider)
|
| 60 |
+
finalIDs[key] = struct{}{}
|
| 61 |
+
continue
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
| 67 |
+
if buildErr != nil {
|
| 68 |
+
return nil, nil, nil, nil, buildErr
|
| 69 |
+
}
|
| 70 |
+
if _, ok := oldCfgMap[key]; ok {
|
| 71 |
+
if _, existed := existingMap[key]; existed {
|
| 72 |
+
appendChange(&updated, key)
|
| 73 |
+
} else {
|
| 74 |
+
appendChange(&added, key)
|
| 75 |
+
}
|
| 76 |
+
} else {
|
| 77 |
+
appendChange(&added, key)
|
| 78 |
+
}
|
| 79 |
+
result = append(result, provider)
|
| 80 |
+
finalIDs[key] = struct{}{}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if len(result) == 0 {
|
| 84 |
+
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
| 85 |
+
key := providerIdentifier(inline)
|
| 86 |
+
if key != "" {
|
| 87 |
+
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
| 88 |
+
if providerConfigEqual(oldCfgProvider, inline) {
|
| 89 |
+
if existingProvider, okExisting := existingMap[key]; okExisting {
|
| 90 |
+
result = append(result, existingProvider)
|
| 91 |
+
finalIDs[key] = struct{}{}
|
| 92 |
+
goto inlineDone
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
| 97 |
+
if buildErr != nil {
|
| 98 |
+
return nil, nil, nil, nil, buildErr
|
| 99 |
+
}
|
| 100 |
+
if _, existed := existingMap[key]; existed {
|
| 101 |
+
appendChange(&updated, key)
|
| 102 |
+
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
| 103 |
+
appendChange(&updated, key)
|
| 104 |
+
} else {
|
| 105 |
+
appendChange(&added, key)
|
| 106 |
+
}
|
| 107 |
+
result = append(result, provider)
|
| 108 |
+
finalIDs[key] = struct{}{}
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
inlineDone:
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
removedSet := make(map[string]struct{})
|
| 115 |
+
for id := range existingMap {
|
| 116 |
+
if _, ok := finalIDs[id]; !ok {
|
| 117 |
+
if isInlineProvider(id) {
|
| 118 |
+
continue
|
| 119 |
+
}
|
| 120 |
+
removedSet[id] = struct{}{}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
removed = make([]string, 0, len(removedSet))
|
| 125 |
+
for id := range removedSet {
|
| 126 |
+
removed = append(removed, id)
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
sort.Strings(added)
|
| 130 |
+
sort.Strings(updated)
|
| 131 |
+
sort.Strings(removed)
|
| 132 |
+
|
| 133 |
+
return result, added, updated, removed, nil
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// ApplyAccessProviders reconciles the configured access providers against the
|
| 137 |
+
// currently registered providers and updates the manager. It logs a concise
|
| 138 |
+
// summary of the detected changes and returns whether any provider changed.
|
| 139 |
+
func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Config) (bool, error) {
|
| 140 |
+
if manager == nil || newCfg == nil {
|
| 141 |
+
return false, nil
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
existing := manager.Providers()
|
| 145 |
+
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
| 146 |
+
if err != nil {
|
| 147 |
+
log.Errorf("failed to reconcile request auth providers: %v", err)
|
| 148 |
+
return false, fmt.Errorf("reconciling access providers: %w", err)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
manager.SetProviders(providers)
|
| 152 |
+
|
| 153 |
+
if len(added)+len(updated)+len(removed) > 0 {
|
| 154 |
+
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
| 155 |
+
log.Debugf("auth providers changes details - added=%v updated=%v removed=%v", added, updated, removed)
|
| 156 |
+
return true, nil
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
log.Debug("auth providers unchanged after config update")
|
| 160 |
+
return false, nil
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
| 164 |
+
result := make(map[string]*sdkConfig.AccessProvider)
|
| 165 |
+
if cfg == nil {
|
| 166 |
+
return result
|
| 167 |
+
}
|
| 168 |
+
for i := range cfg.Access.Providers {
|
| 169 |
+
providerCfg := &cfg.Access.Providers[i]
|
| 170 |
+
if providerCfg.Type == "" {
|
| 171 |
+
continue
|
| 172 |
+
}
|
| 173 |
+
key := providerIdentifier(providerCfg)
|
| 174 |
+
if key == "" {
|
| 175 |
+
continue
|
| 176 |
+
}
|
| 177 |
+
result[key] = providerCfg
|
| 178 |
+
}
|
| 179 |
+
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
| 180 |
+
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
| 181 |
+
if key := providerIdentifier(provider); key != "" {
|
| 182 |
+
result[key] = provider
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
return result
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
| 190 |
+
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
| 191 |
+
for i := range cfg.Access.Providers {
|
| 192 |
+
providerCfg := &cfg.Access.Providers[i]
|
| 193 |
+
if providerCfg.Type == "" {
|
| 194 |
+
continue
|
| 195 |
+
}
|
| 196 |
+
if key := providerIdentifier(providerCfg); key != "" {
|
| 197 |
+
entries = append(entries, providerCfg)
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
| 201 |
+
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
| 202 |
+
entries = append(entries, inline)
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
return entries
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
| 209 |
+
if provider == nil {
|
| 210 |
+
return ""
|
| 211 |
+
}
|
| 212 |
+
if name := strings.TrimSpace(provider.Name); name != "" {
|
| 213 |
+
return name
|
| 214 |
+
}
|
| 215 |
+
typ := strings.TrimSpace(provider.Type)
|
| 216 |
+
if typ == "" {
|
| 217 |
+
return ""
|
| 218 |
+
}
|
| 219 |
+
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
| 220 |
+
return sdkConfig.DefaultAccessProviderName
|
| 221 |
+
}
|
| 222 |
+
return typ
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
| 226 |
+
if a == nil || b == nil {
|
| 227 |
+
return a == nil && b == nil
|
| 228 |
+
}
|
| 229 |
+
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
| 230 |
+
return false
|
| 231 |
+
}
|
| 232 |
+
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
| 233 |
+
return false
|
| 234 |
+
}
|
| 235 |
+
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
| 236 |
+
return false
|
| 237 |
+
}
|
| 238 |
+
if len(a.Config) != len(b.Config) {
|
| 239 |
+
return false
|
| 240 |
+
}
|
| 241 |
+
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
| 242 |
+
return false
|
| 243 |
+
}
|
| 244 |
+
return true
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
func stringSetEqual(a, b []string) bool {
|
| 248 |
+
if len(a) != len(b) {
|
| 249 |
+
return false
|
| 250 |
+
}
|
| 251 |
+
if len(a) == 0 {
|
| 252 |
+
return true
|
| 253 |
+
}
|
| 254 |
+
seen := make(map[string]int, len(a))
|
| 255 |
+
for _, val := range a {
|
| 256 |
+
seen[val]++
|
| 257 |
+
}
|
| 258 |
+
for _, val := range b {
|
| 259 |
+
count := seen[val]
|
| 260 |
+
if count == 0 {
|
| 261 |
+
return false
|
| 262 |
+
}
|
| 263 |
+
if count == 1 {
|
| 264 |
+
delete(seen, val)
|
| 265 |
+
} else {
|
| 266 |
+
seen[val] = count - 1
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
return len(seen) == 0
|
| 270 |
+
}
|
internal/api/handlers/management/api_tools.go
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"net"
|
| 9 |
+
"net/http"
|
| 10 |
+
"net/url"
|
| 11 |
+
"strings"
|
| 12 |
+
"time"
|
| 13 |
+
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
| 16 |
+
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 17 |
+
log "github.com/sirupsen/logrus"
|
| 18 |
+
"golang.org/x/net/proxy"
|
| 19 |
+
"golang.org/x/oauth2"
|
| 20 |
+
"golang.org/x/oauth2/google"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
const defaultAPICallTimeout = 60 * time.Second
|
| 24 |
+
|
| 25 |
+
const (
|
| 26 |
+
geminiOAuthClientID = "YOUR_CLIENT_ID"
|
| 27 |
+
geminiOAuthClientSecret = "YOUR_CLIENT_SECRET"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
var geminiOAuthScopes = []string{
|
| 31 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 32 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 33 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
type apiCallRequest struct {
|
| 37 |
+
AuthIndexSnake *string `json:"auth_index"`
|
| 38 |
+
AuthIndexCamel *string `json:"authIndex"`
|
| 39 |
+
AuthIndexPascal *string `json:"AuthIndex"`
|
| 40 |
+
Method string `json:"method"`
|
| 41 |
+
URL string `json:"url"`
|
| 42 |
+
Header map[string]string `json:"header"`
|
| 43 |
+
Data string `json:"data"`
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
type apiCallResponse struct {
|
| 47 |
+
StatusCode int `json:"status_code"`
|
| 48 |
+
Header map[string][]string `json:"header"`
|
| 49 |
+
Body string `json:"body"`
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// APICall makes a generic HTTP request on behalf of the management API caller.
|
| 53 |
+
// It is protected by the management middleware.
|
| 54 |
+
//
|
| 55 |
+
// Endpoint:
|
| 56 |
+
//
|
| 57 |
+
// POST /v0/management/api-call
|
| 58 |
+
//
|
| 59 |
+
// Authentication:
|
| 60 |
+
//
|
| 61 |
+
// Same as other management APIs (requires a management key and remote-management rules).
|
| 62 |
+
// You can provide the key via:
|
| 63 |
+
// - Authorization: Bearer <key>
|
| 64 |
+
// - X-Management-Key: <key>
|
| 65 |
+
//
|
| 66 |
+
// Request JSON:
|
| 67 |
+
// - auth_index / authIndex / AuthIndex (optional):
|
| 68 |
+
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
| 69 |
+
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
| 70 |
+
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
|
| 71 |
+
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
|
| 72 |
+
// - header (optional): Request headers map.
|
| 73 |
+
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
|
| 74 |
+
// 1) metadata.access_token
|
| 75 |
+
// 2) attributes.api_key
|
| 76 |
+
// 3) metadata.token / metadata.id_token / metadata.cookie
|
| 77 |
+
// Example: {"Authorization":"Bearer $TOKEN$"}.
|
| 78 |
+
// Note: if you need to override the HTTP Host header, set header["Host"].
|
| 79 |
+
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
|
| 80 |
+
//
|
| 81 |
+
// Proxy selection (highest priority first):
|
| 82 |
+
// 1. Selected credential proxy_url
|
| 83 |
+
// 2. Global config proxy-url
|
| 84 |
+
// 3. Direct connect (environment proxies are not used)
|
| 85 |
+
//
|
| 86 |
+
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
| 87 |
+
// - status_code: Upstream HTTP status code.
|
| 88 |
+
// - header: Upstream response headers.
|
| 89 |
+
// - body: Upstream response body as string.
|
| 90 |
+
//
|
| 91 |
+
// Example:
|
| 92 |
+
//
|
| 93 |
+
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
| 94 |
+
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
|
| 95 |
+
// -H "Content-Type: application/json" \
|
| 96 |
+
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
|
| 97 |
+
//
|
| 98 |
+
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
|
| 99 |
+
// -H "Authorization: Bearer 831227" \
|
| 100 |
+
// -H "Content-Type: application/json" \
|
| 101 |
+
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
| 102 |
+
func (h *Handler) APICall(c *gin.Context) {
|
| 103 |
+
var body apiCallRequest
|
| 104 |
+
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
| 105 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
| 106 |
+
return
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
| 110 |
+
if method == "" {
|
| 111 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
|
| 112 |
+
return
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
urlStr := strings.TrimSpace(body.URL)
|
| 116 |
+
if urlStr == "" {
|
| 117 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
|
| 118 |
+
return
|
| 119 |
+
}
|
| 120 |
+
parsedURL, errParseURL := url.Parse(urlStr)
|
| 121 |
+
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
| 122 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
| 123 |
+
return
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
|
| 127 |
+
auth := h.authByIndex(authIndex)
|
| 128 |
+
|
| 129 |
+
reqHeaders := body.Header
|
| 130 |
+
if reqHeaders == nil {
|
| 131 |
+
reqHeaders = map[string]string{}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
var hostOverride string
|
| 135 |
+
var token string
|
| 136 |
+
var tokenResolved bool
|
| 137 |
+
var tokenErr error
|
| 138 |
+
for key, value := range reqHeaders {
|
| 139 |
+
if !strings.Contains(value, "$TOKEN$") {
|
| 140 |
+
continue
|
| 141 |
+
}
|
| 142 |
+
if !tokenResolved {
|
| 143 |
+
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
|
| 144 |
+
tokenResolved = true
|
| 145 |
+
}
|
| 146 |
+
if auth != nil && token == "" {
|
| 147 |
+
if tokenErr != nil {
|
| 148 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
|
| 149 |
+
return
|
| 150 |
+
}
|
| 151 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
|
| 152 |
+
return
|
| 153 |
+
}
|
| 154 |
+
if token == "" {
|
| 155 |
+
continue
|
| 156 |
+
}
|
| 157 |
+
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
var requestBody io.Reader
|
| 161 |
+
if body.Data != "" {
|
| 162 |
+
requestBody = strings.NewReader(body.Data)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
| 166 |
+
if errNewRequest != nil {
|
| 167 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
|
| 168 |
+
return
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
for key, value := range reqHeaders {
|
| 172 |
+
if strings.EqualFold(key, "host") {
|
| 173 |
+
hostOverride = strings.TrimSpace(value)
|
| 174 |
+
continue
|
| 175 |
+
}
|
| 176 |
+
req.Header.Set(key, value)
|
| 177 |
+
}
|
| 178 |
+
if hostOverride != "" {
|
| 179 |
+
req.Host = hostOverride
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
httpClient := &http.Client{
|
| 183 |
+
Timeout: defaultAPICallTimeout,
|
| 184 |
+
}
|
| 185 |
+
httpClient.Transport = h.apiCallTransport(auth)
|
| 186 |
+
|
| 187 |
+
resp, errDo := httpClient.Do(req)
|
| 188 |
+
if errDo != nil {
|
| 189 |
+
log.WithError(errDo).Debug("management APICall request failed")
|
| 190 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
|
| 191 |
+
return
|
| 192 |
+
}
|
| 193 |
+
defer func() {
|
| 194 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 195 |
+
log.Errorf("response body close error: %v", errClose)
|
| 196 |
+
}
|
| 197 |
+
}()
|
| 198 |
+
|
| 199 |
+
respBody, errReadAll := io.ReadAll(resp.Body)
|
| 200 |
+
if errReadAll != nil {
|
| 201 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
|
| 202 |
+
return
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
c.JSON(http.StatusOK, apiCallResponse{
|
| 206 |
+
StatusCode: resp.StatusCode,
|
| 207 |
+
Header: resp.Header,
|
| 208 |
+
Body: string(respBody),
|
| 209 |
+
})
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
func firstNonEmptyString(values ...*string) string {
|
| 213 |
+
for _, v := range values {
|
| 214 |
+
if v == nil {
|
| 215 |
+
continue
|
| 216 |
+
}
|
| 217 |
+
if out := strings.TrimSpace(*v); out != "" {
|
| 218 |
+
return out
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
return ""
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
func tokenValueForAuth(auth *coreauth.Auth) string {
|
| 225 |
+
if auth == nil {
|
| 226 |
+
return ""
|
| 227 |
+
}
|
| 228 |
+
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
|
| 229 |
+
return v
|
| 230 |
+
}
|
| 231 |
+
if auth.Attributes != nil {
|
| 232 |
+
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
| 233 |
+
return v
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
| 237 |
+
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
|
| 238 |
+
return v
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
return ""
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
| 245 |
+
if auth == nil {
|
| 246 |
+
return "", nil
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
| 250 |
+
if provider == "gemini-cli" {
|
| 251 |
+
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
|
| 252 |
+
return token, errToken
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
return tokenValueForAuth(auth), nil
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
| 259 |
+
if ctx == nil {
|
| 260 |
+
ctx = context.Background()
|
| 261 |
+
}
|
| 262 |
+
if auth == nil {
|
| 263 |
+
return "", nil
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
metadata, updater := geminiOAuthMetadata(auth)
|
| 267 |
+
if len(metadata) == 0 {
|
| 268 |
+
return "", fmt.Errorf("gemini oauth metadata missing")
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
base := make(map[string]any)
|
| 272 |
+
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
| 273 |
+
base = cloneMap(tokenRaw)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
var token oauth2.Token
|
| 277 |
+
if len(base) > 0 {
|
| 278 |
+
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
|
| 279 |
+
_ = json.Unmarshal(raw, &token)
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
if token.AccessToken == "" {
|
| 284 |
+
token.AccessToken = stringValue(metadata, "access_token")
|
| 285 |
+
}
|
| 286 |
+
if token.RefreshToken == "" {
|
| 287 |
+
token.RefreshToken = stringValue(metadata, "refresh_token")
|
| 288 |
+
}
|
| 289 |
+
if token.TokenType == "" {
|
| 290 |
+
token.TokenType = stringValue(metadata, "token_type")
|
| 291 |
+
}
|
| 292 |
+
if token.Expiry.IsZero() {
|
| 293 |
+
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
| 294 |
+
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
|
| 295 |
+
token.Expiry = ts
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
conf := &oauth2.Config{
|
| 301 |
+
ClientID: geminiOAuthClientID,
|
| 302 |
+
ClientSecret: geminiOAuthClientSecret,
|
| 303 |
+
Scopes: geminiOAuthScopes,
|
| 304 |
+
Endpoint: google.Endpoint,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
ctxToken := ctx
|
| 308 |
+
httpClient := &http.Client{
|
| 309 |
+
Timeout: defaultAPICallTimeout,
|
| 310 |
+
Transport: h.apiCallTransport(auth),
|
| 311 |
+
}
|
| 312 |
+
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
|
| 313 |
+
|
| 314 |
+
src := conf.TokenSource(ctxToken, &token)
|
| 315 |
+
currentToken, errToken := src.Token()
|
| 316 |
+
if errToken != nil {
|
| 317 |
+
return "", errToken
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
merged := buildOAuthTokenMap(base, currentToken)
|
| 321 |
+
fields := buildOAuthTokenFields(currentToken, merged)
|
| 322 |
+
if updater != nil {
|
| 323 |
+
updater(fields)
|
| 324 |
+
}
|
| 325 |
+
return strings.TrimSpace(currentToken.AccessToken), nil
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
|
| 329 |
+
if auth == nil {
|
| 330 |
+
return nil, nil
|
| 331 |
+
}
|
| 332 |
+
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
| 333 |
+
snapshot := shared.MetadataSnapshot()
|
| 334 |
+
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
|
| 335 |
+
}
|
| 336 |
+
return auth.Metadata, func(fields map[string]any) {
|
| 337 |
+
if auth.Metadata == nil {
|
| 338 |
+
auth.Metadata = make(map[string]any)
|
| 339 |
+
}
|
| 340 |
+
for k, v := range fields {
|
| 341 |
+
auth.Metadata[k] = v
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
func stringValue(metadata map[string]any, key string) string {
|
| 347 |
+
if len(metadata) == 0 || key == "" {
|
| 348 |
+
return ""
|
| 349 |
+
}
|
| 350 |
+
if v, ok := metadata[key].(string); ok {
|
| 351 |
+
return strings.TrimSpace(v)
|
| 352 |
+
}
|
| 353 |
+
return ""
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
func cloneMap(in map[string]any) map[string]any {
|
| 357 |
+
if len(in) == 0 {
|
| 358 |
+
return nil
|
| 359 |
+
}
|
| 360 |
+
out := make(map[string]any, len(in))
|
| 361 |
+
for k, v := range in {
|
| 362 |
+
out[k] = v
|
| 363 |
+
}
|
| 364 |
+
return out
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
| 368 |
+
merged := cloneMap(base)
|
| 369 |
+
if merged == nil {
|
| 370 |
+
merged = make(map[string]any)
|
| 371 |
+
}
|
| 372 |
+
if tok == nil {
|
| 373 |
+
return merged
|
| 374 |
+
}
|
| 375 |
+
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
|
| 376 |
+
var tokenMap map[string]any
|
| 377 |
+
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
|
| 378 |
+
for k, v := range tokenMap {
|
| 379 |
+
merged[k] = v
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
return merged
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
| 387 |
+
fields := make(map[string]any, 5)
|
| 388 |
+
if tok != nil && tok.AccessToken != "" {
|
| 389 |
+
fields["access_token"] = tok.AccessToken
|
| 390 |
+
}
|
| 391 |
+
if tok != nil && tok.TokenType != "" {
|
| 392 |
+
fields["token_type"] = tok.TokenType
|
| 393 |
+
}
|
| 394 |
+
if tok != nil && tok.RefreshToken != "" {
|
| 395 |
+
fields["refresh_token"] = tok.RefreshToken
|
| 396 |
+
}
|
| 397 |
+
if tok != nil && !tok.Expiry.IsZero() {
|
| 398 |
+
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
| 399 |
+
}
|
| 400 |
+
if len(merged) > 0 {
|
| 401 |
+
fields["token"] = cloneMap(merged)
|
| 402 |
+
}
|
| 403 |
+
return fields
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
func tokenValueFromMetadata(metadata map[string]any) string {
|
| 407 |
+
if len(metadata) == 0 {
|
| 408 |
+
return ""
|
| 409 |
+
}
|
| 410 |
+
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
| 411 |
+
return strings.TrimSpace(v)
|
| 412 |
+
}
|
| 413 |
+
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
| 414 |
+
return strings.TrimSpace(v)
|
| 415 |
+
}
|
| 416 |
+
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
|
| 417 |
+
switch typed := tokenRaw.(type) {
|
| 418 |
+
case string:
|
| 419 |
+
if v := strings.TrimSpace(typed); v != "" {
|
| 420 |
+
return v
|
| 421 |
+
}
|
| 422 |
+
case map[string]any:
|
| 423 |
+
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
| 424 |
+
return strings.TrimSpace(v)
|
| 425 |
+
}
|
| 426 |
+
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
|
| 427 |
+
return strings.TrimSpace(v)
|
| 428 |
+
}
|
| 429 |
+
case map[string]string:
|
| 430 |
+
if v := strings.TrimSpace(typed["access_token"]); v != "" {
|
| 431 |
+
return v
|
| 432 |
+
}
|
| 433 |
+
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
|
| 434 |
+
return v
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
|
| 439 |
+
return strings.TrimSpace(v)
|
| 440 |
+
}
|
| 441 |
+
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
|
| 442 |
+
return strings.TrimSpace(v)
|
| 443 |
+
}
|
| 444 |
+
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
|
| 445 |
+
return strings.TrimSpace(v)
|
| 446 |
+
}
|
| 447 |
+
return ""
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
|
| 451 |
+
authIndex = strings.TrimSpace(authIndex)
|
| 452 |
+
if authIndex == "" || h == nil || h.authManager == nil {
|
| 453 |
+
return nil
|
| 454 |
+
}
|
| 455 |
+
auths := h.authManager.List()
|
| 456 |
+
for _, auth := range auths {
|
| 457 |
+
if auth == nil {
|
| 458 |
+
continue
|
| 459 |
+
}
|
| 460 |
+
auth.EnsureIndex()
|
| 461 |
+
if auth.Index == authIndex {
|
| 462 |
+
return auth
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
return nil
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
| 469 |
+
var proxyCandidates []string
|
| 470 |
+
if auth != nil {
|
| 471 |
+
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
|
| 472 |
+
proxyCandidates = append(proxyCandidates, proxyStr)
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
if h != nil && h.cfg != nil {
|
| 476 |
+
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
|
| 477 |
+
proxyCandidates = append(proxyCandidates, proxyStr)
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
for _, proxyStr := range proxyCandidates {
|
| 482 |
+
if transport := buildProxyTransport(proxyStr); transport != nil {
|
| 483 |
+
return transport
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
transport, ok := http.DefaultTransport.(*http.Transport)
|
| 488 |
+
if !ok || transport == nil {
|
| 489 |
+
return &http.Transport{Proxy: nil}
|
| 490 |
+
}
|
| 491 |
+
clone := transport.Clone()
|
| 492 |
+
clone.Proxy = nil
|
| 493 |
+
return clone
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
func buildProxyTransport(proxyStr string) *http.Transport {
|
| 497 |
+
proxyStr = strings.TrimSpace(proxyStr)
|
| 498 |
+
if proxyStr == "" {
|
| 499 |
+
return nil
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
proxyURL, errParse := url.Parse(proxyStr)
|
| 503 |
+
if errParse != nil {
|
| 504 |
+
log.WithError(errParse).Debug("parse proxy URL failed")
|
| 505 |
+
return nil
|
| 506 |
+
}
|
| 507 |
+
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
| 508 |
+
log.Debug("proxy URL missing scheme/host")
|
| 509 |
+
return nil
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
if proxyURL.Scheme == "socks5" {
|
| 513 |
+
var proxyAuth *proxy.Auth
|
| 514 |
+
if proxyURL.User != nil {
|
| 515 |
+
username := proxyURL.User.Username()
|
| 516 |
+
password, _ := proxyURL.User.Password()
|
| 517 |
+
proxyAuth = &proxy.Auth{User: username, Password: password}
|
| 518 |
+
}
|
| 519 |
+
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
| 520 |
+
if errSOCKS5 != nil {
|
| 521 |
+
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
| 522 |
+
return nil
|
| 523 |
+
}
|
| 524 |
+
return &http.Transport{
|
| 525 |
+
Proxy: nil,
|
| 526 |
+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
| 527 |
+
return dialer.Dial(network, addr)
|
| 528 |
+
},
|
| 529 |
+
}
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
| 533 |
+
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
| 537 |
+
return nil
|
| 538 |
+
}
|
internal/api/handlers/management/auth_files.go
ADDED
|
@@ -0,0 +1,2606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"context"
|
| 6 |
+
"crypto/rand"
|
| 7 |
+
"crypto/sha256"
|
| 8 |
+
"encoding/base64"
|
| 9 |
+
"encoding/json"
|
| 10 |
+
"errors"
|
| 11 |
+
"fmt"
|
| 12 |
+
"io"
|
| 13 |
+
"net"
|
| 14 |
+
"net/http"
|
| 15 |
+
"net/url"
|
| 16 |
+
"os"
|
| 17 |
+
"path/filepath"
|
| 18 |
+
"sort"
|
| 19 |
+
"strconv"
|
| 20 |
+
"strings"
|
| 21 |
+
"sync"
|
| 22 |
+
"time"
|
| 23 |
+
|
| 24 |
+
"github.com/gin-gonic/gin"
|
| 25 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
| 26 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
| 27 |
+
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
| 28 |
+
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
| 29 |
+
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
| 30 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
| 31 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
| 32 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
| 33 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
| 34 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 35 |
+
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
| 36 |
+
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 37 |
+
log "github.com/sirupsen/logrus"
|
| 38 |
+
"github.com/tidwall/gjson"
|
| 39 |
+
"golang.org/x/oauth2"
|
| 40 |
+
"golang.org/x/oauth2/google"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
| 44 |
+
|
| 45 |
+
const (
|
| 46 |
+
anthropicCallbackPort = 54545
|
| 47 |
+
geminiCallbackPort = 8085
|
| 48 |
+
codexCallbackPort = 1455
|
| 49 |
+
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
|
| 50 |
+
geminiCLIVersion = "v1internal"
|
| 51 |
+
geminiCLIUserAgent = "google-api-nodejs-client/9.15.1"
|
| 52 |
+
geminiCLIApiClient = "gl-node/22.17.0"
|
| 53 |
+
geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
type callbackForwarder struct {
|
| 57 |
+
provider string
|
| 58 |
+
server *http.Server
|
| 59 |
+
done chan struct{}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
var (
|
| 63 |
+
callbackForwardersMu sync.Mutex
|
| 64 |
+
callbackForwarders = make(map[int]*callbackForwarder)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
|
| 68 |
+
if len(meta) == 0 {
|
| 69 |
+
return time.Time{}, false
|
| 70 |
+
}
|
| 71 |
+
for _, key := range lastRefreshKeys {
|
| 72 |
+
if val, ok := meta[key]; ok {
|
| 73 |
+
if ts, ok1 := parseLastRefreshValue(val); ok1 {
|
| 74 |
+
return ts, true
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
return time.Time{}, false
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
func parseLastRefreshValue(v any) (time.Time, bool) {
|
| 82 |
+
switch val := v.(type) {
|
| 83 |
+
case string:
|
| 84 |
+
s := strings.TrimSpace(val)
|
| 85 |
+
if s == "" {
|
| 86 |
+
return time.Time{}, false
|
| 87 |
+
}
|
| 88 |
+
layouts := []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z07:00"}
|
| 89 |
+
for _, layout := range layouts {
|
| 90 |
+
if ts, err := time.Parse(layout, s); err == nil {
|
| 91 |
+
return ts.UTC(), true
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
if unix, err := strconv.ParseInt(s, 10, 64); err == nil && unix > 0 {
|
| 95 |
+
return time.Unix(unix, 0).UTC(), true
|
| 96 |
+
}
|
| 97 |
+
case float64:
|
| 98 |
+
if val <= 0 {
|
| 99 |
+
return time.Time{}, false
|
| 100 |
+
}
|
| 101 |
+
return time.Unix(int64(val), 0).UTC(), true
|
| 102 |
+
case int64:
|
| 103 |
+
if val <= 0 {
|
| 104 |
+
return time.Time{}, false
|
| 105 |
+
}
|
| 106 |
+
return time.Unix(val, 0).UTC(), true
|
| 107 |
+
case int:
|
| 108 |
+
if val <= 0 {
|
| 109 |
+
return time.Time{}, false
|
| 110 |
+
}
|
| 111 |
+
return time.Unix(int64(val), 0).UTC(), true
|
| 112 |
+
case json.Number:
|
| 113 |
+
if i, err := val.Int64(); err == nil && i > 0 {
|
| 114 |
+
return time.Unix(i, 0).UTC(), true
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
return time.Time{}, false
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
func isWebUIRequest(c *gin.Context) bool {
|
| 121 |
+
raw := strings.TrimSpace(c.Query("is_webui"))
|
| 122 |
+
if raw == "" {
|
| 123 |
+
return false
|
| 124 |
+
}
|
| 125 |
+
switch strings.ToLower(raw) {
|
| 126 |
+
case "1", "true", "yes", "on":
|
| 127 |
+
return true
|
| 128 |
+
default:
|
| 129 |
+
return false
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
|
| 134 |
+
callbackForwardersMu.Lock()
|
| 135 |
+
prev := callbackForwarders[port]
|
| 136 |
+
if prev != nil {
|
| 137 |
+
delete(callbackForwarders, port)
|
| 138 |
+
}
|
| 139 |
+
callbackForwardersMu.Unlock()
|
| 140 |
+
|
| 141 |
+
if prev != nil {
|
| 142 |
+
stopForwarderInstance(port, prev)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
| 146 |
+
ln, err := net.Listen("tcp", addr)
|
| 147 |
+
if err != nil {
|
| 148 |
+
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 152 |
+
target := targetBase
|
| 153 |
+
if raw := r.URL.RawQuery; raw != "" {
|
| 154 |
+
if strings.Contains(target, "?") {
|
| 155 |
+
target = target + "&" + raw
|
| 156 |
+
} else {
|
| 157 |
+
target = target + "?" + raw
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
w.Header().Set("Cache-Control", "no-store")
|
| 161 |
+
http.Redirect(w, r, target, http.StatusFound)
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
srv := &http.Server{
|
| 165 |
+
Handler: handler,
|
| 166 |
+
ReadHeaderTimeout: 5 * time.Second,
|
| 167 |
+
WriteTimeout: 5 * time.Second,
|
| 168 |
+
}
|
| 169 |
+
done := make(chan struct{})
|
| 170 |
+
|
| 171 |
+
go func() {
|
| 172 |
+
if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
| 173 |
+
log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
|
| 174 |
+
}
|
| 175 |
+
close(done)
|
| 176 |
+
}()
|
| 177 |
+
|
| 178 |
+
forwarder := &callbackForwarder{
|
| 179 |
+
provider: provider,
|
| 180 |
+
server: srv,
|
| 181 |
+
done: done,
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
callbackForwardersMu.Lock()
|
| 185 |
+
callbackForwarders[port] = forwarder
|
| 186 |
+
callbackForwardersMu.Unlock()
|
| 187 |
+
|
| 188 |
+
log.Infof("callback forwarder for %s listening on %s", provider, addr)
|
| 189 |
+
|
| 190 |
+
return forwarder, nil
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
func stopCallbackForwarder(port int) {
|
| 194 |
+
callbackForwardersMu.Lock()
|
| 195 |
+
forwarder := callbackForwarders[port]
|
| 196 |
+
if forwarder != nil {
|
| 197 |
+
delete(callbackForwarders, port)
|
| 198 |
+
}
|
| 199 |
+
callbackForwardersMu.Unlock()
|
| 200 |
+
|
| 201 |
+
stopForwarderInstance(port, forwarder)
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
| 205 |
+
if forwarder == nil {
|
| 206 |
+
return
|
| 207 |
+
}
|
| 208 |
+
callbackForwardersMu.Lock()
|
| 209 |
+
if current := callbackForwarders[port]; current == forwarder {
|
| 210 |
+
delete(callbackForwarders, port)
|
| 211 |
+
}
|
| 212 |
+
callbackForwardersMu.Unlock()
|
| 213 |
+
|
| 214 |
+
stopForwarderInstance(port, forwarder)
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
| 218 |
+
if forwarder == nil || forwarder.server == nil {
|
| 219 |
+
return
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
| 223 |
+
defer cancel()
|
| 224 |
+
|
| 225 |
+
if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
| 226 |
+
log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
select {
|
| 230 |
+
case <-forwarder.done:
|
| 231 |
+
case <-time.After(2 * time.Second):
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
log.Infof("callback forwarder on port %d stopped", port)
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
func sanitizeAntigravityFileName(email string) string {
|
| 238 |
+
if strings.TrimSpace(email) == "" {
|
| 239 |
+
return "antigravity.json"
|
| 240 |
+
}
|
| 241 |
+
replacer := strings.NewReplacer("@", "_", ".", "_")
|
| 242 |
+
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
| 246 |
+
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
| 247 |
+
return "", fmt.Errorf("server port is not configured")
|
| 248 |
+
}
|
| 249 |
+
if !strings.HasPrefix(path, "/") {
|
| 250 |
+
path = "/" + path
|
| 251 |
+
}
|
| 252 |
+
scheme := "http"
|
| 253 |
+
if h.cfg.TLS.Enable {
|
| 254 |
+
scheme = "https"
|
| 255 |
+
}
|
| 256 |
+
return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
| 260 |
+
if h == nil {
|
| 261 |
+
c.JSON(500, gin.H{"error": "handler not initialized"})
|
| 262 |
+
return
|
| 263 |
+
}
|
| 264 |
+
if h.authManager == nil {
|
| 265 |
+
h.listAuthFilesFromDisk(c)
|
| 266 |
+
return
|
| 267 |
+
}
|
| 268 |
+
auths := h.authManager.List()
|
| 269 |
+
files := make([]gin.H, 0, len(auths))
|
| 270 |
+
for _, auth := range auths {
|
| 271 |
+
if entry := h.buildAuthFileEntry(auth); entry != nil {
|
| 272 |
+
files = append(files, entry)
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
sort.Slice(files, func(i, j int) bool {
|
| 276 |
+
nameI, _ := files[i]["name"].(string)
|
| 277 |
+
nameJ, _ := files[j]["name"].(string)
|
| 278 |
+
return strings.ToLower(nameI) < strings.ToLower(nameJ)
|
| 279 |
+
})
|
| 280 |
+
c.JSON(200, gin.H{"files": files})
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// GetAuthFileModels returns the models supported by a specific auth file
|
| 284 |
+
func (h *Handler) GetAuthFileModels(c *gin.Context) {
|
| 285 |
+
name := c.Query("name")
|
| 286 |
+
if name == "" {
|
| 287 |
+
c.JSON(400, gin.H{"error": "name is required"})
|
| 288 |
+
return
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
// Try to find auth ID via authManager
|
| 292 |
+
var authID string
|
| 293 |
+
if h.authManager != nil {
|
| 294 |
+
auths := h.authManager.List()
|
| 295 |
+
for _, auth := range auths {
|
| 296 |
+
if auth.FileName == name || auth.ID == name {
|
| 297 |
+
authID = auth.ID
|
| 298 |
+
break
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
if authID == "" {
|
| 304 |
+
authID = name // fallback to filename as ID
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// Get models from registry
|
| 308 |
+
reg := registry.GetGlobalRegistry()
|
| 309 |
+
models := reg.GetModelsForClient(authID)
|
| 310 |
+
|
| 311 |
+
result := make([]gin.H, 0, len(models))
|
| 312 |
+
for _, m := range models {
|
| 313 |
+
entry := gin.H{
|
| 314 |
+
"id": m.ID,
|
| 315 |
+
}
|
| 316 |
+
if m.DisplayName != "" {
|
| 317 |
+
entry["display_name"] = m.DisplayName
|
| 318 |
+
}
|
| 319 |
+
if m.Type != "" {
|
| 320 |
+
entry["type"] = m.Type
|
| 321 |
+
}
|
| 322 |
+
if m.OwnedBy != "" {
|
| 323 |
+
entry["owned_by"] = m.OwnedBy
|
| 324 |
+
}
|
| 325 |
+
result = append(result, entry)
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
c.JSON(200, gin.H{"models": result})
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// List auth files from disk when the auth manager is unavailable.
|
| 332 |
+
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
| 333 |
+
entries, err := os.ReadDir(h.cfg.AuthDir)
|
| 334 |
+
if err != nil {
|
| 335 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
|
| 336 |
+
return
|
| 337 |
+
}
|
| 338 |
+
files := make([]gin.H, 0)
|
| 339 |
+
for _, e := range entries {
|
| 340 |
+
if e.IsDir() {
|
| 341 |
+
continue
|
| 342 |
+
}
|
| 343 |
+
name := e.Name()
|
| 344 |
+
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
| 345 |
+
continue
|
| 346 |
+
}
|
| 347 |
+
if info, errInfo := e.Info(); errInfo == nil {
|
| 348 |
+
fileData := gin.H{"name": name, "size": info.Size(), "modtime": info.ModTime()}
|
| 349 |
+
|
| 350 |
+
// Read file to get type field
|
| 351 |
+
full := filepath.Join(h.cfg.AuthDir, name)
|
| 352 |
+
if data, errRead := os.ReadFile(full); errRead == nil {
|
| 353 |
+
typeValue := gjson.GetBytes(data, "type").String()
|
| 354 |
+
emailValue := gjson.GetBytes(data, "email").String()
|
| 355 |
+
fileData["type"] = typeValue
|
| 356 |
+
fileData["email"] = emailValue
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
files = append(files, fileData)
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
c.JSON(200, gin.H{"files": files})
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
| 366 |
+
if auth == nil {
|
| 367 |
+
return nil
|
| 368 |
+
}
|
| 369 |
+
auth.EnsureIndex()
|
| 370 |
+
runtimeOnly := isRuntimeOnlyAuth(auth)
|
| 371 |
+
if runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled) {
|
| 372 |
+
return nil
|
| 373 |
+
}
|
| 374 |
+
path := strings.TrimSpace(authAttribute(auth, "path"))
|
| 375 |
+
if path == "" && !runtimeOnly {
|
| 376 |
+
return nil
|
| 377 |
+
}
|
| 378 |
+
name := strings.TrimSpace(auth.FileName)
|
| 379 |
+
if name == "" {
|
| 380 |
+
name = auth.ID
|
| 381 |
+
}
|
| 382 |
+
entry := gin.H{
|
| 383 |
+
"id": auth.ID,
|
| 384 |
+
"auth_index": auth.Index,
|
| 385 |
+
"name": name,
|
| 386 |
+
"type": strings.TrimSpace(auth.Provider),
|
| 387 |
+
"provider": strings.TrimSpace(auth.Provider),
|
| 388 |
+
"label": auth.Label,
|
| 389 |
+
"status": auth.Status,
|
| 390 |
+
"status_message": auth.StatusMessage,
|
| 391 |
+
"disabled": auth.Disabled,
|
| 392 |
+
"unavailable": auth.Unavailable,
|
| 393 |
+
"runtime_only": runtimeOnly,
|
| 394 |
+
"source": "memory",
|
| 395 |
+
"size": int64(0),
|
| 396 |
+
}
|
| 397 |
+
if email := authEmail(auth); email != "" {
|
| 398 |
+
entry["email"] = email
|
| 399 |
+
}
|
| 400 |
+
if accountType, account := auth.AccountInfo(); accountType != "" || account != "" {
|
| 401 |
+
if accountType != "" {
|
| 402 |
+
entry["account_type"] = accountType
|
| 403 |
+
}
|
| 404 |
+
if account != "" {
|
| 405 |
+
entry["account"] = account
|
| 406 |
+
}
|
| 407 |
+
}
|
| 408 |
+
if !auth.CreatedAt.IsZero() {
|
| 409 |
+
entry["created_at"] = auth.CreatedAt
|
| 410 |
+
}
|
| 411 |
+
if !auth.UpdatedAt.IsZero() {
|
| 412 |
+
entry["modtime"] = auth.UpdatedAt
|
| 413 |
+
entry["updated_at"] = auth.UpdatedAt
|
| 414 |
+
}
|
| 415 |
+
if !auth.LastRefreshedAt.IsZero() {
|
| 416 |
+
entry["last_refresh"] = auth.LastRefreshedAt
|
| 417 |
+
}
|
| 418 |
+
if path != "" {
|
| 419 |
+
entry["path"] = path
|
| 420 |
+
entry["source"] = "file"
|
| 421 |
+
if info, err := os.Stat(path); err == nil {
|
| 422 |
+
entry["size"] = info.Size()
|
| 423 |
+
entry["modtime"] = info.ModTime()
|
| 424 |
+
} else if os.IsNotExist(err) {
|
| 425 |
+
// Hide credentials removed from disk but still lingering in memory.
|
| 426 |
+
if !runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled || strings.EqualFold(strings.TrimSpace(auth.StatusMessage), "removed via management api")) {
|
| 427 |
+
return nil
|
| 428 |
+
}
|
| 429 |
+
entry["source"] = "memory"
|
| 430 |
+
} else {
|
| 431 |
+
log.WithError(err).Warnf("failed to stat auth file %s", path)
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
if claims := extractCodexIDTokenClaims(auth); claims != nil {
|
| 435 |
+
entry["id_token"] = claims
|
| 436 |
+
}
|
| 437 |
+
return entry
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
|
| 441 |
+
if auth == nil || auth.Metadata == nil {
|
| 442 |
+
return nil
|
| 443 |
+
}
|
| 444 |
+
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
| 445 |
+
return nil
|
| 446 |
+
}
|
| 447 |
+
idTokenRaw, ok := auth.Metadata["id_token"].(string)
|
| 448 |
+
if !ok {
|
| 449 |
+
return nil
|
| 450 |
+
}
|
| 451 |
+
idToken := strings.TrimSpace(idTokenRaw)
|
| 452 |
+
if idToken == "" {
|
| 453 |
+
return nil
|
| 454 |
+
}
|
| 455 |
+
claims, err := codex.ParseJWTToken(idToken)
|
| 456 |
+
if err != nil || claims == nil {
|
| 457 |
+
return nil
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
result := gin.H{}
|
| 461 |
+
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
|
| 462 |
+
result["chatgpt_account_id"] = v
|
| 463 |
+
}
|
| 464 |
+
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
|
| 465 |
+
result["plan_type"] = v
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
if len(result) == 0 {
|
| 469 |
+
return nil
|
| 470 |
+
}
|
| 471 |
+
return result
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
func authEmail(auth *coreauth.Auth) string {
|
| 475 |
+
if auth == nil {
|
| 476 |
+
return ""
|
| 477 |
+
}
|
| 478 |
+
if auth.Metadata != nil {
|
| 479 |
+
if v, ok := auth.Metadata["email"].(string); ok {
|
| 480 |
+
return strings.TrimSpace(v)
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
if auth.Attributes != nil {
|
| 484 |
+
if v := strings.TrimSpace(auth.Attributes["email"]); v != "" {
|
| 485 |
+
return v
|
| 486 |
+
}
|
| 487 |
+
if v := strings.TrimSpace(auth.Attributes["account_email"]); v != "" {
|
| 488 |
+
return v
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
return ""
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
func authAttribute(auth *coreauth.Auth, key string) string {
|
| 495 |
+
if auth == nil || len(auth.Attributes) == 0 {
|
| 496 |
+
return ""
|
| 497 |
+
}
|
| 498 |
+
return auth.Attributes[key]
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
|
| 502 |
+
if auth == nil || len(auth.Attributes) == 0 {
|
| 503 |
+
return false
|
| 504 |
+
}
|
| 505 |
+
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
// Download single auth file by name
|
| 509 |
+
func (h *Handler) DownloadAuthFile(c *gin.Context) {
|
| 510 |
+
name := c.Query("name")
|
| 511 |
+
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
| 512 |
+
c.JSON(400, gin.H{"error": "invalid name"})
|
| 513 |
+
return
|
| 514 |
+
}
|
| 515 |
+
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
| 516 |
+
c.JSON(400, gin.H{"error": "name must end with .json"})
|
| 517 |
+
return
|
| 518 |
+
}
|
| 519 |
+
full := filepath.Join(h.cfg.AuthDir, name)
|
| 520 |
+
data, err := os.ReadFile(full)
|
| 521 |
+
if err != nil {
|
| 522 |
+
if os.IsNotExist(err) {
|
| 523 |
+
c.JSON(404, gin.H{"error": "file not found"})
|
| 524 |
+
} else {
|
| 525 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
| 526 |
+
}
|
| 527 |
+
return
|
| 528 |
+
}
|
| 529 |
+
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", name))
|
| 530 |
+
c.Data(200, "application/json", data)
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
// Upload auth file: multipart or raw JSON with ?name=
|
| 534 |
+
func (h *Handler) UploadAuthFile(c *gin.Context) {
|
| 535 |
+
if h.authManager == nil {
|
| 536 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
| 537 |
+
return
|
| 538 |
+
}
|
| 539 |
+
ctx := c.Request.Context()
|
| 540 |
+
if file, err := c.FormFile("file"); err == nil && file != nil {
|
| 541 |
+
name := filepath.Base(file.Filename)
|
| 542 |
+
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
| 543 |
+
c.JSON(400, gin.H{"error": "file must be .json"})
|
| 544 |
+
return
|
| 545 |
+
}
|
| 546 |
+
dst := filepath.Join(h.cfg.AuthDir, name)
|
| 547 |
+
if !filepath.IsAbs(dst) {
|
| 548 |
+
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
| 549 |
+
dst = abs
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
|
| 553 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
|
| 554 |
+
return
|
| 555 |
+
}
|
| 556 |
+
data, errRead := os.ReadFile(dst)
|
| 557 |
+
if errRead != nil {
|
| 558 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
|
| 559 |
+
return
|
| 560 |
+
}
|
| 561 |
+
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
|
| 562 |
+
c.JSON(500, gin.H{"error": errReg.Error()})
|
| 563 |
+
return
|
| 564 |
+
}
|
| 565 |
+
c.JSON(200, gin.H{"status": "ok"})
|
| 566 |
+
return
|
| 567 |
+
}
|
| 568 |
+
name := c.Query("name")
|
| 569 |
+
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
| 570 |
+
c.JSON(400, gin.H{"error": "invalid name"})
|
| 571 |
+
return
|
| 572 |
+
}
|
| 573 |
+
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
| 574 |
+
c.JSON(400, gin.H{"error": "name must end with .json"})
|
| 575 |
+
return
|
| 576 |
+
}
|
| 577 |
+
data, err := io.ReadAll(c.Request.Body)
|
| 578 |
+
if err != nil {
|
| 579 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 580 |
+
return
|
| 581 |
+
}
|
| 582 |
+
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
| 583 |
+
if !filepath.IsAbs(dst) {
|
| 584 |
+
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
|
| 585 |
+
dst = abs
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
|
| 589 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
|
| 590 |
+
return
|
| 591 |
+
}
|
| 592 |
+
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
|
| 593 |
+
c.JSON(500, gin.H{"error": err.Error()})
|
| 594 |
+
return
|
| 595 |
+
}
|
| 596 |
+
c.JSON(200, gin.H{"status": "ok"})
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
// Delete auth files: single by name or all
|
| 600 |
+
func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
| 601 |
+
if h.authManager == nil {
|
| 602 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
| 603 |
+
return
|
| 604 |
+
}
|
| 605 |
+
ctx := c.Request.Context()
|
| 606 |
+
if all := c.Query("all"); all == "true" || all == "1" || all == "*" {
|
| 607 |
+
entries, err := os.ReadDir(h.cfg.AuthDir)
|
| 608 |
+
if err != nil {
|
| 609 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
|
| 610 |
+
return
|
| 611 |
+
}
|
| 612 |
+
deleted := 0
|
| 613 |
+
for _, e := range entries {
|
| 614 |
+
if e.IsDir() {
|
| 615 |
+
continue
|
| 616 |
+
}
|
| 617 |
+
name := e.Name()
|
| 618 |
+
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
| 619 |
+
continue
|
| 620 |
+
}
|
| 621 |
+
full := filepath.Join(h.cfg.AuthDir, name)
|
| 622 |
+
if !filepath.IsAbs(full) {
|
| 623 |
+
if abs, errAbs := filepath.Abs(full); errAbs == nil {
|
| 624 |
+
full = abs
|
| 625 |
+
}
|
| 626 |
+
}
|
| 627 |
+
if err = os.Remove(full); err == nil {
|
| 628 |
+
if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
|
| 629 |
+
c.JSON(500, gin.H{"error": errDel.Error()})
|
| 630 |
+
return
|
| 631 |
+
}
|
| 632 |
+
deleted++
|
| 633 |
+
h.disableAuth(ctx, full)
|
| 634 |
+
}
|
| 635 |
+
}
|
| 636 |
+
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
|
| 637 |
+
return
|
| 638 |
+
}
|
| 639 |
+
name := c.Query("name")
|
| 640 |
+
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
| 641 |
+
c.JSON(400, gin.H{"error": "invalid name"})
|
| 642 |
+
return
|
| 643 |
+
}
|
| 644 |
+
full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
|
| 645 |
+
if !filepath.IsAbs(full) {
|
| 646 |
+
if abs, errAbs := filepath.Abs(full); errAbs == nil {
|
| 647 |
+
full = abs
|
| 648 |
+
}
|
| 649 |
+
}
|
| 650 |
+
if err := os.Remove(full); err != nil {
|
| 651 |
+
if os.IsNotExist(err) {
|
| 652 |
+
c.JSON(404, gin.H{"error": "file not found"})
|
| 653 |
+
} else {
|
| 654 |
+
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
|
| 655 |
+
}
|
| 656 |
+
return
|
| 657 |
+
}
|
| 658 |
+
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
| 659 |
+
c.JSON(500, gin.H{"error": err.Error()})
|
| 660 |
+
return
|
| 661 |
+
}
|
| 662 |
+
h.disableAuth(ctx, full)
|
| 663 |
+
c.JSON(200, gin.H{"status": "ok"})
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
func (h *Handler) authIDForPath(path string) string {
|
| 667 |
+
path = strings.TrimSpace(path)
|
| 668 |
+
if path == "" {
|
| 669 |
+
return ""
|
| 670 |
+
}
|
| 671 |
+
if h == nil || h.cfg == nil {
|
| 672 |
+
return path
|
| 673 |
+
}
|
| 674 |
+
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
| 675 |
+
if authDir == "" {
|
| 676 |
+
return path
|
| 677 |
+
}
|
| 678 |
+
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
| 679 |
+
return rel
|
| 680 |
+
}
|
| 681 |
+
return path
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
| 685 |
+
if h.authManager == nil {
|
| 686 |
+
return nil
|
| 687 |
+
}
|
| 688 |
+
if path == "" {
|
| 689 |
+
return fmt.Errorf("auth path is empty")
|
| 690 |
+
}
|
| 691 |
+
if data == nil {
|
| 692 |
+
var err error
|
| 693 |
+
data, err = os.ReadFile(path)
|
| 694 |
+
if err != nil {
|
| 695 |
+
return fmt.Errorf("failed to read auth file: %w", err)
|
| 696 |
+
}
|
| 697 |
+
}
|
| 698 |
+
metadata := make(map[string]any)
|
| 699 |
+
if err := json.Unmarshal(data, &metadata); err != nil {
|
| 700 |
+
return fmt.Errorf("invalid auth file: %w", err)
|
| 701 |
+
}
|
| 702 |
+
provider, _ := metadata["type"].(string)
|
| 703 |
+
if provider == "" {
|
| 704 |
+
provider = "unknown"
|
| 705 |
+
}
|
| 706 |
+
label := provider
|
| 707 |
+
if email, ok := metadata["email"].(string); ok && email != "" {
|
| 708 |
+
label = email
|
| 709 |
+
}
|
| 710 |
+
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
|
| 711 |
+
|
| 712 |
+
authID := h.authIDForPath(path)
|
| 713 |
+
if authID == "" {
|
| 714 |
+
authID = path
|
| 715 |
+
}
|
| 716 |
+
attr := map[string]string{
|
| 717 |
+
"path": path,
|
| 718 |
+
"source": path,
|
| 719 |
+
}
|
| 720 |
+
auth := &coreauth.Auth{
|
| 721 |
+
ID: authID,
|
| 722 |
+
Provider: provider,
|
| 723 |
+
FileName: filepath.Base(path),
|
| 724 |
+
Label: label,
|
| 725 |
+
Status: coreauth.StatusActive,
|
| 726 |
+
Attributes: attr,
|
| 727 |
+
Metadata: metadata,
|
| 728 |
+
CreatedAt: time.Now(),
|
| 729 |
+
UpdatedAt: time.Now(),
|
| 730 |
+
}
|
| 731 |
+
if hasLastRefresh {
|
| 732 |
+
auth.LastRefreshedAt = lastRefresh
|
| 733 |
+
}
|
| 734 |
+
if existing, ok := h.authManager.GetByID(authID); ok {
|
| 735 |
+
auth.CreatedAt = existing.CreatedAt
|
| 736 |
+
if !hasLastRefresh {
|
| 737 |
+
auth.LastRefreshedAt = existing.LastRefreshedAt
|
| 738 |
+
}
|
| 739 |
+
auth.NextRefreshAfter = existing.NextRefreshAfter
|
| 740 |
+
auth.Runtime = existing.Runtime
|
| 741 |
+
_, err := h.authManager.Update(ctx, auth)
|
| 742 |
+
return err
|
| 743 |
+
}
|
| 744 |
+
_, err := h.authManager.Register(ctx, auth)
|
| 745 |
+
return err
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
| 749 |
+
if h == nil || h.authManager == nil {
|
| 750 |
+
return
|
| 751 |
+
}
|
| 752 |
+
authID := h.authIDForPath(id)
|
| 753 |
+
if authID == "" {
|
| 754 |
+
authID = strings.TrimSpace(id)
|
| 755 |
+
}
|
| 756 |
+
if authID == "" {
|
| 757 |
+
return
|
| 758 |
+
}
|
| 759 |
+
if auth, ok := h.authManager.GetByID(authID); ok {
|
| 760 |
+
auth.Disabled = true
|
| 761 |
+
auth.Status = coreauth.StatusDisabled
|
| 762 |
+
auth.StatusMessage = "removed via management API"
|
| 763 |
+
auth.UpdatedAt = time.Now()
|
| 764 |
+
_, _ = h.authManager.Update(ctx, auth)
|
| 765 |
+
}
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
|
| 769 |
+
if strings.TrimSpace(path) == "" {
|
| 770 |
+
return fmt.Errorf("auth path is empty")
|
| 771 |
+
}
|
| 772 |
+
store := h.tokenStoreWithBaseDir()
|
| 773 |
+
if store == nil {
|
| 774 |
+
return fmt.Errorf("token store unavailable")
|
| 775 |
+
}
|
| 776 |
+
return store.Delete(ctx, path)
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
|
| 780 |
+
if h == nil {
|
| 781 |
+
return nil
|
| 782 |
+
}
|
| 783 |
+
store := h.tokenStore
|
| 784 |
+
if store == nil {
|
| 785 |
+
store = sdkAuth.GetTokenStore()
|
| 786 |
+
h.tokenStore = store
|
| 787 |
+
}
|
| 788 |
+
if h.cfg != nil {
|
| 789 |
+
if dirSetter, ok := store.(interface{ SetBaseDir(string) }); ok {
|
| 790 |
+
dirSetter.SetBaseDir(h.cfg.AuthDir)
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
return store
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
|
| 797 |
+
if record == nil {
|
| 798 |
+
return "", fmt.Errorf("token record is nil")
|
| 799 |
+
}
|
| 800 |
+
store := h.tokenStoreWithBaseDir()
|
| 801 |
+
if store == nil {
|
| 802 |
+
return "", fmt.Errorf("token store unavailable")
|
| 803 |
+
}
|
| 804 |
+
return store.Save(ctx, record)
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
| 808 |
+
ctx := context.Background()
|
| 809 |
+
|
| 810 |
+
fmt.Println("Initializing Claude authentication...")
|
| 811 |
+
|
| 812 |
+
// Generate PKCE codes
|
| 813 |
+
pkceCodes, err := claude.GeneratePKCECodes()
|
| 814 |
+
if err != nil {
|
| 815 |
+
log.Errorf("Failed to generate PKCE codes: %v", err)
|
| 816 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
| 817 |
+
return
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
// Generate random state parameter
|
| 821 |
+
state, err := misc.GenerateRandomState()
|
| 822 |
+
if err != nil {
|
| 823 |
+
log.Errorf("Failed to generate state parameter: %v", err)
|
| 824 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
| 825 |
+
return
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
// Initialize Claude auth service
|
| 829 |
+
anthropicAuth := claude.NewClaudeAuth(h.cfg)
|
| 830 |
+
|
| 831 |
+
// Generate authorization URL (then override redirect_uri to reuse server port)
|
| 832 |
+
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
|
| 833 |
+
if err != nil {
|
| 834 |
+
log.Errorf("Failed to generate authorization URL: %v", err)
|
| 835 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
| 836 |
+
return
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
RegisterOAuthSession(state, "anthropic")
|
| 840 |
+
|
| 841 |
+
isWebUI := isWebUIRequest(c)
|
| 842 |
+
var forwarder *callbackForwarder
|
| 843 |
+
if isWebUI {
|
| 844 |
+
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
| 845 |
+
if errTarget != nil {
|
| 846 |
+
log.WithError(errTarget).Error("failed to compute anthropic callback target")
|
| 847 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
| 848 |
+
return
|
| 849 |
+
}
|
| 850 |
+
var errStart error
|
| 851 |
+
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
| 852 |
+
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
|
| 853 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
| 854 |
+
return
|
| 855 |
+
}
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
go func() {
|
| 859 |
+
if isWebUI {
|
| 860 |
+
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
// Helper: wait for callback file
|
| 864 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
|
| 865 |
+
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
| 866 |
+
deadline := time.Now().Add(timeout)
|
| 867 |
+
for {
|
| 868 |
+
if !IsOAuthSessionPending(state, "anthropic") {
|
| 869 |
+
return nil, errOAuthSessionNotPending
|
| 870 |
+
}
|
| 871 |
+
if time.Now().After(deadline) {
|
| 872 |
+
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
| 873 |
+
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
| 874 |
+
}
|
| 875 |
+
data, errRead := os.ReadFile(path)
|
| 876 |
+
if errRead == nil {
|
| 877 |
+
var m map[string]string
|
| 878 |
+
_ = json.Unmarshal(data, &m)
|
| 879 |
+
_ = os.Remove(path)
|
| 880 |
+
return m, nil
|
| 881 |
+
}
|
| 882 |
+
time.Sleep(500 * time.Millisecond)
|
| 883 |
+
}
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
fmt.Println("Waiting for authentication callback...")
|
| 887 |
+
// Wait up to 5 minutes
|
| 888 |
+
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
|
| 889 |
+
if errWait != nil {
|
| 890 |
+
if errors.Is(errWait, errOAuthSessionNotPending) {
|
| 891 |
+
return
|
| 892 |
+
}
|
| 893 |
+
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
|
| 894 |
+
log.Error(claude.GetUserFriendlyMessage(authErr))
|
| 895 |
+
return
|
| 896 |
+
}
|
| 897 |
+
if errStr := resultMap["error"]; errStr != "" {
|
| 898 |
+
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
| 899 |
+
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
| 900 |
+
SetOAuthSessionError(state, "Bad request")
|
| 901 |
+
return
|
| 902 |
+
}
|
| 903 |
+
if resultMap["state"] != state {
|
| 904 |
+
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
| 905 |
+
log.Error(claude.GetUserFriendlyMessage(authErr))
|
| 906 |
+
SetOAuthSessionError(state, "State code error")
|
| 907 |
+
return
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
// Parse code (Claude may append state after '#')
|
| 911 |
+
rawCode := resultMap["code"]
|
| 912 |
+
code := strings.Split(rawCode, "#")[0]
|
| 913 |
+
|
| 914 |
+
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
| 915 |
+
// Extract client_id from the modified auth URL
|
| 916 |
+
clientID := ""
|
| 917 |
+
if u2, errP := url.Parse(authURL); errP == nil {
|
| 918 |
+
clientID = u2.Query().Get("client_id")
|
| 919 |
+
}
|
| 920 |
+
// Build request
|
| 921 |
+
bodyMap := map[string]any{
|
| 922 |
+
"code": code,
|
| 923 |
+
"state": state,
|
| 924 |
+
"grant_type": "authorization_code",
|
| 925 |
+
"client_id": clientID,
|
| 926 |
+
"redirect_uri": "http://localhost:54545/callback",
|
| 927 |
+
"code_verifier": pkceCodes.CodeVerifier,
|
| 928 |
+
}
|
| 929 |
+
bodyJSON, _ := json.Marshal(bodyMap)
|
| 930 |
+
|
| 931 |
+
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
| 932 |
+
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
| 933 |
+
req.Header.Set("Content-Type", "application/json")
|
| 934 |
+
req.Header.Set("Accept", "application/json")
|
| 935 |
+
resp, errDo := httpClient.Do(req)
|
| 936 |
+
if errDo != nil {
|
| 937 |
+
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
| 938 |
+
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
| 939 |
+
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
| 940 |
+
return
|
| 941 |
+
}
|
| 942 |
+
defer func() {
|
| 943 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 944 |
+
log.Errorf("failed to close response body: %v", errClose)
|
| 945 |
+
}
|
| 946 |
+
}()
|
| 947 |
+
respBody, _ := io.ReadAll(resp.Body)
|
| 948 |
+
if resp.StatusCode != http.StatusOK {
|
| 949 |
+
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
| 950 |
+
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
| 951 |
+
return
|
| 952 |
+
}
|
| 953 |
+
var tResp struct {
|
| 954 |
+
AccessToken string `json:"access_token"`
|
| 955 |
+
RefreshToken string `json:"refresh_token"`
|
| 956 |
+
ExpiresIn int `json:"expires_in"`
|
| 957 |
+
Account struct {
|
| 958 |
+
EmailAddress string `json:"email_address"`
|
| 959 |
+
} `json:"account"`
|
| 960 |
+
}
|
| 961 |
+
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
| 962 |
+
log.Errorf("failed to parse token response: %v", errU)
|
| 963 |
+
SetOAuthSessionError(state, "Failed to parse token response")
|
| 964 |
+
return
|
| 965 |
+
}
|
| 966 |
+
bundle := &claude.ClaudeAuthBundle{
|
| 967 |
+
TokenData: claude.ClaudeTokenData{
|
| 968 |
+
AccessToken: tResp.AccessToken,
|
| 969 |
+
RefreshToken: tResp.RefreshToken,
|
| 970 |
+
Email: tResp.Account.EmailAddress,
|
| 971 |
+
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
| 972 |
+
},
|
| 973 |
+
LastRefresh: time.Now().Format(time.RFC3339),
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
// Create token storage
|
| 977 |
+
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
| 978 |
+
record := &coreauth.Auth{
|
| 979 |
+
ID: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
|
| 980 |
+
Provider: "claude",
|
| 981 |
+
FileName: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
|
| 982 |
+
Storage: tokenStorage,
|
| 983 |
+
Metadata: map[string]any{"email": tokenStorage.Email},
|
| 984 |
+
}
|
| 985 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 986 |
+
if errSave != nil {
|
| 987 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 988 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 989 |
+
return
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 993 |
+
if bundle.APIKey != "" {
|
| 994 |
+
fmt.Println("API key obtained and saved")
|
| 995 |
+
}
|
| 996 |
+
fmt.Println("You can now use Claude services through this CLI")
|
| 997 |
+
CompleteOAuthSession(state)
|
| 998 |
+
CompleteOAuthSessionsByProvider("anthropic")
|
| 999 |
+
}()
|
| 1000 |
+
|
| 1001 |
+
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
| 1005 |
+
ctx := context.Background()
|
| 1006 |
+
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
| 1007 |
+
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
| 1008 |
+
|
| 1009 |
+
// Optional project ID from query
|
| 1010 |
+
projectID := c.Query("project_id")
|
| 1011 |
+
|
| 1012 |
+
fmt.Println("Initializing Google authentication...")
|
| 1013 |
+
|
| 1014 |
+
// OAuth2 configuration (mirrors internal/auth/gemini)
|
| 1015 |
+
conf := &oauth2.Config{
|
| 1016 |
+
ClientID: "YOUR_CLIENT_ID",
|
| 1017 |
+
ClientSecret: "YOUR_CLIENT_SECRET",
|
| 1018 |
+
RedirectURL: "http://localhost:8085/oauth2callback",
|
| 1019 |
+
Scopes: []string{
|
| 1020 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 1021 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 1022 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
| 1023 |
+
},
|
| 1024 |
+
Endpoint: google.Endpoint,
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
// Build authorization URL and return it immediately
|
| 1028 |
+
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
| 1029 |
+
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
| 1030 |
+
|
| 1031 |
+
RegisterOAuthSession(state, "gemini")
|
| 1032 |
+
|
| 1033 |
+
isWebUI := isWebUIRequest(c)
|
| 1034 |
+
var forwarder *callbackForwarder
|
| 1035 |
+
if isWebUI {
|
| 1036 |
+
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
| 1037 |
+
if errTarget != nil {
|
| 1038 |
+
log.WithError(errTarget).Error("failed to compute gemini callback target")
|
| 1039 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
| 1040 |
+
return
|
| 1041 |
+
}
|
| 1042 |
+
var errStart error
|
| 1043 |
+
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
| 1044 |
+
log.WithError(errStart).Error("failed to start gemini callback forwarder")
|
| 1045 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
| 1046 |
+
return
|
| 1047 |
+
}
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
go func() {
|
| 1051 |
+
if isWebUI {
|
| 1052 |
+
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
// Wait for callback file written by server route
|
| 1056 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
|
| 1057 |
+
fmt.Println("Waiting for authentication callback...")
|
| 1058 |
+
deadline := time.Now().Add(5 * time.Minute)
|
| 1059 |
+
var authCode string
|
| 1060 |
+
for {
|
| 1061 |
+
if !IsOAuthSessionPending(state, "gemini") {
|
| 1062 |
+
return
|
| 1063 |
+
}
|
| 1064 |
+
if time.Now().After(deadline) {
|
| 1065 |
+
log.Error("oauth flow timed out")
|
| 1066 |
+
SetOAuthSessionError(state, "OAuth flow timed out")
|
| 1067 |
+
return
|
| 1068 |
+
}
|
| 1069 |
+
if data, errR := os.ReadFile(waitFile); errR == nil {
|
| 1070 |
+
var m map[string]string
|
| 1071 |
+
_ = json.Unmarshal(data, &m)
|
| 1072 |
+
_ = os.Remove(waitFile)
|
| 1073 |
+
if errStr := m["error"]; errStr != "" {
|
| 1074 |
+
log.Errorf("Authentication failed: %s", errStr)
|
| 1075 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1076 |
+
return
|
| 1077 |
+
}
|
| 1078 |
+
authCode = m["code"]
|
| 1079 |
+
if authCode == "" {
|
| 1080 |
+
log.Errorf("Authentication failed: code not found")
|
| 1081 |
+
SetOAuthSessionError(state, "Authentication failed: code not found")
|
| 1082 |
+
return
|
| 1083 |
+
}
|
| 1084 |
+
break
|
| 1085 |
+
}
|
| 1086 |
+
time.Sleep(500 * time.Millisecond)
|
| 1087 |
+
}
|
| 1088 |
+
|
| 1089 |
+
// Exchange authorization code for token
|
| 1090 |
+
token, err := conf.Exchange(ctx, authCode)
|
| 1091 |
+
if err != nil {
|
| 1092 |
+
log.Errorf("Failed to exchange token: %v", err)
|
| 1093 |
+
SetOAuthSessionError(state, "Failed to exchange token")
|
| 1094 |
+
return
|
| 1095 |
+
}
|
| 1096 |
+
|
| 1097 |
+
requestedProjectID := strings.TrimSpace(projectID)
|
| 1098 |
+
|
| 1099 |
+
// Create token storage (mirrors internal/auth/gemini createTokenStorage)
|
| 1100 |
+
authHTTPClient := conf.Client(ctx, token)
|
| 1101 |
+
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
| 1102 |
+
if errNewRequest != nil {
|
| 1103 |
+
log.Errorf("Could not get user info: %v", errNewRequest)
|
| 1104 |
+
SetOAuthSessionError(state, "Could not get user info")
|
| 1105 |
+
return
|
| 1106 |
+
}
|
| 1107 |
+
req.Header.Set("Content-Type", "application/json")
|
| 1108 |
+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
| 1109 |
+
|
| 1110 |
+
resp, errDo := authHTTPClient.Do(req)
|
| 1111 |
+
if errDo != nil {
|
| 1112 |
+
log.Errorf("Failed to execute request: %v", errDo)
|
| 1113 |
+
SetOAuthSessionError(state, "Failed to execute request")
|
| 1114 |
+
return
|
| 1115 |
+
}
|
| 1116 |
+
defer func() {
|
| 1117 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 1118 |
+
log.Printf("warn: failed to close response body: %v", errClose)
|
| 1119 |
+
}
|
| 1120 |
+
}()
|
| 1121 |
+
|
| 1122 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 1123 |
+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
| 1124 |
+
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
| 1125 |
+
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
| 1126 |
+
return
|
| 1127 |
+
}
|
| 1128 |
+
|
| 1129 |
+
email := gjson.GetBytes(bodyBytes, "email").String()
|
| 1130 |
+
if email != "" {
|
| 1131 |
+
fmt.Printf("Authenticated user email: %s\n", email)
|
| 1132 |
+
} else {
|
| 1133 |
+
fmt.Println("Failed to get user email from token")
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
| 1137 |
+
var ifToken map[string]any
|
| 1138 |
+
jsonData, _ := json.Marshal(token)
|
| 1139 |
+
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
| 1140 |
+
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
| 1141 |
+
SetOAuthSessionError(state, "Failed to unmarshal token")
|
| 1142 |
+
return
|
| 1143 |
+
}
|
| 1144 |
+
|
| 1145 |
+
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
| 1146 |
+
ifToken["client_id"] = "YOUR_CLIENT_ID"
|
| 1147 |
+
ifToken["client_secret"] = "YOUR_CLIENT_SECRET"
|
| 1148 |
+
ifToken["scopes"] = []string{
|
| 1149 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 1150 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 1151 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
| 1152 |
+
}
|
| 1153 |
+
ifToken["universe_domain"] = "googleapis.com"
|
| 1154 |
+
|
| 1155 |
+
ts := geminiAuth.GeminiTokenStorage{
|
| 1156 |
+
Token: ifToken,
|
| 1157 |
+
ProjectID: requestedProjectID,
|
| 1158 |
+
Email: email,
|
| 1159 |
+
Auto: requestedProjectID == "",
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
|
| 1163 |
+
gemAuth := geminiAuth.NewGeminiAuth()
|
| 1164 |
+
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
|
| 1165 |
+
NoBrowser: true,
|
| 1166 |
+
})
|
| 1167 |
+
if errGetClient != nil {
|
| 1168 |
+
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
| 1169 |
+
SetOAuthSessionError(state, "Failed to get authenticated client")
|
| 1170 |
+
return
|
| 1171 |
+
}
|
| 1172 |
+
fmt.Println("Authentication successful.")
|
| 1173 |
+
|
| 1174 |
+
if strings.EqualFold(requestedProjectID, "ALL") {
|
| 1175 |
+
ts.Auto = false
|
| 1176 |
+
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
| 1177 |
+
if errAll != nil {
|
| 1178 |
+
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
| 1179 |
+
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
| 1180 |
+
return
|
| 1181 |
+
}
|
| 1182 |
+
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
| 1183 |
+
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
| 1184 |
+
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
| 1185 |
+
return
|
| 1186 |
+
}
|
| 1187 |
+
ts.ProjectID = strings.Join(projects, ",")
|
| 1188 |
+
ts.Checked = true
|
| 1189 |
+
} else {
|
| 1190 |
+
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
| 1191 |
+
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
| 1192 |
+
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
| 1193 |
+
return
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
+
if strings.TrimSpace(ts.ProjectID) == "" {
|
| 1197 |
+
log.Error("Onboarding did not return a project ID")
|
| 1198 |
+
SetOAuthSessionError(state, "Failed to resolve project ID")
|
| 1199 |
+
return
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
| 1203 |
+
if errCheck != nil {
|
| 1204 |
+
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
| 1205 |
+
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
| 1206 |
+
return
|
| 1207 |
+
}
|
| 1208 |
+
ts.Checked = isChecked
|
| 1209 |
+
if !isChecked {
|
| 1210 |
+
log.Error("Cloud AI API is not enabled for the selected project")
|
| 1211 |
+
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
| 1212 |
+
return
|
| 1213 |
+
}
|
| 1214 |
+
}
|
| 1215 |
+
|
| 1216 |
+
recordMetadata := map[string]any{
|
| 1217 |
+
"email": ts.Email,
|
| 1218 |
+
"project_id": ts.ProjectID,
|
| 1219 |
+
"auto": ts.Auto,
|
| 1220 |
+
"checked": ts.Checked,
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
fileName := geminiAuth.CredentialFileName(ts.Email, ts.ProjectID, true)
|
| 1224 |
+
record := &coreauth.Auth{
|
| 1225 |
+
ID: fileName,
|
| 1226 |
+
Provider: "gemini",
|
| 1227 |
+
FileName: fileName,
|
| 1228 |
+
Storage: &ts,
|
| 1229 |
+
Metadata: recordMetadata,
|
| 1230 |
+
}
|
| 1231 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1232 |
+
if errSave != nil {
|
| 1233 |
+
log.Errorf("Failed to save token to file: %v", errSave)
|
| 1234 |
+
SetOAuthSessionError(state, "Failed to save token to file")
|
| 1235 |
+
return
|
| 1236 |
+
}
|
| 1237 |
+
|
| 1238 |
+
CompleteOAuthSession(state)
|
| 1239 |
+
CompleteOAuthSessionsByProvider("gemini")
|
| 1240 |
+
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
| 1241 |
+
}()
|
| 1242 |
+
|
| 1243 |
+
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1244 |
+
}
|
| 1245 |
+
|
| 1246 |
+
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
| 1247 |
+
ctx := context.Background()
|
| 1248 |
+
|
| 1249 |
+
fmt.Println("Initializing Codex authentication...")
|
| 1250 |
+
|
| 1251 |
+
// Generate PKCE codes
|
| 1252 |
+
pkceCodes, err := codex.GeneratePKCECodes()
|
| 1253 |
+
if err != nil {
|
| 1254 |
+
log.Errorf("Failed to generate PKCE codes: %v", err)
|
| 1255 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
| 1256 |
+
return
|
| 1257 |
+
}
|
| 1258 |
+
|
| 1259 |
+
// Generate random state parameter
|
| 1260 |
+
state, err := misc.GenerateRandomState()
|
| 1261 |
+
if err != nil {
|
| 1262 |
+
log.Errorf("Failed to generate state parameter: %v", err)
|
| 1263 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
| 1264 |
+
return
|
| 1265 |
+
}
|
| 1266 |
+
|
| 1267 |
+
// Initialize Codex auth service
|
| 1268 |
+
openaiAuth := codex.NewCodexAuth(h.cfg)
|
| 1269 |
+
|
| 1270 |
+
// Generate authorization URL
|
| 1271 |
+
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
|
| 1272 |
+
if err != nil {
|
| 1273 |
+
log.Errorf("Failed to generate authorization URL: %v", err)
|
| 1274 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
| 1275 |
+
return
|
| 1276 |
+
}
|
| 1277 |
+
|
| 1278 |
+
RegisterOAuthSession(state, "codex")
|
| 1279 |
+
|
| 1280 |
+
isWebUI := isWebUIRequest(c)
|
| 1281 |
+
var forwarder *callbackForwarder
|
| 1282 |
+
if isWebUI {
|
| 1283 |
+
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
| 1284 |
+
if errTarget != nil {
|
| 1285 |
+
log.WithError(errTarget).Error("failed to compute codex callback target")
|
| 1286 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
| 1287 |
+
return
|
| 1288 |
+
}
|
| 1289 |
+
var errStart error
|
| 1290 |
+
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
| 1291 |
+
log.WithError(errStart).Error("failed to start codex callback forwarder")
|
| 1292 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
| 1293 |
+
return
|
| 1294 |
+
}
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
go func() {
|
| 1298 |
+
if isWebUI {
|
| 1299 |
+
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
// Wait for callback file
|
| 1303 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
|
| 1304 |
+
deadline := time.Now().Add(5 * time.Minute)
|
| 1305 |
+
var code string
|
| 1306 |
+
for {
|
| 1307 |
+
if !IsOAuthSessionPending(state, "codex") {
|
| 1308 |
+
return
|
| 1309 |
+
}
|
| 1310 |
+
if time.Now().After(deadline) {
|
| 1311 |
+
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
| 1312 |
+
log.Error(codex.GetUserFriendlyMessage(authErr))
|
| 1313 |
+
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
| 1314 |
+
return
|
| 1315 |
+
}
|
| 1316 |
+
if data, errR := os.ReadFile(waitFile); errR == nil {
|
| 1317 |
+
var m map[string]string
|
| 1318 |
+
_ = json.Unmarshal(data, &m)
|
| 1319 |
+
_ = os.Remove(waitFile)
|
| 1320 |
+
if errStr := m["error"]; errStr != "" {
|
| 1321 |
+
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
| 1322 |
+
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
| 1323 |
+
SetOAuthSessionError(state, "Bad Request")
|
| 1324 |
+
return
|
| 1325 |
+
}
|
| 1326 |
+
if m["state"] != state {
|
| 1327 |
+
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
| 1328 |
+
SetOAuthSessionError(state, "State code error")
|
| 1329 |
+
log.Error(codex.GetUserFriendlyMessage(authErr))
|
| 1330 |
+
return
|
| 1331 |
+
}
|
| 1332 |
+
code = m["code"]
|
| 1333 |
+
break
|
| 1334 |
+
}
|
| 1335 |
+
time.Sleep(500 * time.Millisecond)
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
+
log.Debug("Authorization code received, exchanging for tokens...")
|
| 1339 |
+
// Extract client_id from authURL
|
| 1340 |
+
clientID := ""
|
| 1341 |
+
if u2, errP := url.Parse(authURL); errP == nil {
|
| 1342 |
+
clientID = u2.Query().Get("client_id")
|
| 1343 |
+
}
|
| 1344 |
+
// Exchange code for tokens with redirect equal to mgmtRedirect
|
| 1345 |
+
form := url.Values{
|
| 1346 |
+
"grant_type": {"authorization_code"},
|
| 1347 |
+
"client_id": {clientID},
|
| 1348 |
+
"code": {code},
|
| 1349 |
+
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
| 1350 |
+
"code_verifier": {pkceCodes.CodeVerifier},
|
| 1351 |
+
}
|
| 1352 |
+
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
| 1353 |
+
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
| 1354 |
+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
| 1355 |
+
req.Header.Set("Accept", "application/json")
|
| 1356 |
+
resp, errDo := httpClient.Do(req)
|
| 1357 |
+
if errDo != nil {
|
| 1358 |
+
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
| 1359 |
+
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
| 1360 |
+
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
| 1361 |
+
return
|
| 1362 |
+
}
|
| 1363 |
+
defer func() { _ = resp.Body.Close() }()
|
| 1364 |
+
respBody, _ := io.ReadAll(resp.Body)
|
| 1365 |
+
if resp.StatusCode != http.StatusOK {
|
| 1366 |
+
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
| 1367 |
+
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
| 1368 |
+
return
|
| 1369 |
+
}
|
| 1370 |
+
var tokenResp struct {
|
| 1371 |
+
AccessToken string `json:"access_token"`
|
| 1372 |
+
RefreshToken string `json:"refresh_token"`
|
| 1373 |
+
IDToken string `json:"id_token"`
|
| 1374 |
+
ExpiresIn int `json:"expires_in"`
|
| 1375 |
+
}
|
| 1376 |
+
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
| 1377 |
+
SetOAuthSessionError(state, "Failed to parse token response")
|
| 1378 |
+
log.Errorf("failed to parse token response: %v", errU)
|
| 1379 |
+
return
|
| 1380 |
+
}
|
| 1381 |
+
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
| 1382 |
+
email := ""
|
| 1383 |
+
accountID := ""
|
| 1384 |
+
if claims != nil {
|
| 1385 |
+
email = claims.GetUserEmail()
|
| 1386 |
+
accountID = claims.GetAccountID()
|
| 1387 |
+
}
|
| 1388 |
+
// Build bundle compatible with existing storage
|
| 1389 |
+
bundle := &codex.CodexAuthBundle{
|
| 1390 |
+
TokenData: codex.CodexTokenData{
|
| 1391 |
+
IDToken: tokenResp.IDToken,
|
| 1392 |
+
AccessToken: tokenResp.AccessToken,
|
| 1393 |
+
RefreshToken: tokenResp.RefreshToken,
|
| 1394 |
+
AccountID: accountID,
|
| 1395 |
+
Email: email,
|
| 1396 |
+
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
| 1397 |
+
},
|
| 1398 |
+
LastRefresh: time.Now().Format(time.RFC3339),
|
| 1399 |
+
}
|
| 1400 |
+
|
| 1401 |
+
// Create token storage and persist
|
| 1402 |
+
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
| 1403 |
+
record := &coreauth.Auth{
|
| 1404 |
+
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
| 1405 |
+
Provider: "codex",
|
| 1406 |
+
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
| 1407 |
+
Storage: tokenStorage,
|
| 1408 |
+
Metadata: map[string]any{
|
| 1409 |
+
"email": tokenStorage.Email,
|
| 1410 |
+
"account_id": tokenStorage.AccountID,
|
| 1411 |
+
},
|
| 1412 |
+
}
|
| 1413 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1414 |
+
if errSave != nil {
|
| 1415 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 1416 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 1417 |
+
return
|
| 1418 |
+
}
|
| 1419 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 1420 |
+
if bundle.APIKey != "" {
|
| 1421 |
+
fmt.Println("API key obtained and saved")
|
| 1422 |
+
}
|
| 1423 |
+
fmt.Println("You can now use Codex services through this CLI")
|
| 1424 |
+
CompleteOAuthSession(state)
|
| 1425 |
+
CompleteOAuthSessionsByProvider("codex")
|
| 1426 |
+
}()
|
| 1427 |
+
|
| 1428 |
+
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
| 1432 |
+
const (
|
| 1433 |
+
antigravityCallbackPort = 51121
|
| 1434 |
+
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
| 1435 |
+
antigravityClientSecret = "YOUR_ANTIGRAVITY_CLIENT_SECRET"
|
| 1436 |
+
)
|
| 1437 |
+
var antigravityScopes = []string{
|
| 1438 |
+
"https://www.googleapis.com/auth/cloud-platform",
|
| 1439 |
+
"https://www.googleapis.com/auth/userinfo.email",
|
| 1440 |
+
"https://www.googleapis.com/auth/userinfo.profile",
|
| 1441 |
+
"https://www.googleapis.com/auth/cclog",
|
| 1442 |
+
"https://www.googleapis.com/auth/experimentsandconfigs",
|
| 1443 |
+
}
|
| 1444 |
+
|
| 1445 |
+
ctx := context.Background()
|
| 1446 |
+
|
| 1447 |
+
fmt.Println("Initializing Antigravity authentication...")
|
| 1448 |
+
|
| 1449 |
+
state, errState := misc.GenerateRandomState()
|
| 1450 |
+
if errState != nil {
|
| 1451 |
+
log.Errorf("Failed to generate state parameter: %v", errState)
|
| 1452 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
| 1453 |
+
return
|
| 1454 |
+
}
|
| 1455 |
+
|
| 1456 |
+
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
|
| 1457 |
+
|
| 1458 |
+
params := url.Values{}
|
| 1459 |
+
params.Set("access_type", "offline")
|
| 1460 |
+
params.Set("client_id", antigravityClientID)
|
| 1461 |
+
params.Set("prompt", "consent")
|
| 1462 |
+
params.Set("redirect_uri", redirectURI)
|
| 1463 |
+
params.Set("response_type", "code")
|
| 1464 |
+
params.Set("scope", strings.Join(antigravityScopes, " "))
|
| 1465 |
+
params.Set("state", state)
|
| 1466 |
+
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
| 1467 |
+
|
| 1468 |
+
RegisterOAuthSession(state, "antigravity")
|
| 1469 |
+
|
| 1470 |
+
isWebUI := isWebUIRequest(c)
|
| 1471 |
+
var forwarder *callbackForwarder
|
| 1472 |
+
if isWebUI {
|
| 1473 |
+
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
| 1474 |
+
if errTarget != nil {
|
| 1475 |
+
log.WithError(errTarget).Error("failed to compute antigravity callback target")
|
| 1476 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
| 1477 |
+
return
|
| 1478 |
+
}
|
| 1479 |
+
var errStart error
|
| 1480 |
+
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
| 1481 |
+
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
| 1482 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
| 1483 |
+
return
|
| 1484 |
+
}
|
| 1485 |
+
}
|
| 1486 |
+
|
| 1487 |
+
go func() {
|
| 1488 |
+
if isWebUI {
|
| 1489 |
+
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
| 1490 |
+
}
|
| 1491 |
+
|
| 1492 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
| 1493 |
+
deadline := time.Now().Add(5 * time.Minute)
|
| 1494 |
+
var authCode string
|
| 1495 |
+
for {
|
| 1496 |
+
if !IsOAuthSessionPending(state, "antigravity") {
|
| 1497 |
+
return
|
| 1498 |
+
}
|
| 1499 |
+
if time.Now().After(deadline) {
|
| 1500 |
+
log.Error("oauth flow timed out")
|
| 1501 |
+
SetOAuthSessionError(state, "OAuth flow timed out")
|
| 1502 |
+
return
|
| 1503 |
+
}
|
| 1504 |
+
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
| 1505 |
+
var payload map[string]string
|
| 1506 |
+
_ = json.Unmarshal(data, &payload)
|
| 1507 |
+
_ = os.Remove(waitFile)
|
| 1508 |
+
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
| 1509 |
+
log.Errorf("Authentication failed: %s", errStr)
|
| 1510 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1511 |
+
return
|
| 1512 |
+
}
|
| 1513 |
+
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
| 1514 |
+
log.Errorf("Authentication failed: state mismatch")
|
| 1515 |
+
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
| 1516 |
+
return
|
| 1517 |
+
}
|
| 1518 |
+
authCode = strings.TrimSpace(payload["code"])
|
| 1519 |
+
if authCode == "" {
|
| 1520 |
+
log.Error("Authentication failed: code not found")
|
| 1521 |
+
SetOAuthSessionError(state, "Authentication failed: code not found")
|
| 1522 |
+
return
|
| 1523 |
+
}
|
| 1524 |
+
break
|
| 1525 |
+
}
|
| 1526 |
+
time.Sleep(500 * time.Millisecond)
|
| 1527 |
+
}
|
| 1528 |
+
|
| 1529 |
+
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
| 1530 |
+
form := url.Values{}
|
| 1531 |
+
form.Set("code", authCode)
|
| 1532 |
+
form.Set("client_id", antigravityClientID)
|
| 1533 |
+
form.Set("client_secret", antigravityClientSecret)
|
| 1534 |
+
form.Set("redirect_uri", redirectURI)
|
| 1535 |
+
form.Set("grant_type", "authorization_code")
|
| 1536 |
+
|
| 1537 |
+
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
| 1538 |
+
if errNewRequest != nil {
|
| 1539 |
+
log.Errorf("Failed to build token request: %v", errNewRequest)
|
| 1540 |
+
SetOAuthSessionError(state, "Failed to build token request")
|
| 1541 |
+
return
|
| 1542 |
+
}
|
| 1543 |
+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
| 1544 |
+
|
| 1545 |
+
resp, errDo := httpClient.Do(req)
|
| 1546 |
+
if errDo != nil {
|
| 1547 |
+
log.Errorf("Failed to execute token request: %v", errDo)
|
| 1548 |
+
SetOAuthSessionError(state, "Failed to exchange token")
|
| 1549 |
+
return
|
| 1550 |
+
}
|
| 1551 |
+
defer func() {
|
| 1552 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 1553 |
+
log.Errorf("antigravity token exchange close error: %v", errClose)
|
| 1554 |
+
}
|
| 1555 |
+
}()
|
| 1556 |
+
|
| 1557 |
+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
| 1558 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 1559 |
+
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
| 1560 |
+
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
| 1561 |
+
return
|
| 1562 |
+
}
|
| 1563 |
+
|
| 1564 |
+
var tokenResp struct {
|
| 1565 |
+
AccessToken string `json:"access_token"`
|
| 1566 |
+
RefreshToken string `json:"refresh_token"`
|
| 1567 |
+
ExpiresIn int64 `json:"expires_in"`
|
| 1568 |
+
TokenType string `json:"token_type"`
|
| 1569 |
+
}
|
| 1570 |
+
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
| 1571 |
+
log.Errorf("Failed to parse token response: %v", errDecode)
|
| 1572 |
+
SetOAuthSessionError(state, "Failed to parse token response")
|
| 1573 |
+
return
|
| 1574 |
+
}
|
| 1575 |
+
|
| 1576 |
+
email := ""
|
| 1577 |
+
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
| 1578 |
+
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
| 1579 |
+
if errInfoReq != nil {
|
| 1580 |
+
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
| 1581 |
+
SetOAuthSessionError(state, "Failed to build user info request")
|
| 1582 |
+
return
|
| 1583 |
+
}
|
| 1584 |
+
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
| 1585 |
+
|
| 1586 |
+
infoResp, errInfo := httpClient.Do(infoReq)
|
| 1587 |
+
if errInfo != nil {
|
| 1588 |
+
log.Errorf("Failed to execute user info request: %v", errInfo)
|
| 1589 |
+
SetOAuthSessionError(state, "Failed to execute user info request")
|
| 1590 |
+
return
|
| 1591 |
+
}
|
| 1592 |
+
defer func() {
|
| 1593 |
+
if errClose := infoResp.Body.Close(); errClose != nil {
|
| 1594 |
+
log.Errorf("antigravity user info close error: %v", errClose)
|
| 1595 |
+
}
|
| 1596 |
+
}()
|
| 1597 |
+
|
| 1598 |
+
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
|
| 1599 |
+
var infoPayload struct {
|
| 1600 |
+
Email string `json:"email"`
|
| 1601 |
+
}
|
| 1602 |
+
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
|
| 1603 |
+
email = strings.TrimSpace(infoPayload.Email)
|
| 1604 |
+
}
|
| 1605 |
+
} else {
|
| 1606 |
+
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
| 1607 |
+
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
| 1608 |
+
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
| 1609 |
+
return
|
| 1610 |
+
}
|
| 1611 |
+
}
|
| 1612 |
+
|
| 1613 |
+
projectID := ""
|
| 1614 |
+
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
| 1615 |
+
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
| 1616 |
+
if errProject != nil {
|
| 1617 |
+
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
| 1618 |
+
} else {
|
| 1619 |
+
projectID = fetchedProjectID
|
| 1620 |
+
log.Infof("antigravity: obtained project ID %s", projectID)
|
| 1621 |
+
}
|
| 1622 |
+
}
|
| 1623 |
+
|
| 1624 |
+
now := time.Now()
|
| 1625 |
+
metadata := map[string]any{
|
| 1626 |
+
"type": "antigravity",
|
| 1627 |
+
"access_token": tokenResp.AccessToken,
|
| 1628 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 1629 |
+
"expires_in": tokenResp.ExpiresIn,
|
| 1630 |
+
"timestamp": now.UnixMilli(),
|
| 1631 |
+
"expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
| 1632 |
+
}
|
| 1633 |
+
if email != "" {
|
| 1634 |
+
metadata["email"] = email
|
| 1635 |
+
}
|
| 1636 |
+
if projectID != "" {
|
| 1637 |
+
metadata["project_id"] = projectID
|
| 1638 |
+
}
|
| 1639 |
+
|
| 1640 |
+
fileName := sanitizeAntigravityFileName(email)
|
| 1641 |
+
label := strings.TrimSpace(email)
|
| 1642 |
+
if label == "" {
|
| 1643 |
+
label = "antigravity"
|
| 1644 |
+
}
|
| 1645 |
+
|
| 1646 |
+
record := &coreauth.Auth{
|
| 1647 |
+
ID: fileName,
|
| 1648 |
+
Provider: "antigravity",
|
| 1649 |
+
FileName: fileName,
|
| 1650 |
+
Label: label,
|
| 1651 |
+
Metadata: metadata,
|
| 1652 |
+
}
|
| 1653 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1654 |
+
if errSave != nil {
|
| 1655 |
+
log.Errorf("Failed to save token to file: %v", errSave)
|
| 1656 |
+
SetOAuthSessionError(state, "Failed to save token to file")
|
| 1657 |
+
return
|
| 1658 |
+
}
|
| 1659 |
+
|
| 1660 |
+
CompleteOAuthSession(state)
|
| 1661 |
+
CompleteOAuthSessionsByProvider("antigravity")
|
| 1662 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 1663 |
+
if projectID != "" {
|
| 1664 |
+
fmt.Printf("Using GCP project: %s\n", projectID)
|
| 1665 |
+
}
|
| 1666 |
+
fmt.Println("You can now use Antigravity services through this CLI")
|
| 1667 |
+
}()
|
| 1668 |
+
|
| 1669 |
+
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1670 |
+
}
|
| 1671 |
+
|
| 1672 |
+
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
| 1673 |
+
ctx := context.Background()
|
| 1674 |
+
|
| 1675 |
+
fmt.Println("Initializing Qwen authentication...")
|
| 1676 |
+
|
| 1677 |
+
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
| 1678 |
+
// Initialize Qwen auth service
|
| 1679 |
+
qwenAuth := qwen.NewQwenAuth(h.cfg)
|
| 1680 |
+
|
| 1681 |
+
// Generate authorization URL
|
| 1682 |
+
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
|
| 1683 |
+
if err != nil {
|
| 1684 |
+
log.Errorf("Failed to generate authorization URL: %v", err)
|
| 1685 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
| 1686 |
+
return
|
| 1687 |
+
}
|
| 1688 |
+
authURL := deviceFlow.VerificationURIComplete
|
| 1689 |
+
|
| 1690 |
+
RegisterOAuthSession(state, "qwen")
|
| 1691 |
+
|
| 1692 |
+
go func() {
|
| 1693 |
+
fmt.Println("Waiting for authentication...")
|
| 1694 |
+
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
| 1695 |
+
if errPollForToken != nil {
|
| 1696 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1697 |
+
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
| 1698 |
+
return
|
| 1699 |
+
}
|
| 1700 |
+
|
| 1701 |
+
// Create token storage
|
| 1702 |
+
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
|
| 1703 |
+
|
| 1704 |
+
tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
|
| 1705 |
+
record := &coreauth.Auth{
|
| 1706 |
+
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
| 1707 |
+
Provider: "qwen",
|
| 1708 |
+
FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
|
| 1709 |
+
Storage: tokenStorage,
|
| 1710 |
+
Metadata: map[string]any{"email": tokenStorage.Email},
|
| 1711 |
+
}
|
| 1712 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1713 |
+
if errSave != nil {
|
| 1714 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 1715 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 1716 |
+
return
|
| 1717 |
+
}
|
| 1718 |
+
|
| 1719 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 1720 |
+
fmt.Println("You can now use Qwen services through this CLI")
|
| 1721 |
+
CompleteOAuthSession(state)
|
| 1722 |
+
}()
|
| 1723 |
+
|
| 1724 |
+
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1725 |
+
}
|
| 1726 |
+
|
| 1727 |
+
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
| 1728 |
+
ctx := context.Background()
|
| 1729 |
+
|
| 1730 |
+
fmt.Println("Initializing iFlow authentication...")
|
| 1731 |
+
|
| 1732 |
+
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
|
| 1733 |
+
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
| 1734 |
+
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
| 1735 |
+
|
| 1736 |
+
RegisterOAuthSession(state, "iflow")
|
| 1737 |
+
|
| 1738 |
+
isWebUI := isWebUIRequest(c)
|
| 1739 |
+
var forwarder *callbackForwarder
|
| 1740 |
+
if isWebUI {
|
| 1741 |
+
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
| 1742 |
+
if errTarget != nil {
|
| 1743 |
+
log.WithError(errTarget).Error("failed to compute iflow callback target")
|
| 1744 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
|
| 1745 |
+
return
|
| 1746 |
+
}
|
| 1747 |
+
var errStart error
|
| 1748 |
+
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
| 1749 |
+
log.WithError(errStart).Error("failed to start iflow callback forwarder")
|
| 1750 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
|
| 1751 |
+
return
|
| 1752 |
+
}
|
| 1753 |
+
}
|
| 1754 |
+
|
| 1755 |
+
go func() {
|
| 1756 |
+
if isWebUI {
|
| 1757 |
+
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
|
| 1758 |
+
}
|
| 1759 |
+
fmt.Println("Waiting for authentication...")
|
| 1760 |
+
|
| 1761 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
|
| 1762 |
+
deadline := time.Now().Add(5 * time.Minute)
|
| 1763 |
+
var resultMap map[string]string
|
| 1764 |
+
for {
|
| 1765 |
+
if !IsOAuthSessionPending(state, "iflow") {
|
| 1766 |
+
return
|
| 1767 |
+
}
|
| 1768 |
+
if time.Now().After(deadline) {
|
| 1769 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1770 |
+
fmt.Println("Authentication failed: timeout waiting for callback")
|
| 1771 |
+
return
|
| 1772 |
+
}
|
| 1773 |
+
if data, errR := os.ReadFile(waitFile); errR == nil {
|
| 1774 |
+
_ = os.Remove(waitFile)
|
| 1775 |
+
_ = json.Unmarshal(data, &resultMap)
|
| 1776 |
+
break
|
| 1777 |
+
}
|
| 1778 |
+
time.Sleep(500 * time.Millisecond)
|
| 1779 |
+
}
|
| 1780 |
+
|
| 1781 |
+
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
| 1782 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1783 |
+
fmt.Printf("Authentication failed: %s\n", errStr)
|
| 1784 |
+
return
|
| 1785 |
+
}
|
| 1786 |
+
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
| 1787 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1788 |
+
fmt.Println("Authentication failed: state mismatch")
|
| 1789 |
+
return
|
| 1790 |
+
}
|
| 1791 |
+
|
| 1792 |
+
code := strings.TrimSpace(resultMap["code"])
|
| 1793 |
+
if code == "" {
|
| 1794 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1795 |
+
fmt.Println("Authentication failed: code missing")
|
| 1796 |
+
return
|
| 1797 |
+
}
|
| 1798 |
+
|
| 1799 |
+
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
| 1800 |
+
if errExchange != nil {
|
| 1801 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 1802 |
+
fmt.Printf("Authentication failed: %v\n", errExchange)
|
| 1803 |
+
return
|
| 1804 |
+
}
|
| 1805 |
+
|
| 1806 |
+
tokenStorage := authSvc.CreateTokenStorage(tokenData)
|
| 1807 |
+
identifier := strings.TrimSpace(tokenStorage.Email)
|
| 1808 |
+
if identifier == "" {
|
| 1809 |
+
identifier = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
| 1810 |
+
tokenStorage.Email = identifier
|
| 1811 |
+
}
|
| 1812 |
+
record := &coreauth.Auth{
|
| 1813 |
+
ID: fmt.Sprintf("iflow-%s.json", identifier),
|
| 1814 |
+
Provider: "iflow",
|
| 1815 |
+
FileName: fmt.Sprintf("iflow-%s.json", identifier),
|
| 1816 |
+
Storage: tokenStorage,
|
| 1817 |
+
Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
|
| 1818 |
+
Attributes: map[string]string{"api_key": tokenStorage.APIKey},
|
| 1819 |
+
}
|
| 1820 |
+
|
| 1821 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1822 |
+
if errSave != nil {
|
| 1823 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 1824 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 1825 |
+
return
|
| 1826 |
+
}
|
| 1827 |
+
|
| 1828 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 1829 |
+
if tokenStorage.APIKey != "" {
|
| 1830 |
+
fmt.Println("API key obtained and saved")
|
| 1831 |
+
}
|
| 1832 |
+
fmt.Println("You can now use iFlow services through this CLI")
|
| 1833 |
+
CompleteOAuthSession(state)
|
| 1834 |
+
CompleteOAuthSessionsByProvider("iflow")
|
| 1835 |
+
}()
|
| 1836 |
+
|
| 1837 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
| 1838 |
+
}
|
| 1839 |
+
|
| 1840 |
+
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
|
| 1841 |
+
ctx := context.Background()
|
| 1842 |
+
|
| 1843 |
+
var payload struct {
|
| 1844 |
+
Cookie string `json:"cookie"`
|
| 1845 |
+
}
|
| 1846 |
+
if err := c.ShouldBindJSON(&payload); err != nil {
|
| 1847 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
|
| 1848 |
+
return
|
| 1849 |
+
}
|
| 1850 |
+
|
| 1851 |
+
cookieValue := strings.TrimSpace(payload.Cookie)
|
| 1852 |
+
|
| 1853 |
+
if cookieValue == "" {
|
| 1854 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
|
| 1855 |
+
return
|
| 1856 |
+
}
|
| 1857 |
+
|
| 1858 |
+
cookieValue, errNormalize := iflowauth.NormalizeCookie(cookieValue)
|
| 1859 |
+
if errNormalize != nil {
|
| 1860 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errNormalize.Error()})
|
| 1861 |
+
return
|
| 1862 |
+
}
|
| 1863 |
+
|
| 1864 |
+
// Check for duplicate BXAuth before authentication
|
| 1865 |
+
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
|
| 1866 |
+
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
|
| 1867 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
|
| 1868 |
+
return
|
| 1869 |
+
} else if existingFile != "" {
|
| 1870 |
+
existingFileName := filepath.Base(existingFile)
|
| 1871 |
+
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
|
| 1872 |
+
return
|
| 1873 |
+
}
|
| 1874 |
+
|
| 1875 |
+
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
| 1876 |
+
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
|
| 1877 |
+
if errAuth != nil {
|
| 1878 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errAuth.Error()})
|
| 1879 |
+
return
|
| 1880 |
+
}
|
| 1881 |
+
|
| 1882 |
+
tokenData.Cookie = cookieValue
|
| 1883 |
+
|
| 1884 |
+
tokenStorage := authSvc.CreateCookieTokenStorage(tokenData)
|
| 1885 |
+
email := strings.TrimSpace(tokenStorage.Email)
|
| 1886 |
+
if email == "" {
|
| 1887 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "failed to extract email from token"})
|
| 1888 |
+
return
|
| 1889 |
+
}
|
| 1890 |
+
|
| 1891 |
+
fileName := iflowauth.SanitizeIFlowFileName(email)
|
| 1892 |
+
if fileName == "" {
|
| 1893 |
+
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
|
| 1894 |
+
}
|
| 1895 |
+
|
| 1896 |
+
tokenStorage.Email = email
|
| 1897 |
+
timestamp := time.Now().Unix()
|
| 1898 |
+
|
| 1899 |
+
record := &coreauth.Auth{
|
| 1900 |
+
ID: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
| 1901 |
+
Provider: "iflow",
|
| 1902 |
+
FileName: fmt.Sprintf("iflow-%s-%d.json", fileName, timestamp),
|
| 1903 |
+
Storage: tokenStorage,
|
| 1904 |
+
Metadata: map[string]any{
|
| 1905 |
+
"email": email,
|
| 1906 |
+
"api_key": tokenStorage.APIKey,
|
| 1907 |
+
"expired": tokenStorage.Expire,
|
| 1908 |
+
"cookie": tokenStorage.Cookie,
|
| 1909 |
+
"type": tokenStorage.Type,
|
| 1910 |
+
"last_refresh": tokenStorage.LastRefresh,
|
| 1911 |
+
},
|
| 1912 |
+
Attributes: map[string]string{
|
| 1913 |
+
"api_key": tokenStorage.APIKey,
|
| 1914 |
+
},
|
| 1915 |
+
}
|
| 1916 |
+
|
| 1917 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 1918 |
+
if errSave != nil {
|
| 1919 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
|
| 1920 |
+
return
|
| 1921 |
+
}
|
| 1922 |
+
|
| 1923 |
+
fmt.Printf("iFlow cookie authentication successful. Token saved to %s\n", savedPath)
|
| 1924 |
+
c.JSON(http.StatusOK, gin.H{
|
| 1925 |
+
"status": "ok",
|
| 1926 |
+
"saved_path": savedPath,
|
| 1927 |
+
"email": email,
|
| 1928 |
+
"expired": tokenStorage.Expire,
|
| 1929 |
+
"type": tokenStorage.Type,
|
| 1930 |
+
})
|
| 1931 |
+
}
|
| 1932 |
+
|
| 1933 |
+
type projectSelectionRequiredError struct{}
|
| 1934 |
+
|
| 1935 |
+
func (e *projectSelectionRequiredError) Error() string {
|
| 1936 |
+
return "gemini cli: project selection required"
|
| 1937 |
+
}
|
| 1938 |
+
|
| 1939 |
+
func ensureGeminiProjectAndOnboard(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
|
| 1940 |
+
if storage == nil {
|
| 1941 |
+
return fmt.Errorf("gemini storage is nil")
|
| 1942 |
+
}
|
| 1943 |
+
|
| 1944 |
+
trimmedRequest := strings.TrimSpace(requestedProject)
|
| 1945 |
+
if trimmedRequest == "" {
|
| 1946 |
+
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
| 1947 |
+
if errProjects != nil {
|
| 1948 |
+
return fmt.Errorf("fetch project list: %w", errProjects)
|
| 1949 |
+
}
|
| 1950 |
+
if len(projects) == 0 {
|
| 1951 |
+
return fmt.Errorf("no Google Cloud projects available for this account")
|
| 1952 |
+
}
|
| 1953 |
+
trimmedRequest = strings.TrimSpace(projects[0].ProjectID)
|
| 1954 |
+
if trimmedRequest == "" {
|
| 1955 |
+
return fmt.Errorf("resolved project id is empty")
|
| 1956 |
+
}
|
| 1957 |
+
storage.Auto = true
|
| 1958 |
+
} else {
|
| 1959 |
+
storage.Auto = false
|
| 1960 |
+
}
|
| 1961 |
+
|
| 1962 |
+
if err := performGeminiCLISetup(ctx, httpClient, storage, trimmedRequest); err != nil {
|
| 1963 |
+
return err
|
| 1964 |
+
}
|
| 1965 |
+
|
| 1966 |
+
if strings.TrimSpace(storage.ProjectID) == "" {
|
| 1967 |
+
storage.ProjectID = trimmedRequest
|
| 1968 |
+
}
|
| 1969 |
+
|
| 1970 |
+
return nil
|
| 1971 |
+
}
|
| 1972 |
+
|
| 1973 |
+
func onboardAllGeminiProjects(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage) ([]string, error) {
|
| 1974 |
+
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
| 1975 |
+
if errProjects != nil {
|
| 1976 |
+
return nil, fmt.Errorf("fetch project list: %w", errProjects)
|
| 1977 |
+
}
|
| 1978 |
+
if len(projects) == 0 {
|
| 1979 |
+
return nil, fmt.Errorf("no Google Cloud projects available for this account")
|
| 1980 |
+
}
|
| 1981 |
+
activated := make([]string, 0, len(projects))
|
| 1982 |
+
seen := make(map[string]struct{}, len(projects))
|
| 1983 |
+
for _, project := range projects {
|
| 1984 |
+
candidate := strings.TrimSpace(project.ProjectID)
|
| 1985 |
+
if candidate == "" {
|
| 1986 |
+
continue
|
| 1987 |
+
}
|
| 1988 |
+
if _, dup := seen[candidate]; dup {
|
| 1989 |
+
continue
|
| 1990 |
+
}
|
| 1991 |
+
if err := performGeminiCLISetup(ctx, httpClient, storage, candidate); err != nil {
|
| 1992 |
+
return nil, fmt.Errorf("onboard project %s: %w", candidate, err)
|
| 1993 |
+
}
|
| 1994 |
+
finalID := strings.TrimSpace(storage.ProjectID)
|
| 1995 |
+
if finalID == "" {
|
| 1996 |
+
finalID = candidate
|
| 1997 |
+
}
|
| 1998 |
+
activated = append(activated, finalID)
|
| 1999 |
+
seen[candidate] = struct{}{}
|
| 2000 |
+
}
|
| 2001 |
+
if len(activated) == 0 {
|
| 2002 |
+
return nil, fmt.Errorf("no Google Cloud projects available for this account")
|
| 2003 |
+
}
|
| 2004 |
+
return activated, nil
|
| 2005 |
+
}
|
| 2006 |
+
|
| 2007 |
+
func ensureGeminiProjectsEnabled(ctx context.Context, httpClient *http.Client, projectIDs []string) error {
|
| 2008 |
+
for _, pid := range projectIDs {
|
| 2009 |
+
trimmed := strings.TrimSpace(pid)
|
| 2010 |
+
if trimmed == "" {
|
| 2011 |
+
continue
|
| 2012 |
+
}
|
| 2013 |
+
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, trimmed)
|
| 2014 |
+
if errCheck != nil {
|
| 2015 |
+
return fmt.Errorf("project %s: %w", trimmed, errCheck)
|
| 2016 |
+
}
|
| 2017 |
+
if !isChecked {
|
| 2018 |
+
return fmt.Errorf("project %s: Cloud AI API not enabled", trimmed)
|
| 2019 |
+
}
|
| 2020 |
+
}
|
| 2021 |
+
return nil
|
| 2022 |
+
}
|
| 2023 |
+
|
| 2024 |
+
func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
|
| 2025 |
+
metadata := map[string]string{
|
| 2026 |
+
"ideType": "IDE_UNSPECIFIED",
|
| 2027 |
+
"platform": "PLATFORM_UNSPECIFIED",
|
| 2028 |
+
"pluginType": "GEMINI",
|
| 2029 |
+
}
|
| 2030 |
+
|
| 2031 |
+
trimmedRequest := strings.TrimSpace(requestedProject)
|
| 2032 |
+
explicitProject := trimmedRequest != ""
|
| 2033 |
+
|
| 2034 |
+
loadReqBody := map[string]any{
|
| 2035 |
+
"metadata": metadata,
|
| 2036 |
+
}
|
| 2037 |
+
if explicitProject {
|
| 2038 |
+
loadReqBody["cloudaicompanionProject"] = trimmedRequest
|
| 2039 |
+
}
|
| 2040 |
+
|
| 2041 |
+
var loadResp map[string]any
|
| 2042 |
+
if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil {
|
| 2043 |
+
return fmt.Errorf("load code assist: %w", errLoad)
|
| 2044 |
+
}
|
| 2045 |
+
|
| 2046 |
+
tierID := "legacy-tier"
|
| 2047 |
+
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
| 2048 |
+
for _, rawTier := range tiers {
|
| 2049 |
+
tier, okTier := rawTier.(map[string]any)
|
| 2050 |
+
if !okTier {
|
| 2051 |
+
continue
|
| 2052 |
+
}
|
| 2053 |
+
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
| 2054 |
+
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
| 2055 |
+
tierID = strings.TrimSpace(id)
|
| 2056 |
+
break
|
| 2057 |
+
}
|
| 2058 |
+
}
|
| 2059 |
+
}
|
| 2060 |
+
}
|
| 2061 |
+
|
| 2062 |
+
projectID := trimmedRequest
|
| 2063 |
+
if projectID == "" {
|
| 2064 |
+
if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject {
|
| 2065 |
+
projectID = strings.TrimSpace(id)
|
| 2066 |
+
}
|
| 2067 |
+
if projectID == "" {
|
| 2068 |
+
if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject {
|
| 2069 |
+
if id, okID := projectMap["id"].(string); okID {
|
| 2070 |
+
projectID = strings.TrimSpace(id)
|
| 2071 |
+
}
|
| 2072 |
+
}
|
| 2073 |
+
}
|
| 2074 |
+
}
|
| 2075 |
+
if projectID == "" {
|
| 2076 |
+
return &projectSelectionRequiredError{}
|
| 2077 |
+
}
|
| 2078 |
+
|
| 2079 |
+
onboardReqBody := map[string]any{
|
| 2080 |
+
"tierId": tierID,
|
| 2081 |
+
"metadata": metadata,
|
| 2082 |
+
"cloudaicompanionProject": projectID,
|
| 2083 |
+
}
|
| 2084 |
+
|
| 2085 |
+
storage.ProjectID = projectID
|
| 2086 |
+
|
| 2087 |
+
for {
|
| 2088 |
+
var onboardResp map[string]any
|
| 2089 |
+
if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil {
|
| 2090 |
+
return fmt.Errorf("onboard user: %w", errOnboard)
|
| 2091 |
+
}
|
| 2092 |
+
|
| 2093 |
+
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
| 2094 |
+
responseProjectID := ""
|
| 2095 |
+
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
| 2096 |
+
switch projectValue := resp["cloudaicompanionProject"].(type) {
|
| 2097 |
+
case map[string]any:
|
| 2098 |
+
if id, okID := projectValue["id"].(string); okID {
|
| 2099 |
+
responseProjectID = strings.TrimSpace(id)
|
| 2100 |
+
}
|
| 2101 |
+
case string:
|
| 2102 |
+
responseProjectID = strings.TrimSpace(projectValue)
|
| 2103 |
+
}
|
| 2104 |
+
}
|
| 2105 |
+
|
| 2106 |
+
finalProjectID := projectID
|
| 2107 |
+
if responseProjectID != "" {
|
| 2108 |
+
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
| 2109 |
+
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
| 2110 |
+
} else {
|
| 2111 |
+
finalProjectID = responseProjectID
|
| 2112 |
+
}
|
| 2113 |
+
}
|
| 2114 |
+
|
| 2115 |
+
storage.ProjectID = strings.TrimSpace(finalProjectID)
|
| 2116 |
+
if storage.ProjectID == "" {
|
| 2117 |
+
storage.ProjectID = strings.TrimSpace(projectID)
|
| 2118 |
+
}
|
| 2119 |
+
if storage.ProjectID == "" {
|
| 2120 |
+
return fmt.Errorf("onboard user completed without project id")
|
| 2121 |
+
}
|
| 2122 |
+
log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID)
|
| 2123 |
+
return nil
|
| 2124 |
+
}
|
| 2125 |
+
|
| 2126 |
+
log.Println("Onboarding in progress, waiting 5 seconds...")
|
| 2127 |
+
time.Sleep(5 * time.Second)
|
| 2128 |
+
}
|
| 2129 |
+
}
|
| 2130 |
+
|
| 2131 |
+
func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error {
|
| 2132 |
+
endPointURL := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint)
|
| 2133 |
+
if strings.HasPrefix(endpoint, "operations/") {
|
| 2134 |
+
endPointURL = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint)
|
| 2135 |
+
}
|
| 2136 |
+
|
| 2137 |
+
var reader io.Reader
|
| 2138 |
+
if body != nil {
|
| 2139 |
+
rawBody, errMarshal := json.Marshal(body)
|
| 2140 |
+
if errMarshal != nil {
|
| 2141 |
+
return fmt.Errorf("marshal request body: %w", errMarshal)
|
| 2142 |
+
}
|
| 2143 |
+
reader = bytes.NewReader(rawBody)
|
| 2144 |
+
}
|
| 2145 |
+
|
| 2146 |
+
req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, endPointURL, reader)
|
| 2147 |
+
if errRequest != nil {
|
| 2148 |
+
return fmt.Errorf("create request: %w", errRequest)
|
| 2149 |
+
}
|
| 2150 |
+
req.Header.Set("Content-Type", "application/json")
|
| 2151 |
+
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
| 2152 |
+
req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient)
|
| 2153 |
+
req.Header.Set("Client-Metadata", geminiCLIClientMetadata)
|
| 2154 |
+
|
| 2155 |
+
resp, errDo := httpClient.Do(req)
|
| 2156 |
+
if errDo != nil {
|
| 2157 |
+
return fmt.Errorf("execute request: %w", errDo)
|
| 2158 |
+
}
|
| 2159 |
+
defer func() {
|
| 2160 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 2161 |
+
log.Errorf("response body close error: %v", errClose)
|
| 2162 |
+
}
|
| 2163 |
+
}()
|
| 2164 |
+
|
| 2165 |
+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
| 2166 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 2167 |
+
return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
| 2168 |
+
}
|
| 2169 |
+
|
| 2170 |
+
if result == nil {
|
| 2171 |
+
_, _ = io.Copy(io.Discard, resp.Body)
|
| 2172 |
+
return nil
|
| 2173 |
+
}
|
| 2174 |
+
|
| 2175 |
+
if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil {
|
| 2176 |
+
return fmt.Errorf("decode response body: %w", errDecode)
|
| 2177 |
+
}
|
| 2178 |
+
|
| 2179 |
+
return nil
|
| 2180 |
+
}
|
| 2181 |
+
|
| 2182 |
+
func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) {
|
| 2183 |
+
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
| 2184 |
+
if errRequest != nil {
|
| 2185 |
+
return nil, fmt.Errorf("could not create project list request: %w", errRequest)
|
| 2186 |
+
}
|
| 2187 |
+
|
| 2188 |
+
resp, errDo := httpClient.Do(req)
|
| 2189 |
+
if errDo != nil {
|
| 2190 |
+
return nil, fmt.Errorf("failed to execute project list request: %w", errDo)
|
| 2191 |
+
}
|
| 2192 |
+
defer func() {
|
| 2193 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 2194 |
+
log.Errorf("response body close error: %v", errClose)
|
| 2195 |
+
}
|
| 2196 |
+
}()
|
| 2197 |
+
|
| 2198 |
+
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
| 2199 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 2200 |
+
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
| 2201 |
+
}
|
| 2202 |
+
|
| 2203 |
+
var projects interfaces.GCPProject
|
| 2204 |
+
if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil {
|
| 2205 |
+
return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode)
|
| 2206 |
+
}
|
| 2207 |
+
|
| 2208 |
+
return projects.Projects, nil
|
| 2209 |
+
}
|
| 2210 |
+
|
| 2211 |
+
func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) {
|
| 2212 |
+
serviceUsageURL := "https://serviceusage.googleapis.com"
|
| 2213 |
+
requiredServices := []string{
|
| 2214 |
+
"cloudaicompanion.googleapis.com",
|
| 2215 |
+
}
|
| 2216 |
+
for _, service := range requiredServices {
|
| 2217 |
+
checkURL := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service)
|
| 2218 |
+
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkURL, nil)
|
| 2219 |
+
if errRequest != nil {
|
| 2220 |
+
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
| 2221 |
+
}
|
| 2222 |
+
req.Header.Set("Content-Type", "application/json")
|
| 2223 |
+
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
| 2224 |
+
resp, errDo := httpClient.Do(req)
|
| 2225 |
+
if errDo != nil {
|
| 2226 |
+
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
| 2227 |
+
}
|
| 2228 |
+
|
| 2229 |
+
if resp.StatusCode == http.StatusOK {
|
| 2230 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 2231 |
+
if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" {
|
| 2232 |
+
_ = resp.Body.Close()
|
| 2233 |
+
continue
|
| 2234 |
+
}
|
| 2235 |
+
}
|
| 2236 |
+
_ = resp.Body.Close()
|
| 2237 |
+
|
| 2238 |
+
enableURL := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service)
|
| 2239 |
+
req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableURL, strings.NewReader("{}"))
|
| 2240 |
+
if errRequest != nil {
|
| 2241 |
+
return false, fmt.Errorf("failed to create request: %w", errRequest)
|
| 2242 |
+
}
|
| 2243 |
+
req.Header.Set("Content-Type", "application/json")
|
| 2244 |
+
req.Header.Set("User-Agent", geminiCLIUserAgent)
|
| 2245 |
+
resp, errDo = httpClient.Do(req)
|
| 2246 |
+
if errDo != nil {
|
| 2247 |
+
return false, fmt.Errorf("failed to execute request: %w", errDo)
|
| 2248 |
+
}
|
| 2249 |
+
|
| 2250 |
+
bodyBytes, _ := io.ReadAll(resp.Body)
|
| 2251 |
+
errMessage := string(bodyBytes)
|
| 2252 |
+
errMessageResult := gjson.GetBytes(bodyBytes, "error.message")
|
| 2253 |
+
if errMessageResult.Exists() {
|
| 2254 |
+
errMessage = errMessageResult.String()
|
| 2255 |
+
}
|
| 2256 |
+
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
|
| 2257 |
+
_ = resp.Body.Close()
|
| 2258 |
+
continue
|
| 2259 |
+
} else if resp.StatusCode == http.StatusBadRequest {
|
| 2260 |
+
_ = resp.Body.Close()
|
| 2261 |
+
if strings.Contains(strings.ToLower(errMessage), "already enabled") {
|
| 2262 |
+
continue
|
| 2263 |
+
}
|
| 2264 |
+
}
|
| 2265 |
+
_ = resp.Body.Close()
|
| 2266 |
+
return false, fmt.Errorf("project activation required: %s", errMessage)
|
| 2267 |
+
}
|
| 2268 |
+
return true, nil
|
| 2269 |
+
}
|
| 2270 |
+
|
| 2271 |
+
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
| 2272 |
+
state := strings.TrimSpace(c.Query("state"))
|
| 2273 |
+
if state == "" {
|
| 2274 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 2275 |
+
return
|
| 2276 |
+
}
|
| 2277 |
+
if err := ValidateOAuthState(state); err != nil {
|
| 2278 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
| 2279 |
+
return
|
| 2280 |
+
}
|
| 2281 |
+
|
| 2282 |
+
_, status, ok := GetOAuthSession(state)
|
| 2283 |
+
if !ok {
|
| 2284 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 2285 |
+
return
|
| 2286 |
+
}
|
| 2287 |
+
if status != "" {
|
| 2288 |
+
if strings.HasPrefix(status, "device_code|") {
|
| 2289 |
+
parts := strings.SplitN(status, "|", 3)
|
| 2290 |
+
if len(parts) == 3 {
|
| 2291 |
+
c.JSON(http.StatusOK, gin.H{
|
| 2292 |
+
"status": "device_code",
|
| 2293 |
+
"verification_url": parts[1],
|
| 2294 |
+
"user_code": parts[2],
|
| 2295 |
+
})
|
| 2296 |
+
return
|
| 2297 |
+
}
|
| 2298 |
+
}
|
| 2299 |
+
if strings.HasPrefix(status, "auth_url|") {
|
| 2300 |
+
authURL := strings.TrimPrefix(status, "auth_url|")
|
| 2301 |
+
c.JSON(http.StatusOK, gin.H{
|
| 2302 |
+
"status": "auth_url",
|
| 2303 |
+
"url": authURL,
|
| 2304 |
+
})
|
| 2305 |
+
return
|
| 2306 |
+
}
|
| 2307 |
+
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
| 2308 |
+
return
|
| 2309 |
+
}
|
| 2310 |
+
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
| 2311 |
+
}
|
| 2312 |
+
|
| 2313 |
+
const kiroCallbackPort = 9876
|
| 2314 |
+
|
| 2315 |
+
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
| 2316 |
+
ctx := context.Background()
|
| 2317 |
+
|
| 2318 |
+
// Get the login method from query parameter (default: aws for device code flow)
|
| 2319 |
+
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
| 2320 |
+
if method == "" {
|
| 2321 |
+
method = "aws"
|
| 2322 |
+
}
|
| 2323 |
+
|
| 2324 |
+
fmt.Println("Initializing Kiro authentication...")
|
| 2325 |
+
|
| 2326 |
+
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
| 2327 |
+
|
| 2328 |
+
switch method {
|
| 2329 |
+
case "aws", "builder-id":
|
| 2330 |
+
RegisterOAuthSession(state, "kiro")
|
| 2331 |
+
|
| 2332 |
+
// AWS Builder ID uses device code flow (no callback needed)
|
| 2333 |
+
go func() {
|
| 2334 |
+
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
| 2335 |
+
|
| 2336 |
+
// Step 1: Register client
|
| 2337 |
+
fmt.Println("Registering client...")
|
| 2338 |
+
regResp, errRegister := ssoClient.RegisterClient(ctx)
|
| 2339 |
+
if errRegister != nil {
|
| 2340 |
+
log.Errorf("Failed to register client: %v", errRegister)
|
| 2341 |
+
SetOAuthSessionError(state, "Failed to register client")
|
| 2342 |
+
return
|
| 2343 |
+
}
|
| 2344 |
+
|
| 2345 |
+
// Step 2: Start device authorization
|
| 2346 |
+
fmt.Println("Starting device authorization...")
|
| 2347 |
+
authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
| 2348 |
+
if errAuth != nil {
|
| 2349 |
+
log.Errorf("Failed to start device auth: %v", errAuth)
|
| 2350 |
+
SetOAuthSessionError(state, "Failed to start device authorization")
|
| 2351 |
+
return
|
| 2352 |
+
}
|
| 2353 |
+
|
| 2354 |
+
// Store the verification URL for the frontend to display.
|
| 2355 |
+
// Using "|" as separator because URLs contain ":".
|
| 2356 |
+
SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
| 2357 |
+
|
| 2358 |
+
// Step 3: Poll for token
|
| 2359 |
+
fmt.Println("Waiting for authorization...")
|
| 2360 |
+
interval := 5 * time.Second
|
| 2361 |
+
if authResp.Interval > 0 {
|
| 2362 |
+
interval = time.Duration(authResp.Interval) * time.Second
|
| 2363 |
+
}
|
| 2364 |
+
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
| 2365 |
+
|
| 2366 |
+
for time.Now().Before(deadline) {
|
| 2367 |
+
select {
|
| 2368 |
+
case <-ctx.Done():
|
| 2369 |
+
SetOAuthSessionError(state, "Authorization cancelled")
|
| 2370 |
+
return
|
| 2371 |
+
case <-time.After(interval):
|
| 2372 |
+
tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
| 2373 |
+
if errToken != nil {
|
| 2374 |
+
errStr := errToken.Error()
|
| 2375 |
+
if strings.Contains(errStr, "authorization_pending") {
|
| 2376 |
+
continue
|
| 2377 |
+
}
|
| 2378 |
+
if strings.Contains(errStr, "slow_down") {
|
| 2379 |
+
interval += 5 * time.Second
|
| 2380 |
+
continue
|
| 2381 |
+
}
|
| 2382 |
+
log.Errorf("Token creation failed: %v", errToken)
|
| 2383 |
+
SetOAuthSessionError(state, "Token creation failed")
|
| 2384 |
+
return
|
| 2385 |
+
}
|
| 2386 |
+
|
| 2387 |
+
// Success! Save the token
|
| 2388 |
+
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
| 2389 |
+
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
| 2390 |
+
|
| 2391 |
+
idPart := kiroauth.SanitizeEmailForFilename(email)
|
| 2392 |
+
if idPart == "" {
|
| 2393 |
+
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
| 2394 |
+
}
|
| 2395 |
+
|
| 2396 |
+
now := time.Now()
|
| 2397 |
+
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
| 2398 |
+
|
| 2399 |
+
record := &coreauth.Auth{
|
| 2400 |
+
ID: fileName,
|
| 2401 |
+
Provider: "kiro",
|
| 2402 |
+
FileName: fileName,
|
| 2403 |
+
Metadata: map[string]any{
|
| 2404 |
+
"type": "kiro",
|
| 2405 |
+
"access_token": tokenResp.AccessToken,
|
| 2406 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 2407 |
+
"expires_at": expiresAt.Format(time.RFC3339),
|
| 2408 |
+
"auth_method": "builder-id",
|
| 2409 |
+
"provider": "AWS",
|
| 2410 |
+
"client_id": regResp.ClientID,
|
| 2411 |
+
"client_secret": regResp.ClientSecret,
|
| 2412 |
+
"email": email,
|
| 2413 |
+
"last_refresh": now.Format(time.RFC3339),
|
| 2414 |
+
},
|
| 2415 |
+
}
|
| 2416 |
+
|
| 2417 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 2418 |
+
if errSave != nil {
|
| 2419 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 2420 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 2421 |
+
return
|
| 2422 |
+
}
|
| 2423 |
+
|
| 2424 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 2425 |
+
if email != "" {
|
| 2426 |
+
fmt.Printf("Authenticated as: %s\n", email)
|
| 2427 |
+
}
|
| 2428 |
+
CompleteOAuthSession(state)
|
| 2429 |
+
return
|
| 2430 |
+
}
|
| 2431 |
+
}
|
| 2432 |
+
|
| 2433 |
+
SetOAuthSessionError(state, "Authorization timed out")
|
| 2434 |
+
}()
|
| 2435 |
+
|
| 2436 |
+
// Return immediately with the state for polling
|
| 2437 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
| 2438 |
+
|
| 2439 |
+
case "google", "github":
|
| 2440 |
+
RegisterOAuthSession(state, "kiro")
|
| 2441 |
+
|
| 2442 |
+
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
| 2443 |
+
provider := "Google"
|
| 2444 |
+
if method == "github" {
|
| 2445 |
+
provider = "Github"
|
| 2446 |
+
}
|
| 2447 |
+
|
| 2448 |
+
isWebUI := isWebUIRequest(c)
|
| 2449 |
+
if isWebUI {
|
| 2450 |
+
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
| 2451 |
+
if errTarget != nil {
|
| 2452 |
+
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
| 2453 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
| 2454 |
+
return
|
| 2455 |
+
}
|
| 2456 |
+
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
| 2457 |
+
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
| 2458 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
| 2459 |
+
return
|
| 2460 |
+
}
|
| 2461 |
+
}
|
| 2462 |
+
|
| 2463 |
+
go func() {
|
| 2464 |
+
if isWebUI {
|
| 2465 |
+
defer stopCallbackForwarder(kiroCallbackPort)
|
| 2466 |
+
}
|
| 2467 |
+
|
| 2468 |
+
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
| 2469 |
+
|
| 2470 |
+
// Generate PKCE codes
|
| 2471 |
+
codeVerifier, codeChallenge, errPKCE := generateKiroPKCE()
|
| 2472 |
+
if errPKCE != nil {
|
| 2473 |
+
log.Errorf("Failed to generate PKCE: %v", errPKCE)
|
| 2474 |
+
SetOAuthSessionError(state, "Failed to generate PKCE")
|
| 2475 |
+
return
|
| 2476 |
+
}
|
| 2477 |
+
|
| 2478 |
+
// Build login URL
|
| 2479 |
+
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
| 2480 |
+
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
| 2481 |
+
provider,
|
| 2482 |
+
url.QueryEscape(kiroauth.KiroRedirectURI),
|
| 2483 |
+
codeChallenge,
|
| 2484 |
+
state,
|
| 2485 |
+
)
|
| 2486 |
+
|
| 2487 |
+
// Store auth URL for frontend.
|
| 2488 |
+
// Using "|" as separator because URLs contain ":".
|
| 2489 |
+
SetOAuthSessionError(state, "auth_url|"+authURL)
|
| 2490 |
+
|
| 2491 |
+
// Wait for callback file
|
| 2492 |
+
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
| 2493 |
+
deadline := time.Now().Add(5 * time.Minute)
|
| 2494 |
+
|
| 2495 |
+
for {
|
| 2496 |
+
if time.Now().After(deadline) {
|
| 2497 |
+
log.Error("oauth flow timed out")
|
| 2498 |
+
SetOAuthSessionError(state, "OAuth flow timed out")
|
| 2499 |
+
return
|
| 2500 |
+
}
|
| 2501 |
+
if data, errRead := os.ReadFile(waitFile); errRead == nil {
|
| 2502 |
+
var m map[string]string
|
| 2503 |
+
_ = json.Unmarshal(data, &m)
|
| 2504 |
+
_ = os.Remove(waitFile)
|
| 2505 |
+
if errStr := m["error"]; errStr != "" {
|
| 2506 |
+
log.Errorf("Authentication failed: %s", errStr)
|
| 2507 |
+
SetOAuthSessionError(state, "Authentication failed")
|
| 2508 |
+
return
|
| 2509 |
+
}
|
| 2510 |
+
if m["state"] != state {
|
| 2511 |
+
log.Errorf("State mismatch")
|
| 2512 |
+
SetOAuthSessionError(state, "State mismatch")
|
| 2513 |
+
return
|
| 2514 |
+
}
|
| 2515 |
+
code := m["code"]
|
| 2516 |
+
if code == "" {
|
| 2517 |
+
log.Error("No authorization code received")
|
| 2518 |
+
SetOAuthSessionError(state, "No authorization code received")
|
| 2519 |
+
return
|
| 2520 |
+
}
|
| 2521 |
+
|
| 2522 |
+
// Exchange code for tokens
|
| 2523 |
+
tokenReq := &kiroauth.CreateTokenRequest{
|
| 2524 |
+
Code: code,
|
| 2525 |
+
CodeVerifier: codeVerifier,
|
| 2526 |
+
RedirectURI: kiroauth.KiroRedirectURI,
|
| 2527 |
+
}
|
| 2528 |
+
|
| 2529 |
+
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
| 2530 |
+
if errToken != nil {
|
| 2531 |
+
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
| 2532 |
+
SetOAuthSessionError(state, "Failed to exchange code for tokens")
|
| 2533 |
+
return
|
| 2534 |
+
}
|
| 2535 |
+
|
| 2536 |
+
// Save the token
|
| 2537 |
+
expiresIn := tokenResp.ExpiresIn
|
| 2538 |
+
if expiresIn <= 0 {
|
| 2539 |
+
expiresIn = 3600
|
| 2540 |
+
}
|
| 2541 |
+
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
| 2542 |
+
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
| 2543 |
+
|
| 2544 |
+
idPart := kiroauth.SanitizeEmailForFilename(email)
|
| 2545 |
+
if idPart == "" {
|
| 2546 |
+
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
| 2547 |
+
}
|
| 2548 |
+
|
| 2549 |
+
now := time.Now()
|
| 2550 |
+
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
| 2551 |
+
|
| 2552 |
+
record := &coreauth.Auth{
|
| 2553 |
+
ID: fileName,
|
| 2554 |
+
Provider: "kiro",
|
| 2555 |
+
FileName: fileName,
|
| 2556 |
+
Metadata: map[string]any{
|
| 2557 |
+
"type": "kiro",
|
| 2558 |
+
"access_token": tokenResp.AccessToken,
|
| 2559 |
+
"refresh_token": tokenResp.RefreshToken,
|
| 2560 |
+
"profile_arn": tokenResp.ProfileArn,
|
| 2561 |
+
"expires_at": expiresAt.Format(time.RFC3339),
|
| 2562 |
+
"auth_method": "social",
|
| 2563 |
+
"provider": provider,
|
| 2564 |
+
"email": email,
|
| 2565 |
+
"last_refresh": now.Format(time.RFC3339),
|
| 2566 |
+
},
|
| 2567 |
+
}
|
| 2568 |
+
|
| 2569 |
+
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
| 2570 |
+
if errSave != nil {
|
| 2571 |
+
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
| 2572 |
+
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
| 2573 |
+
return
|
| 2574 |
+
}
|
| 2575 |
+
|
| 2576 |
+
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
| 2577 |
+
if email != "" {
|
| 2578 |
+
fmt.Printf("Authenticated as: %s\n", email)
|
| 2579 |
+
}
|
| 2580 |
+
CompleteOAuthSession(state)
|
| 2581 |
+
return
|
| 2582 |
+
}
|
| 2583 |
+
time.Sleep(500 * time.Millisecond)
|
| 2584 |
+
}
|
| 2585 |
+
}()
|
| 2586 |
+
|
| 2587 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"})
|
| 2588 |
+
|
| 2589 |
+
default:
|
| 2590 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
| 2591 |
+
}
|
| 2592 |
+
}
|
| 2593 |
+
|
| 2594 |
+
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
| 2595 |
+
func generateKiroPKCE() (verifier, challenge string, err error) {
|
| 2596 |
+
b := make([]byte, 32)
|
| 2597 |
+
if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil {
|
| 2598 |
+
return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead)
|
| 2599 |
+
}
|
| 2600 |
+
verifier = base64.RawURLEncoding.EncodeToString(b)
|
| 2601 |
+
|
| 2602 |
+
h := sha256.Sum256([]byte(verifier))
|
| 2603 |
+
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
| 2604 |
+
|
| 2605 |
+
return verifier, challenge, nil
|
| 2606 |
+
}
|
internal/api/handlers/management/config_basic.go
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"io"
|
| 7 |
+
"net/http"
|
| 8 |
+
"os"
|
| 9 |
+
"path/filepath"
|
| 10 |
+
"strings"
|
| 11 |
+
"time"
|
| 12 |
+
|
| 13 |
+
"github.com/gin-gonic/gin"
|
| 14 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 15 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 16 |
+
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
| 17 |
+
log "github.com/sirupsen/logrus"
|
| 18 |
+
"gopkg.in/yaml.v3"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
const (
|
| 22 |
+
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
|
| 23 |
+
latestReleaseUserAgent = "CLIProxyAPIPlus"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
func (h *Handler) GetConfig(c *gin.Context) {
|
| 27 |
+
if h == nil || h.cfg == nil {
|
| 28 |
+
c.JSON(200, gin.H{})
|
| 29 |
+
return
|
| 30 |
+
}
|
| 31 |
+
cfgCopy := *h.cfg
|
| 32 |
+
c.JSON(200, &cfgCopy)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
type releaseInfo struct {
|
| 36 |
+
TagName string `json:"tag_name"`
|
| 37 |
+
Name string `json:"name"`
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
|
| 41 |
+
func (h *Handler) GetLatestVersion(c *gin.Context) {
|
| 42 |
+
client := &http.Client{Timeout: 10 * time.Second}
|
| 43 |
+
proxyURL := ""
|
| 44 |
+
if h != nil && h.cfg != nil {
|
| 45 |
+
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
|
| 46 |
+
}
|
| 47 |
+
if proxyURL != "" {
|
| 48 |
+
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
|
| 49 |
+
util.SetProxy(sdkCfg, client)
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
|
| 53 |
+
if err != nil {
|
| 54 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
|
| 55 |
+
return
|
| 56 |
+
}
|
| 57 |
+
req.Header.Set("Accept", "application/vnd.github+json")
|
| 58 |
+
req.Header.Set("User-Agent", latestReleaseUserAgent)
|
| 59 |
+
|
| 60 |
+
resp, err := client.Do(req)
|
| 61 |
+
if err != nil {
|
| 62 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
|
| 63 |
+
return
|
| 64 |
+
}
|
| 65 |
+
defer func() {
|
| 66 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 67 |
+
log.WithError(errClose).Debug("failed to close latest version response body")
|
| 68 |
+
}
|
| 69 |
+
}()
|
| 70 |
+
|
| 71 |
+
if resp.StatusCode != http.StatusOK {
|
| 72 |
+
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
| 73 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
|
| 74 |
+
return
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
var info releaseInfo
|
| 78 |
+
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
| 79 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
|
| 80 |
+
return
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
version := strings.TrimSpace(info.TagName)
|
| 84 |
+
if version == "" {
|
| 85 |
+
version = strings.TrimSpace(info.Name)
|
| 86 |
+
}
|
| 87 |
+
if version == "" {
|
| 88 |
+
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
|
| 89 |
+
return
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
c.JSON(http.StatusOK, gin.H{"latest-version": version})
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
func WriteConfig(path string, data []byte) error {
|
| 96 |
+
data = config.NormalizeCommentIndentation(data)
|
| 97 |
+
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
| 98 |
+
if err != nil {
|
| 99 |
+
return err
|
| 100 |
+
}
|
| 101 |
+
if _, errWrite := f.Write(data); errWrite != nil {
|
| 102 |
+
_ = f.Close()
|
| 103 |
+
return errWrite
|
| 104 |
+
}
|
| 105 |
+
if errSync := f.Sync(); errSync != nil {
|
| 106 |
+
_ = f.Close()
|
| 107 |
+
return errSync
|
| 108 |
+
}
|
| 109 |
+
return f.Close()
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
func (h *Handler) PutConfigYAML(c *gin.Context) {
|
| 113 |
+
body, err := io.ReadAll(c.Request.Body)
|
| 114 |
+
if err != nil {
|
| 115 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": "cannot read request body"})
|
| 116 |
+
return
|
| 117 |
+
}
|
| 118 |
+
var cfg config.Config
|
| 119 |
+
if err = yaml.Unmarshal(body, &cfg); err != nil {
|
| 120 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
|
| 121 |
+
return
|
| 122 |
+
}
|
| 123 |
+
// Validate config using LoadConfigOptional with optional=false to enforce parsing
|
| 124 |
+
tmpDir := filepath.Dir(h.configFilePath)
|
| 125 |
+
tmpFile, err := os.CreateTemp(tmpDir, "config-validate-*.yaml")
|
| 126 |
+
if err != nil {
|
| 127 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
| 128 |
+
return
|
| 129 |
+
}
|
| 130 |
+
tempFile := tmpFile.Name()
|
| 131 |
+
if _, errWrite := tmpFile.Write(body); errWrite != nil {
|
| 132 |
+
_ = tmpFile.Close()
|
| 133 |
+
_ = os.Remove(tempFile)
|
| 134 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
|
| 135 |
+
return
|
| 136 |
+
}
|
| 137 |
+
if errClose := tmpFile.Close(); errClose != nil {
|
| 138 |
+
_ = os.Remove(tempFile)
|
| 139 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
|
| 140 |
+
return
|
| 141 |
+
}
|
| 142 |
+
defer func() {
|
| 143 |
+
_ = os.Remove(tempFile)
|
| 144 |
+
}()
|
| 145 |
+
_, err = config.LoadConfigOptional(tempFile, false)
|
| 146 |
+
if err != nil {
|
| 147 |
+
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
|
| 148 |
+
return
|
| 149 |
+
}
|
| 150 |
+
h.mu.Lock()
|
| 151 |
+
defer h.mu.Unlock()
|
| 152 |
+
if WriteConfig(h.configFilePath, body) != nil {
|
| 153 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": "failed to write config"})
|
| 154 |
+
return
|
| 155 |
+
}
|
| 156 |
+
// Reload into handler to keep memory in sync
|
| 157 |
+
newCfg, err := config.LoadConfig(h.configFilePath)
|
| 158 |
+
if err != nil {
|
| 159 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": err.Error()})
|
| 160 |
+
return
|
| 161 |
+
}
|
| 162 |
+
h.cfg = newCfg
|
| 163 |
+
c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}})
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// GetConfigYAML returns the raw config.yaml file bytes without re-encoding.
|
| 167 |
+
// It preserves comments and original formatting/styles.
|
| 168 |
+
func (h *Handler) GetConfigYAML(c *gin.Context) {
|
| 169 |
+
data, err := os.ReadFile(h.configFilePath)
|
| 170 |
+
if err != nil {
|
| 171 |
+
if os.IsNotExist(err) {
|
| 172 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "not_found", "message": "config file not found"})
|
| 173 |
+
return
|
| 174 |
+
}
|
| 175 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()})
|
| 176 |
+
return
|
| 177 |
+
}
|
| 178 |
+
c.Header("Content-Type", "application/yaml; charset=utf-8")
|
| 179 |
+
c.Header("Cache-Control", "no-store")
|
| 180 |
+
c.Header("X-Content-Type-Options", "nosniff")
|
| 181 |
+
// Write raw bytes as-is
|
| 182 |
+
_, _ = c.Writer.Write(data)
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Debug
|
| 186 |
+
func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) }
|
| 187 |
+
func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) }
|
| 188 |
+
|
| 189 |
+
// UsageStatisticsEnabled
|
| 190 |
+
func (h *Handler) GetUsageStatisticsEnabled(c *gin.Context) {
|
| 191 |
+
c.JSON(200, gin.H{"usage-statistics-enabled": h.cfg.UsageStatisticsEnabled})
|
| 192 |
+
}
|
| 193 |
+
func (h *Handler) PutUsageStatisticsEnabled(c *gin.Context) {
|
| 194 |
+
h.updateBoolField(c, func(v bool) { h.cfg.UsageStatisticsEnabled = v })
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// UsageStatisticsEnabled
|
| 198 |
+
func (h *Handler) GetLoggingToFile(c *gin.Context) {
|
| 199 |
+
c.JSON(200, gin.H{"logging-to-file": h.cfg.LoggingToFile})
|
| 200 |
+
}
|
| 201 |
+
func (h *Handler) PutLoggingToFile(c *gin.Context) {
|
| 202 |
+
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Request log
|
| 206 |
+
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
| 207 |
+
func (h *Handler) PutRequestLog(c *gin.Context) {
|
| 208 |
+
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
// Websocket auth
|
| 212 |
+
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
|
| 213 |
+
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
|
| 214 |
+
}
|
| 215 |
+
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
|
| 216 |
+
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Request retry
|
| 220 |
+
func (h *Handler) GetRequestRetry(c *gin.Context) {
|
| 221 |
+
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
|
| 222 |
+
}
|
| 223 |
+
func (h *Handler) PutRequestRetry(c *gin.Context) {
|
| 224 |
+
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Max retry interval
|
| 228 |
+
func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
|
| 229 |
+
c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
|
| 230 |
+
}
|
| 231 |
+
func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
| 232 |
+
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Proxy URL
|
| 236 |
+
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
| 237 |
+
func (h *Handler) PutProxyURL(c *gin.Context) {
|
| 238 |
+
h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v })
|
| 239 |
+
}
|
| 240 |
+
func (h *Handler) DeleteProxyURL(c *gin.Context) {
|
| 241 |
+
h.cfg.ProxyURL = ""
|
| 242 |
+
h.persist(c)
|
| 243 |
+
}
|
internal/api/handlers/management/config_lists.go
ADDED
|
@@ -0,0 +1,1090 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"strings"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
// Generic helpers for list[string]
|
| 13 |
+
func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) {
|
| 14 |
+
data, err := c.GetRawData()
|
| 15 |
+
if err != nil {
|
| 16 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 17 |
+
return
|
| 18 |
+
}
|
| 19 |
+
var arr []string
|
| 20 |
+
if err = json.Unmarshal(data, &arr); err != nil {
|
| 21 |
+
var obj struct {
|
| 22 |
+
Items []string `json:"items"`
|
| 23 |
+
}
|
| 24 |
+
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
| 25 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 26 |
+
return
|
| 27 |
+
}
|
| 28 |
+
arr = obj.Items
|
| 29 |
+
}
|
| 30 |
+
set(arr)
|
| 31 |
+
if after != nil {
|
| 32 |
+
after()
|
| 33 |
+
}
|
| 34 |
+
h.persist(c)
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) {
|
| 38 |
+
var body struct {
|
| 39 |
+
Old *string `json:"old"`
|
| 40 |
+
New *string `json:"new"`
|
| 41 |
+
Index *int `json:"index"`
|
| 42 |
+
Value *string `json:"value"`
|
| 43 |
+
}
|
| 44 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 45 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 46 |
+
return
|
| 47 |
+
}
|
| 48 |
+
if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) {
|
| 49 |
+
(*target)[*body.Index] = *body.Value
|
| 50 |
+
if after != nil {
|
| 51 |
+
after()
|
| 52 |
+
}
|
| 53 |
+
h.persist(c)
|
| 54 |
+
return
|
| 55 |
+
}
|
| 56 |
+
if body.Old != nil && body.New != nil {
|
| 57 |
+
for i := range *target {
|
| 58 |
+
if (*target)[i] == *body.Old {
|
| 59 |
+
(*target)[i] = *body.New
|
| 60 |
+
if after != nil {
|
| 61 |
+
after()
|
| 62 |
+
}
|
| 63 |
+
h.persist(c)
|
| 64 |
+
return
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
*target = append(*target, *body.New)
|
| 68 |
+
if after != nil {
|
| 69 |
+
after()
|
| 70 |
+
}
|
| 71 |
+
h.persist(c)
|
| 72 |
+
return
|
| 73 |
+
}
|
| 74 |
+
c.JSON(400, gin.H{"error": "missing fields"})
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) {
|
| 78 |
+
if idxStr := c.Query("index"); idxStr != "" {
|
| 79 |
+
var idx int
|
| 80 |
+
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
| 81 |
+
if err == nil && idx >= 0 && idx < len(*target) {
|
| 82 |
+
*target = append((*target)[:idx], (*target)[idx+1:]...)
|
| 83 |
+
if after != nil {
|
| 84 |
+
after()
|
| 85 |
+
}
|
| 86 |
+
h.persist(c)
|
| 87 |
+
return
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
if val := strings.TrimSpace(c.Query("value")); val != "" {
|
| 91 |
+
out := make([]string, 0, len(*target))
|
| 92 |
+
for _, v := range *target {
|
| 93 |
+
if strings.TrimSpace(v) != val {
|
| 94 |
+
out = append(out, v)
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
*target = out
|
| 98 |
+
if after != nil {
|
| 99 |
+
after()
|
| 100 |
+
}
|
| 101 |
+
h.persist(c)
|
| 102 |
+
return
|
| 103 |
+
}
|
| 104 |
+
c.JSON(400, gin.H{"error": "missing index or value"})
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// api-keys
|
| 108 |
+
func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) }
|
| 109 |
+
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
| 110 |
+
h.putStringList(c, func(v []string) {
|
| 111 |
+
h.cfg.APIKeys = append([]string(nil), v...)
|
| 112 |
+
h.cfg.Access.Providers = nil
|
| 113 |
+
}, nil)
|
| 114 |
+
}
|
| 115 |
+
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
| 116 |
+
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
| 117 |
+
}
|
| 118 |
+
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
| 119 |
+
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// gemini-api-key: []GeminiKey
|
| 123 |
+
func (h *Handler) GetGeminiKeys(c *gin.Context) {
|
| 124 |
+
c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey})
|
| 125 |
+
}
|
| 126 |
+
func (h *Handler) PutGeminiKeys(c *gin.Context) {
|
| 127 |
+
data, err := c.GetRawData()
|
| 128 |
+
if err != nil {
|
| 129 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 130 |
+
return
|
| 131 |
+
}
|
| 132 |
+
var arr []config.GeminiKey
|
| 133 |
+
if err = json.Unmarshal(data, &arr); err != nil {
|
| 134 |
+
var obj struct {
|
| 135 |
+
Items []config.GeminiKey `json:"items"`
|
| 136 |
+
}
|
| 137 |
+
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
| 138 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 139 |
+
return
|
| 140 |
+
}
|
| 141 |
+
arr = obj.Items
|
| 142 |
+
}
|
| 143 |
+
h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...)
|
| 144 |
+
h.cfg.SanitizeGeminiKeys()
|
| 145 |
+
h.persist(c)
|
| 146 |
+
}
|
| 147 |
+
func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
| 148 |
+
type geminiKeyPatch struct {
|
| 149 |
+
APIKey *string `json:"api-key"`
|
| 150 |
+
Prefix *string `json:"prefix"`
|
| 151 |
+
BaseURL *string `json:"base-url"`
|
| 152 |
+
ProxyURL *string `json:"proxy-url"`
|
| 153 |
+
Headers *map[string]string `json:"headers"`
|
| 154 |
+
ExcludedModels *[]string `json:"excluded-models"`
|
| 155 |
+
}
|
| 156 |
+
var body struct {
|
| 157 |
+
Index *int `json:"index"`
|
| 158 |
+
Match *string `json:"match"`
|
| 159 |
+
Value *geminiKeyPatch `json:"value"`
|
| 160 |
+
}
|
| 161 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 162 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 163 |
+
return
|
| 164 |
+
}
|
| 165 |
+
targetIndex := -1
|
| 166 |
+
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
| 167 |
+
targetIndex = *body.Index
|
| 168 |
+
}
|
| 169 |
+
if targetIndex == -1 && body.Match != nil {
|
| 170 |
+
match := strings.TrimSpace(*body.Match)
|
| 171 |
+
if match != "" {
|
| 172 |
+
for i := range h.cfg.GeminiKey {
|
| 173 |
+
if h.cfg.GeminiKey[i].APIKey == match {
|
| 174 |
+
targetIndex = i
|
| 175 |
+
break
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
if targetIndex == -1 {
|
| 181 |
+
c.JSON(404, gin.H{"error": "item not found"})
|
| 182 |
+
return
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
entry := h.cfg.GeminiKey[targetIndex]
|
| 186 |
+
if body.Value.APIKey != nil {
|
| 187 |
+
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
| 188 |
+
if trimmed == "" {
|
| 189 |
+
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
|
| 190 |
+
h.cfg.SanitizeGeminiKeys()
|
| 191 |
+
h.persist(c)
|
| 192 |
+
return
|
| 193 |
+
}
|
| 194 |
+
entry.APIKey = trimmed
|
| 195 |
+
}
|
| 196 |
+
if body.Value.Prefix != nil {
|
| 197 |
+
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
| 198 |
+
}
|
| 199 |
+
if body.Value.BaseURL != nil {
|
| 200 |
+
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
| 201 |
+
}
|
| 202 |
+
if body.Value.ProxyURL != nil {
|
| 203 |
+
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
| 204 |
+
}
|
| 205 |
+
if body.Value.Headers != nil {
|
| 206 |
+
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
| 207 |
+
}
|
| 208 |
+
if body.Value.ExcludedModels != nil {
|
| 209 |
+
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
| 210 |
+
}
|
| 211 |
+
h.cfg.GeminiKey[targetIndex] = entry
|
| 212 |
+
h.cfg.SanitizeGeminiKeys()
|
| 213 |
+
h.persist(c)
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
| 217 |
+
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
| 218 |
+
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
| 219 |
+
for _, v := range h.cfg.GeminiKey {
|
| 220 |
+
if v.APIKey != val {
|
| 221 |
+
out = append(out, v)
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
if len(out) != len(h.cfg.GeminiKey) {
|
| 225 |
+
h.cfg.GeminiKey = out
|
| 226 |
+
h.cfg.SanitizeGeminiKeys()
|
| 227 |
+
h.persist(c)
|
| 228 |
+
} else {
|
| 229 |
+
c.JSON(404, gin.H{"error": "item not found"})
|
| 230 |
+
}
|
| 231 |
+
return
|
| 232 |
+
}
|
| 233 |
+
if idxStr := c.Query("index"); idxStr != "" {
|
| 234 |
+
var idx int
|
| 235 |
+
if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) {
|
| 236 |
+
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...)
|
| 237 |
+
h.cfg.SanitizeGeminiKeys()
|
| 238 |
+
h.persist(c)
|
| 239 |
+
return
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// claude-api-key: []ClaudeKey
|
| 246 |
+
func (h *Handler) GetClaudeKeys(c *gin.Context) {
|
| 247 |
+
c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey})
|
| 248 |
+
}
|
| 249 |
+
func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
| 250 |
+
data, err := c.GetRawData()
|
| 251 |
+
if err != nil {
|
| 252 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 253 |
+
return
|
| 254 |
+
}
|
| 255 |
+
var arr []config.ClaudeKey
|
| 256 |
+
if err = json.Unmarshal(data, &arr); err != nil {
|
| 257 |
+
var obj struct {
|
| 258 |
+
Items []config.ClaudeKey `json:"items"`
|
| 259 |
+
}
|
| 260 |
+
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
| 261 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 262 |
+
return
|
| 263 |
+
}
|
| 264 |
+
arr = obj.Items
|
| 265 |
+
}
|
| 266 |
+
for i := range arr {
|
| 267 |
+
normalizeClaudeKey(&arr[i])
|
| 268 |
+
}
|
| 269 |
+
h.cfg.ClaudeKey = arr
|
| 270 |
+
h.cfg.SanitizeClaudeKeys()
|
| 271 |
+
h.persist(c)
|
| 272 |
+
}
|
| 273 |
+
func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
| 274 |
+
type claudeKeyPatch struct {
|
| 275 |
+
APIKey *string `json:"api-key"`
|
| 276 |
+
Prefix *string `json:"prefix"`
|
| 277 |
+
BaseURL *string `json:"base-url"`
|
| 278 |
+
ProxyURL *string `json:"proxy-url"`
|
| 279 |
+
Models *[]config.ClaudeModel `json:"models"`
|
| 280 |
+
Headers *map[string]string `json:"headers"`
|
| 281 |
+
ExcludedModels *[]string `json:"excluded-models"`
|
| 282 |
+
}
|
| 283 |
+
var body struct {
|
| 284 |
+
Index *int `json:"index"`
|
| 285 |
+
Match *string `json:"match"`
|
| 286 |
+
Value *claudeKeyPatch `json:"value"`
|
| 287 |
+
}
|
| 288 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 289 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 290 |
+
return
|
| 291 |
+
}
|
| 292 |
+
targetIndex := -1
|
| 293 |
+
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
| 294 |
+
targetIndex = *body.Index
|
| 295 |
+
}
|
| 296 |
+
if targetIndex == -1 && body.Match != nil {
|
| 297 |
+
match := strings.TrimSpace(*body.Match)
|
| 298 |
+
for i := range h.cfg.ClaudeKey {
|
| 299 |
+
if h.cfg.ClaudeKey[i].APIKey == match {
|
| 300 |
+
targetIndex = i
|
| 301 |
+
break
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
if targetIndex == -1 {
|
| 306 |
+
c.JSON(404, gin.H{"error": "item not found"})
|
| 307 |
+
return
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
entry := h.cfg.ClaudeKey[targetIndex]
|
| 311 |
+
if body.Value.APIKey != nil {
|
| 312 |
+
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
| 313 |
+
}
|
| 314 |
+
if body.Value.Prefix != nil {
|
| 315 |
+
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
| 316 |
+
}
|
| 317 |
+
if body.Value.BaseURL != nil {
|
| 318 |
+
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
| 319 |
+
}
|
| 320 |
+
if body.Value.ProxyURL != nil {
|
| 321 |
+
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
| 322 |
+
}
|
| 323 |
+
if body.Value.Models != nil {
|
| 324 |
+
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
|
| 325 |
+
}
|
| 326 |
+
if body.Value.Headers != nil {
|
| 327 |
+
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
| 328 |
+
}
|
| 329 |
+
if body.Value.ExcludedModels != nil {
|
| 330 |
+
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
| 331 |
+
}
|
| 332 |
+
normalizeClaudeKey(&entry)
|
| 333 |
+
h.cfg.ClaudeKey[targetIndex] = entry
|
| 334 |
+
h.cfg.SanitizeClaudeKeys()
|
| 335 |
+
h.persist(c)
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
| 339 |
+
if val := c.Query("api-key"); val != "" {
|
| 340 |
+
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
| 341 |
+
for _, v := range h.cfg.ClaudeKey {
|
| 342 |
+
if v.APIKey != val {
|
| 343 |
+
out = append(out, v)
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
h.cfg.ClaudeKey = out
|
| 347 |
+
h.cfg.SanitizeClaudeKeys()
|
| 348 |
+
h.persist(c)
|
| 349 |
+
return
|
| 350 |
+
}
|
| 351 |
+
if idxStr := c.Query("index"); idxStr != "" {
|
| 352 |
+
var idx int
|
| 353 |
+
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
| 354 |
+
if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) {
|
| 355 |
+
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...)
|
| 356 |
+
h.cfg.SanitizeClaudeKeys()
|
| 357 |
+
h.persist(c)
|
| 358 |
+
return
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
// openai-compatibility: []OpenAICompatibility
|
| 365 |
+
func (h *Handler) GetOpenAICompat(c *gin.Context) {
|
| 366 |
+
c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)})
|
| 367 |
+
}
|
| 368 |
+
func (h *Handler) PutOpenAICompat(c *gin.Context) {
|
| 369 |
+
data, err := c.GetRawData()
|
| 370 |
+
if err != nil {
|
| 371 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 372 |
+
return
|
| 373 |
+
}
|
| 374 |
+
var arr []config.OpenAICompatibility
|
| 375 |
+
if err = json.Unmarshal(data, &arr); err != nil {
|
| 376 |
+
var obj struct {
|
| 377 |
+
Items []config.OpenAICompatibility `json:"items"`
|
| 378 |
+
}
|
| 379 |
+
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
| 380 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 381 |
+
return
|
| 382 |
+
}
|
| 383 |
+
arr = obj.Items
|
| 384 |
+
}
|
| 385 |
+
filtered := make([]config.OpenAICompatibility, 0, len(arr))
|
| 386 |
+
for i := range arr {
|
| 387 |
+
normalizeOpenAICompatibilityEntry(&arr[i])
|
| 388 |
+
if strings.TrimSpace(arr[i].BaseURL) != "" {
|
| 389 |
+
filtered = append(filtered, arr[i])
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
h.cfg.OpenAICompatibility = filtered
|
| 393 |
+
h.cfg.SanitizeOpenAICompatibility()
|
| 394 |
+
h.persist(c)
|
| 395 |
+
}
|
| 396 |
+
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
|
| 397 |
+
type openAICompatPatch struct {
|
| 398 |
+
Name *string `json:"name"`
|
| 399 |
+
Prefix *string `json:"prefix"`
|
| 400 |
+
BaseURL *string `json:"base-url"`
|
| 401 |
+
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
|
| 402 |
+
Models *[]config.OpenAICompatibilityModel `json:"models"`
|
| 403 |
+
Headers *map[string]string `json:"headers"`
|
| 404 |
+
}
|
| 405 |
+
var body struct {
|
| 406 |
+
Name *string `json:"name"`
|
| 407 |
+
Index *int `json:"index"`
|
| 408 |
+
Value *openAICompatPatch `json:"value"`
|
| 409 |
+
}
|
| 410 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 411 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 412 |
+
return
|
| 413 |
+
}
|
| 414 |
+
targetIndex := -1
|
| 415 |
+
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
| 416 |
+
targetIndex = *body.Index
|
| 417 |
+
}
|
| 418 |
+
if targetIndex == -1 && body.Name != nil {
|
| 419 |
+
match := strings.TrimSpace(*body.Name)
|
| 420 |
+
for i := range h.cfg.OpenAICompatibility {
|
| 421 |
+
if h.cfg.OpenAICompatibility[i].Name == match {
|
| 422 |
+
targetIndex = i
|
| 423 |
+
break
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
if targetIndex == -1 {
|
| 428 |
+
c.JSON(404, gin.H{"error": "item not found"})
|
| 429 |
+
return
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
entry := h.cfg.OpenAICompatibility[targetIndex]
|
| 433 |
+
if body.Value.Name != nil {
|
| 434 |
+
entry.Name = strings.TrimSpace(*body.Value.Name)
|
| 435 |
+
}
|
| 436 |
+
if body.Value.Prefix != nil {
|
| 437 |
+
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
| 438 |
+
}
|
| 439 |
+
if body.Value.BaseURL != nil {
|
| 440 |
+
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
| 441 |
+
if trimmed == "" {
|
| 442 |
+
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
|
| 443 |
+
h.cfg.SanitizeOpenAICompatibility()
|
| 444 |
+
h.persist(c)
|
| 445 |
+
return
|
| 446 |
+
}
|
| 447 |
+
entry.BaseURL = trimmed
|
| 448 |
+
}
|
| 449 |
+
if body.Value.APIKeyEntries != nil {
|
| 450 |
+
entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
|
| 451 |
+
}
|
| 452 |
+
if body.Value.Models != nil {
|
| 453 |
+
entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
|
| 454 |
+
}
|
| 455 |
+
if body.Value.Headers != nil {
|
| 456 |
+
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
| 457 |
+
}
|
| 458 |
+
normalizeOpenAICompatibilityEntry(&entry)
|
| 459 |
+
h.cfg.OpenAICompatibility[targetIndex] = entry
|
| 460 |
+
h.cfg.SanitizeOpenAICompatibility()
|
| 461 |
+
h.persist(c)
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
| 465 |
+
if name := c.Query("name"); name != "" {
|
| 466 |
+
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
| 467 |
+
for _, v := range h.cfg.OpenAICompatibility {
|
| 468 |
+
if v.Name != name {
|
| 469 |
+
out = append(out, v)
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
h.cfg.OpenAICompatibility = out
|
| 473 |
+
h.cfg.SanitizeOpenAICompatibility()
|
| 474 |
+
h.persist(c)
|
| 475 |
+
return
|
| 476 |
+
}
|
| 477 |
+
if idxStr := c.Query("index"); idxStr != "" {
|
| 478 |
+
var idx int
|
| 479 |
+
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
| 480 |
+
if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) {
|
| 481 |
+
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...)
|
| 482 |
+
h.cfg.SanitizeOpenAICompatibility()
|
| 483 |
+
h.persist(c)
|
| 484 |
+
return
|
| 485 |
+
}
|
| 486 |
+
}
|
| 487 |
+
c.JSON(400, gin.H{"error": "missing name or index"})
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
// oauth-excluded-models: map[string][]string
|
| 491 |
+
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
| 492 |
+
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
func (h *Handler) PutOAuthExcludedModels(c *gin.Context) {
|
| 496 |
+
data, err := c.GetRawData()
|
| 497 |
+
if err != nil {
|
| 498 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 499 |
+
return
|
| 500 |
+
}
|
| 501 |
+
var entries map[string][]string
|
| 502 |
+
if err = json.Unmarshal(data, &entries); err != nil {
|
| 503 |
+
var wrapper struct {
|
| 504 |
+
Items map[string][]string `json:"items"`
|
| 505 |
+
}
|
| 506 |
+
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
|
| 507 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 508 |
+
return
|
| 509 |
+
}
|
| 510 |
+
entries = wrapper.Items
|
| 511 |
+
}
|
| 512 |
+
h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries)
|
| 513 |
+
h.persist(c)
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) {
|
| 517 |
+
var body struct {
|
| 518 |
+
Provider *string `json:"provider"`
|
| 519 |
+
Models []string `json:"models"`
|
| 520 |
+
}
|
| 521 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil {
|
| 522 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 523 |
+
return
|
| 524 |
+
}
|
| 525 |
+
provider := strings.ToLower(strings.TrimSpace(*body.Provider))
|
| 526 |
+
if provider == "" {
|
| 527 |
+
c.JSON(400, gin.H{"error": "invalid provider"})
|
| 528 |
+
return
|
| 529 |
+
}
|
| 530 |
+
normalized := config.NormalizeExcludedModels(body.Models)
|
| 531 |
+
if len(normalized) == 0 {
|
| 532 |
+
if h.cfg.OAuthExcludedModels == nil {
|
| 533 |
+
c.JSON(404, gin.H{"error": "provider not found"})
|
| 534 |
+
return
|
| 535 |
+
}
|
| 536 |
+
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
|
| 537 |
+
c.JSON(404, gin.H{"error": "provider not found"})
|
| 538 |
+
return
|
| 539 |
+
}
|
| 540 |
+
delete(h.cfg.OAuthExcludedModels, provider)
|
| 541 |
+
if len(h.cfg.OAuthExcludedModels) == 0 {
|
| 542 |
+
h.cfg.OAuthExcludedModels = nil
|
| 543 |
+
}
|
| 544 |
+
h.persist(c)
|
| 545 |
+
return
|
| 546 |
+
}
|
| 547 |
+
if h.cfg.OAuthExcludedModels == nil {
|
| 548 |
+
h.cfg.OAuthExcludedModels = make(map[string][]string)
|
| 549 |
+
}
|
| 550 |
+
h.cfg.OAuthExcludedModels[provider] = normalized
|
| 551 |
+
h.persist(c)
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
|
| 555 |
+
provider := strings.ToLower(strings.TrimSpace(c.Query("provider")))
|
| 556 |
+
if provider == "" {
|
| 557 |
+
c.JSON(400, gin.H{"error": "missing provider"})
|
| 558 |
+
return
|
| 559 |
+
}
|
| 560 |
+
if h.cfg.OAuthExcludedModels == nil {
|
| 561 |
+
c.JSON(404, gin.H{"error": "provider not found"})
|
| 562 |
+
return
|
| 563 |
+
}
|
| 564 |
+
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
|
| 565 |
+
c.JSON(404, gin.H{"error": "provider not found"})
|
| 566 |
+
return
|
| 567 |
+
}
|
| 568 |
+
delete(h.cfg.OAuthExcludedModels, provider)
|
| 569 |
+
if len(h.cfg.OAuthExcludedModels) == 0 {
|
| 570 |
+
h.cfg.OAuthExcludedModels = nil
|
| 571 |
+
}
|
| 572 |
+
h.persist(c)
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
// codex-api-key: []CodexKey
|
| 576 |
+
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
| 577 |
+
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
| 578 |
+
}
|
| 579 |
+
func (h *Handler) PutCodexKeys(c *gin.Context) {
|
| 580 |
+
data, err := c.GetRawData()
|
| 581 |
+
if err != nil {
|
| 582 |
+
c.JSON(400, gin.H{"error": "failed to read body"})
|
| 583 |
+
return
|
| 584 |
+
}
|
| 585 |
+
var arr []config.CodexKey
|
| 586 |
+
if err = json.Unmarshal(data, &arr); err != nil {
|
| 587 |
+
var obj struct {
|
| 588 |
+
Items []config.CodexKey `json:"items"`
|
| 589 |
+
}
|
| 590 |
+
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
|
| 591 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 592 |
+
return
|
| 593 |
+
}
|
| 594 |
+
arr = obj.Items
|
| 595 |
+
}
|
| 596 |
+
// Filter out codex entries with empty base-url (treat as removed)
|
| 597 |
+
filtered := make([]config.CodexKey, 0, len(arr))
|
| 598 |
+
for i := range arr {
|
| 599 |
+
entry := arr[i]
|
| 600 |
+
normalizeCodexKey(&entry)
|
| 601 |
+
if entry.BaseURL == "" {
|
| 602 |
+
continue
|
| 603 |
+
}
|
| 604 |
+
filtered = append(filtered, entry)
|
| 605 |
+
}
|
| 606 |
+
h.cfg.CodexKey = filtered
|
| 607 |
+
h.cfg.SanitizeCodexKeys()
|
| 608 |
+
h.persist(c)
|
| 609 |
+
}
|
| 610 |
+
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
| 611 |
+
type codexKeyPatch struct {
|
| 612 |
+
APIKey *string `json:"api-key"`
|
| 613 |
+
Prefix *string `json:"prefix"`
|
| 614 |
+
BaseURL *string `json:"base-url"`
|
| 615 |
+
ProxyURL *string `json:"proxy-url"`
|
| 616 |
+
Models *[]config.CodexModel `json:"models"`
|
| 617 |
+
Headers *map[string]string `json:"headers"`
|
| 618 |
+
ExcludedModels *[]string `json:"excluded-models"`
|
| 619 |
+
}
|
| 620 |
+
var body struct {
|
| 621 |
+
Index *int `json:"index"`
|
| 622 |
+
Match *string `json:"match"`
|
| 623 |
+
Value *codexKeyPatch `json:"value"`
|
| 624 |
+
}
|
| 625 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 626 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 627 |
+
return
|
| 628 |
+
}
|
| 629 |
+
targetIndex := -1
|
| 630 |
+
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
| 631 |
+
targetIndex = *body.Index
|
| 632 |
+
}
|
| 633 |
+
if targetIndex == -1 && body.Match != nil {
|
| 634 |
+
match := strings.TrimSpace(*body.Match)
|
| 635 |
+
for i := range h.cfg.CodexKey {
|
| 636 |
+
if h.cfg.CodexKey[i].APIKey == match {
|
| 637 |
+
targetIndex = i
|
| 638 |
+
break
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
if targetIndex == -1 {
|
| 643 |
+
c.JSON(404, gin.H{"error": "item not found"})
|
| 644 |
+
return
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
entry := h.cfg.CodexKey[targetIndex]
|
| 648 |
+
if body.Value.APIKey != nil {
|
| 649 |
+
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
| 650 |
+
}
|
| 651 |
+
if body.Value.Prefix != nil {
|
| 652 |
+
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
| 653 |
+
}
|
| 654 |
+
if body.Value.BaseURL != nil {
|
| 655 |
+
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
| 656 |
+
if trimmed == "" {
|
| 657 |
+
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
|
| 658 |
+
h.cfg.SanitizeCodexKeys()
|
| 659 |
+
h.persist(c)
|
| 660 |
+
return
|
| 661 |
+
}
|
| 662 |
+
entry.BaseURL = trimmed
|
| 663 |
+
}
|
| 664 |
+
if body.Value.ProxyURL != nil {
|
| 665 |
+
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
| 666 |
+
}
|
| 667 |
+
if body.Value.Models != nil {
|
| 668 |
+
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
|
| 669 |
+
}
|
| 670 |
+
if body.Value.Headers != nil {
|
| 671 |
+
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
| 672 |
+
}
|
| 673 |
+
if body.Value.ExcludedModels != nil {
|
| 674 |
+
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
| 675 |
+
}
|
| 676 |
+
normalizeCodexKey(&entry)
|
| 677 |
+
h.cfg.CodexKey[targetIndex] = entry
|
| 678 |
+
h.cfg.SanitizeCodexKeys()
|
| 679 |
+
h.persist(c)
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
| 683 |
+
if val := c.Query("api-key"); val != "" {
|
| 684 |
+
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
| 685 |
+
for _, v := range h.cfg.CodexKey {
|
| 686 |
+
if v.APIKey != val {
|
| 687 |
+
out = append(out, v)
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
h.cfg.CodexKey = out
|
| 691 |
+
h.cfg.SanitizeCodexKeys()
|
| 692 |
+
h.persist(c)
|
| 693 |
+
return
|
| 694 |
+
}
|
| 695 |
+
if idxStr := c.Query("index"); idxStr != "" {
|
| 696 |
+
var idx int
|
| 697 |
+
_, err := fmt.Sscanf(idxStr, "%d", &idx)
|
| 698 |
+
if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) {
|
| 699 |
+
h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...)
|
| 700 |
+
h.cfg.SanitizeCodexKeys()
|
| 701 |
+
h.persist(c)
|
| 702 |
+
return
|
| 703 |
+
}
|
| 704 |
+
}
|
| 705 |
+
c.JSON(400, gin.H{"error": "missing api-key or index"})
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
func normalizeOpenAICompatibilityEntry(entry *config.OpenAICompatibility) {
|
| 709 |
+
if entry == nil {
|
| 710 |
+
return
|
| 711 |
+
}
|
| 712 |
+
// Trim base-url; empty base-url indicates provider should be removed by sanitization
|
| 713 |
+
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
| 714 |
+
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
| 715 |
+
existing := make(map[string]struct{}, len(entry.APIKeyEntries))
|
| 716 |
+
for i := range entry.APIKeyEntries {
|
| 717 |
+
trimmed := strings.TrimSpace(entry.APIKeyEntries[i].APIKey)
|
| 718 |
+
entry.APIKeyEntries[i].APIKey = trimmed
|
| 719 |
+
if trimmed != "" {
|
| 720 |
+
existing[trimmed] = struct{}{}
|
| 721 |
+
}
|
| 722 |
+
}
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
func normalizedOpenAICompatibilityEntries(entries []config.OpenAICompatibility) []config.OpenAICompatibility {
|
| 726 |
+
if len(entries) == 0 {
|
| 727 |
+
return nil
|
| 728 |
+
}
|
| 729 |
+
out := make([]config.OpenAICompatibility, len(entries))
|
| 730 |
+
for i := range entries {
|
| 731 |
+
copyEntry := entries[i]
|
| 732 |
+
if len(copyEntry.APIKeyEntries) > 0 {
|
| 733 |
+
copyEntry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), copyEntry.APIKeyEntries...)
|
| 734 |
+
}
|
| 735 |
+
normalizeOpenAICompatibilityEntry(©Entry)
|
| 736 |
+
out[i] = copyEntry
|
| 737 |
+
}
|
| 738 |
+
return out
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
func normalizeClaudeKey(entry *config.ClaudeKey) {
|
| 742 |
+
if entry == nil {
|
| 743 |
+
return
|
| 744 |
+
}
|
| 745 |
+
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
| 746 |
+
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
| 747 |
+
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
| 748 |
+
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
| 749 |
+
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
| 750 |
+
if len(entry.Models) == 0 {
|
| 751 |
+
return
|
| 752 |
+
}
|
| 753 |
+
normalized := make([]config.ClaudeModel, 0, len(entry.Models))
|
| 754 |
+
for i := range entry.Models {
|
| 755 |
+
model := entry.Models[i]
|
| 756 |
+
model.Name = strings.TrimSpace(model.Name)
|
| 757 |
+
model.Alias = strings.TrimSpace(model.Alias)
|
| 758 |
+
if model.Name == "" && model.Alias == "" {
|
| 759 |
+
continue
|
| 760 |
+
}
|
| 761 |
+
normalized = append(normalized, model)
|
| 762 |
+
}
|
| 763 |
+
entry.Models = normalized
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
func normalizeCodexKey(entry *config.CodexKey) {
|
| 767 |
+
if entry == nil {
|
| 768 |
+
return
|
| 769 |
+
}
|
| 770 |
+
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
| 771 |
+
entry.Prefix = strings.TrimSpace(entry.Prefix)
|
| 772 |
+
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
| 773 |
+
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
| 774 |
+
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
| 775 |
+
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
| 776 |
+
if len(entry.Models) == 0 {
|
| 777 |
+
return
|
| 778 |
+
}
|
| 779 |
+
normalized := make([]config.CodexModel, 0, len(entry.Models))
|
| 780 |
+
for i := range entry.Models {
|
| 781 |
+
model := entry.Models[i]
|
| 782 |
+
model.Name = strings.TrimSpace(model.Name)
|
| 783 |
+
model.Alias = strings.TrimSpace(model.Alias)
|
| 784 |
+
if model.Name == "" && model.Alias == "" {
|
| 785 |
+
continue
|
| 786 |
+
}
|
| 787 |
+
normalized = append(normalized, model)
|
| 788 |
+
}
|
| 789 |
+
entry.Models = normalized
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
// GetAmpCode returns the complete ampcode configuration.
|
| 793 |
+
func (h *Handler) GetAmpCode(c *gin.Context) {
|
| 794 |
+
if h == nil || h.cfg == nil {
|
| 795 |
+
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
|
| 796 |
+
return
|
| 797 |
+
}
|
| 798 |
+
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
// GetAmpUpstreamURL returns the ampcode upstream URL.
|
| 802 |
+
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
|
| 803 |
+
if h == nil || h.cfg == nil {
|
| 804 |
+
c.JSON(200, gin.H{"upstream-url": ""})
|
| 805 |
+
return
|
| 806 |
+
}
|
| 807 |
+
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
// PutAmpUpstreamURL updates the ampcode upstream URL.
|
| 811 |
+
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
|
| 812 |
+
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
|
| 816 |
+
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
|
| 817 |
+
h.cfg.AmpCode.UpstreamURL = ""
|
| 818 |
+
h.persist(c)
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
|
| 822 |
+
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
|
| 823 |
+
if h == nil || h.cfg == nil {
|
| 824 |
+
c.JSON(200, gin.H{"upstream-api-key": ""})
|
| 825 |
+
return
|
| 826 |
+
}
|
| 827 |
+
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
|
| 831 |
+
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
|
| 832 |
+
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
|
| 836 |
+
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
|
| 837 |
+
h.cfg.AmpCode.UpstreamAPIKey = ""
|
| 838 |
+
h.persist(c)
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
|
| 842 |
+
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
|
| 843 |
+
if h == nil || h.cfg == nil {
|
| 844 |
+
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
|
| 845 |
+
return
|
| 846 |
+
}
|
| 847 |
+
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
|
| 851 |
+
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
|
| 852 |
+
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
// GetAmpModelMappings returns the ampcode model mappings.
|
| 856 |
+
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
|
| 857 |
+
if h == nil || h.cfg == nil {
|
| 858 |
+
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
|
| 859 |
+
return
|
| 860 |
+
}
|
| 861 |
+
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
// PutAmpModelMappings replaces all ampcode model mappings.
|
| 865 |
+
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
|
| 866 |
+
var body struct {
|
| 867 |
+
Value []config.AmpModelMapping `json:"value"`
|
| 868 |
+
}
|
| 869 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 870 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 871 |
+
return
|
| 872 |
+
}
|
| 873 |
+
h.cfg.AmpCode.ModelMappings = body.Value
|
| 874 |
+
h.persist(c)
|
| 875 |
+
}
|
| 876 |
+
|
| 877 |
+
// PatchAmpModelMappings adds or updates model mappings.
|
| 878 |
+
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
|
| 879 |
+
var body struct {
|
| 880 |
+
Value []config.AmpModelMapping `json:"value"`
|
| 881 |
+
}
|
| 882 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 883 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 884 |
+
return
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
existing := make(map[string]int)
|
| 888 |
+
for i, m := range h.cfg.AmpCode.ModelMappings {
|
| 889 |
+
existing[strings.TrimSpace(m.From)] = i
|
| 890 |
+
}
|
| 891 |
+
|
| 892 |
+
for _, newMapping := range body.Value {
|
| 893 |
+
from := strings.TrimSpace(newMapping.From)
|
| 894 |
+
if idx, ok := existing[from]; ok {
|
| 895 |
+
h.cfg.AmpCode.ModelMappings[idx] = newMapping
|
| 896 |
+
} else {
|
| 897 |
+
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
|
| 898 |
+
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
|
| 899 |
+
}
|
| 900 |
+
}
|
| 901 |
+
h.persist(c)
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
// DeleteAmpModelMappings removes specified model mappings by "from" field.
|
| 905 |
+
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
|
| 906 |
+
var body struct {
|
| 907 |
+
Value []string `json:"value"`
|
| 908 |
+
}
|
| 909 |
+
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
|
| 910 |
+
h.cfg.AmpCode.ModelMappings = nil
|
| 911 |
+
h.persist(c)
|
| 912 |
+
return
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
toRemove := make(map[string]bool)
|
| 916 |
+
for _, from := range body.Value {
|
| 917 |
+
toRemove[strings.TrimSpace(from)] = true
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
|
| 921 |
+
for _, m := range h.cfg.AmpCode.ModelMappings {
|
| 922 |
+
if !toRemove[strings.TrimSpace(m.From)] {
|
| 923 |
+
newMappings = append(newMappings, m)
|
| 924 |
+
}
|
| 925 |
+
}
|
| 926 |
+
h.cfg.AmpCode.ModelMappings = newMappings
|
| 927 |
+
h.persist(c)
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
// GetAmpForceModelMappings returns whether model mappings are forced.
|
| 931 |
+
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
| 932 |
+
if h == nil || h.cfg == nil {
|
| 933 |
+
c.JSON(200, gin.H{"force-model-mappings": false})
|
| 934 |
+
return
|
| 935 |
+
}
|
| 936 |
+
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
// PutAmpForceModelMappings updates the force model mappings setting.
|
| 940 |
+
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
| 941 |
+
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
+
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
| 945 |
+
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
| 946 |
+
if h == nil || h.cfg == nil {
|
| 947 |
+
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
| 948 |
+
return
|
| 949 |
+
}
|
| 950 |
+
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
| 954 |
+
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
| 955 |
+
var body struct {
|
| 956 |
+
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
| 957 |
+
}
|
| 958 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 959 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 960 |
+
return
|
| 961 |
+
}
|
| 962 |
+
// Normalize entries: trim whitespace, filter empty
|
| 963 |
+
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
| 964 |
+
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
| 965 |
+
h.persist(c)
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
| 969 |
+
// Matching is done by upstream-api-key value.
|
| 970 |
+
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
| 971 |
+
var body struct {
|
| 972 |
+
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
| 973 |
+
}
|
| 974 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 975 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 976 |
+
return
|
| 977 |
+
}
|
| 978 |
+
|
| 979 |
+
existing := make(map[string]int)
|
| 980 |
+
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
| 981 |
+
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
for _, newEntry := range body.Value {
|
| 985 |
+
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
| 986 |
+
if upstreamKey == "" {
|
| 987 |
+
continue
|
| 988 |
+
}
|
| 989 |
+
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
| 990 |
+
UpstreamAPIKey: upstreamKey,
|
| 991 |
+
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
| 992 |
+
}
|
| 993 |
+
if idx, ok := existing[upstreamKey]; ok {
|
| 994 |
+
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
| 995 |
+
} else {
|
| 996 |
+
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
| 997 |
+
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
| 998 |
+
}
|
| 999 |
+
}
|
| 1000 |
+
h.persist(c)
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
| 1004 |
+
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
| 1005 |
+
// If "value" is an empty array, clears all entries.
|
| 1006 |
+
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
| 1007 |
+
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
| 1008 |
+
var body struct {
|
| 1009 |
+
Value []string `json:"value"`
|
| 1010 |
+
}
|
| 1011 |
+
if err := c.ShouldBindJSON(&body); err != nil {
|
| 1012 |
+
c.JSON(400, gin.H{"error": "invalid body"})
|
| 1013 |
+
return
|
| 1014 |
+
}
|
| 1015 |
+
|
| 1016 |
+
if body.Value == nil {
|
| 1017 |
+
c.JSON(400, gin.H{"error": "missing value"})
|
| 1018 |
+
return
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
// Empty array means clear all
|
| 1022 |
+
if len(body.Value) == 0 {
|
| 1023 |
+
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
| 1024 |
+
h.persist(c)
|
| 1025 |
+
return
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
toRemove := make(map[string]bool)
|
| 1029 |
+
for _, key := range body.Value {
|
| 1030 |
+
trimmed := strings.TrimSpace(key)
|
| 1031 |
+
if trimmed == "" {
|
| 1032 |
+
continue
|
| 1033 |
+
}
|
| 1034 |
+
toRemove[trimmed] = true
|
| 1035 |
+
}
|
| 1036 |
+
if len(toRemove) == 0 {
|
| 1037 |
+
c.JSON(400, gin.H{"error": "empty value"})
|
| 1038 |
+
return
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
| 1042 |
+
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
| 1043 |
+
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
| 1044 |
+
newEntries = append(newEntries, entry)
|
| 1045 |
+
}
|
| 1046 |
+
}
|
| 1047 |
+
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
| 1048 |
+
h.persist(c)
|
| 1049 |
+
}
|
| 1050 |
+
|
| 1051 |
+
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
| 1052 |
+
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
| 1053 |
+
if len(entries) == 0 {
|
| 1054 |
+
return nil
|
| 1055 |
+
}
|
| 1056 |
+
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
| 1057 |
+
for _, entry := range entries {
|
| 1058 |
+
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
| 1059 |
+
if upstreamKey == "" {
|
| 1060 |
+
continue
|
| 1061 |
+
}
|
| 1062 |
+
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
| 1063 |
+
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
| 1064 |
+
UpstreamAPIKey: upstreamKey,
|
| 1065 |
+
APIKeys: apiKeys,
|
| 1066 |
+
})
|
| 1067 |
+
}
|
| 1068 |
+
if len(out) == 0 {
|
| 1069 |
+
return nil
|
| 1070 |
+
}
|
| 1071 |
+
return out
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
| 1075 |
+
func normalizeAPIKeysList(keys []string) []string {
|
| 1076 |
+
if len(keys) == 0 {
|
| 1077 |
+
return nil
|
| 1078 |
+
}
|
| 1079 |
+
out := make([]string, 0, len(keys))
|
| 1080 |
+
for _, k := range keys {
|
| 1081 |
+
trimmed := strings.TrimSpace(k)
|
| 1082 |
+
if trimmed != "" {
|
| 1083 |
+
out = append(out, trimmed)
|
| 1084 |
+
}
|
| 1085 |
+
}
|
| 1086 |
+
if len(out) == 0 {
|
| 1087 |
+
return nil
|
| 1088 |
+
}
|
| 1089 |
+
return out
|
| 1090 |
+
}
|
internal/api/handlers/management/handler.go
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package management provides the management API handlers and middleware
|
| 2 |
+
// for configuring the server and managing auth files.
|
| 3 |
+
package management
|
| 4 |
+
|
| 5 |
+
import (
|
| 6 |
+
"crypto/subtle"
|
| 7 |
+
"fmt"
|
| 8 |
+
"net/http"
|
| 9 |
+
"os"
|
| 10 |
+
"path/filepath"
|
| 11 |
+
"strings"
|
| 12 |
+
"sync"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"github.com/gin-gonic/gin"
|
| 16 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
| 17 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 18 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
| 19 |
+
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
| 20 |
+
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 21 |
+
"golang.org/x/crypto/bcrypt"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
type attemptInfo struct {
|
| 25 |
+
count int
|
| 26 |
+
blockedUntil time.Time
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// Handler aggregates config reference, persistence path and helpers.
|
| 30 |
+
type Handler struct {
|
| 31 |
+
cfg *config.Config
|
| 32 |
+
configFilePath string
|
| 33 |
+
mu sync.Mutex
|
| 34 |
+
attemptsMu sync.Mutex
|
| 35 |
+
failedAttempts map[string]*attemptInfo // keyed by client IP
|
| 36 |
+
authManager *coreauth.Manager
|
| 37 |
+
usageStats *usage.RequestStatistics
|
| 38 |
+
tokenStore coreauth.Store
|
| 39 |
+
localPassword string
|
| 40 |
+
allowRemoteOverride bool
|
| 41 |
+
envSecret string
|
| 42 |
+
logDir string
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// NewHandler creates a new management handler instance.
|
| 46 |
+
func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Manager) *Handler {
|
| 47 |
+
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
|
| 48 |
+
envSecret = strings.TrimSpace(envSecret)
|
| 49 |
+
|
| 50 |
+
return &Handler{
|
| 51 |
+
cfg: cfg,
|
| 52 |
+
configFilePath: configFilePath,
|
| 53 |
+
failedAttempts: make(map[string]*attemptInfo),
|
| 54 |
+
authManager: manager,
|
| 55 |
+
usageStats: usage.GetRequestStatistics(),
|
| 56 |
+
tokenStore: sdkAuth.GetTokenStore(),
|
| 57 |
+
allowRemoteOverride: envSecret != "",
|
| 58 |
+
envSecret: envSecret,
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// NewHandler creates a new management handler instance.
|
| 63 |
+
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
|
| 64 |
+
return NewHandler(cfg, "", manager)
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// SetConfig updates the in-memory config reference when the server hot-reloads.
|
| 68 |
+
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
|
| 69 |
+
|
| 70 |
+
// SetAuthManager updates the auth manager reference used by management endpoints.
|
| 71 |
+
func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager }
|
| 72 |
+
|
| 73 |
+
// SetUsageStatistics allows replacing the usage statistics reference.
|
| 74 |
+
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
|
| 75 |
+
|
| 76 |
+
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
|
| 77 |
+
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
|
| 78 |
+
|
| 79 |
+
// SetLogDirectory updates the directory where main.log should be looked up.
|
| 80 |
+
func (h *Handler) SetLogDirectory(dir string) {
|
| 81 |
+
if dir == "" {
|
| 82 |
+
return
|
| 83 |
+
}
|
| 84 |
+
if !filepath.IsAbs(dir) {
|
| 85 |
+
if abs, err := filepath.Abs(dir); err == nil {
|
| 86 |
+
dir = abs
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
h.logDir = dir
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// Middleware enforces access control for management endpoints.
|
| 93 |
+
// All requests (local and remote) require a valid management key.
|
| 94 |
+
// Additionally, remote access requires allow-remote-management=true.
|
| 95 |
+
func (h *Handler) Middleware() gin.HandlerFunc {
|
| 96 |
+
const maxFailures = 5
|
| 97 |
+
const banDuration = 30 * time.Minute
|
| 98 |
+
|
| 99 |
+
return func(c *gin.Context) {
|
| 100 |
+
c.Header("X-CPA-VERSION", buildinfo.Version)
|
| 101 |
+
c.Header("X-CPA-COMMIT", buildinfo.Commit)
|
| 102 |
+
c.Header("X-CPA-BUILD-DATE", buildinfo.BuildDate)
|
| 103 |
+
|
| 104 |
+
clientIP := c.ClientIP()
|
| 105 |
+
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
|
| 106 |
+
cfg := h.cfg
|
| 107 |
+
var (
|
| 108 |
+
allowRemote bool
|
| 109 |
+
secretHash string
|
| 110 |
+
)
|
| 111 |
+
if cfg != nil {
|
| 112 |
+
allowRemote = cfg.RemoteManagement.AllowRemote
|
| 113 |
+
secretHash = cfg.RemoteManagement.SecretKey
|
| 114 |
+
}
|
| 115 |
+
if h.allowRemoteOverride {
|
| 116 |
+
allowRemote = true
|
| 117 |
+
}
|
| 118 |
+
envSecret := h.envSecret
|
| 119 |
+
|
| 120 |
+
fail := func() {}
|
| 121 |
+
if !localClient {
|
| 122 |
+
h.attemptsMu.Lock()
|
| 123 |
+
ai := h.failedAttempts[clientIP]
|
| 124 |
+
if ai != nil {
|
| 125 |
+
if !ai.blockedUntil.IsZero() {
|
| 126 |
+
if time.Now().Before(ai.blockedUntil) {
|
| 127 |
+
remaining := time.Until(ai.blockedUntil).Round(time.Second)
|
| 128 |
+
h.attemptsMu.Unlock()
|
| 129 |
+
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
|
| 130 |
+
return
|
| 131 |
+
}
|
| 132 |
+
// Ban expired, reset state
|
| 133 |
+
ai.blockedUntil = time.Time{}
|
| 134 |
+
ai.count = 0
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
h.attemptsMu.Unlock()
|
| 138 |
+
|
| 139 |
+
if !allowRemote {
|
| 140 |
+
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
|
| 141 |
+
return
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
fail = func() {
|
| 145 |
+
h.attemptsMu.Lock()
|
| 146 |
+
aip := h.failedAttempts[clientIP]
|
| 147 |
+
if aip == nil {
|
| 148 |
+
aip = &attemptInfo{}
|
| 149 |
+
h.failedAttempts[clientIP] = aip
|
| 150 |
+
}
|
| 151 |
+
aip.count++
|
| 152 |
+
if aip.count >= maxFailures {
|
| 153 |
+
aip.blockedUntil = time.Now().Add(banDuration)
|
| 154 |
+
aip.count = 0
|
| 155 |
+
}
|
| 156 |
+
h.attemptsMu.Unlock()
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
if secretHash == "" && envSecret == "" {
|
| 160 |
+
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
|
| 161 |
+
return
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// Accept either Authorization: Bearer <key> or X-Management-Key
|
| 165 |
+
var provided string
|
| 166 |
+
if ah := c.GetHeader("Authorization"); ah != "" {
|
| 167 |
+
parts := strings.SplitN(ah, " ", 2)
|
| 168 |
+
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
|
| 169 |
+
provided = parts[1]
|
| 170 |
+
} else {
|
| 171 |
+
provided = ah
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
if provided == "" {
|
| 175 |
+
provided = c.GetHeader("X-Management-Key")
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if provided == "" {
|
| 179 |
+
if !localClient {
|
| 180 |
+
fail()
|
| 181 |
+
}
|
| 182 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
|
| 183 |
+
return
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
if localClient {
|
| 187 |
+
if lp := h.localPassword; lp != "" {
|
| 188 |
+
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
|
| 189 |
+
c.Next()
|
| 190 |
+
return
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
|
| 196 |
+
if !localClient {
|
| 197 |
+
h.attemptsMu.Lock()
|
| 198 |
+
if ai := h.failedAttempts[clientIP]; ai != nil {
|
| 199 |
+
ai.count = 0
|
| 200 |
+
ai.blockedUntil = time.Time{}
|
| 201 |
+
}
|
| 202 |
+
h.attemptsMu.Unlock()
|
| 203 |
+
}
|
| 204 |
+
c.Next()
|
| 205 |
+
return
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
|
| 209 |
+
if !localClient {
|
| 210 |
+
fail()
|
| 211 |
+
}
|
| 212 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
|
| 213 |
+
return
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
if !localClient {
|
| 217 |
+
h.attemptsMu.Lock()
|
| 218 |
+
if ai := h.failedAttempts[clientIP]; ai != nil {
|
| 219 |
+
ai.count = 0
|
| 220 |
+
ai.blockedUntil = time.Time{}
|
| 221 |
+
}
|
| 222 |
+
h.attemptsMu.Unlock()
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
c.Next()
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// persist saves the current in-memory config to disk.
|
| 230 |
+
func (h *Handler) persist(c *gin.Context) bool {
|
| 231 |
+
h.mu.Lock()
|
| 232 |
+
defer h.mu.Unlock()
|
| 233 |
+
// Preserve comments when writing
|
| 234 |
+
if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil {
|
| 235 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)})
|
| 236 |
+
return false
|
| 237 |
+
}
|
| 238 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 239 |
+
return true
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// Helper methods for simple types
|
| 243 |
+
func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
|
| 244 |
+
var body struct {
|
| 245 |
+
Value *bool `json:"value"`
|
| 246 |
+
}
|
| 247 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 248 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
| 249 |
+
return
|
| 250 |
+
}
|
| 251 |
+
set(*body.Value)
|
| 252 |
+
h.persist(c)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
func (h *Handler) updateIntField(c *gin.Context, set func(int)) {
|
| 256 |
+
var body struct {
|
| 257 |
+
Value *int `json:"value"`
|
| 258 |
+
}
|
| 259 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 260 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
| 261 |
+
return
|
| 262 |
+
}
|
| 263 |
+
set(*body.Value)
|
| 264 |
+
h.persist(c)
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
func (h *Handler) updateStringField(c *gin.Context, set func(string)) {
|
| 268 |
+
var body struct {
|
| 269 |
+
Value *string `json:"value"`
|
| 270 |
+
}
|
| 271 |
+
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
| 272 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
| 273 |
+
return
|
| 274 |
+
}
|
| 275 |
+
set(*body.Value)
|
| 276 |
+
h.persist(c)
|
| 277 |
+
}
|
internal/api/handlers/management/logs.go
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"fmt"
|
| 6 |
+
"math"
|
| 7 |
+
"net/http"
|
| 8 |
+
"os"
|
| 9 |
+
"path/filepath"
|
| 10 |
+
"sort"
|
| 11 |
+
"strconv"
|
| 12 |
+
"strings"
|
| 13 |
+
"time"
|
| 14 |
+
|
| 15 |
+
"github.com/gin-gonic/gin"
|
| 16 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
const (
|
| 20 |
+
defaultLogFileName = "main.log"
|
| 21 |
+
logScannerInitialBuffer = 64 * 1024
|
| 22 |
+
logScannerMaxBuffer = 8 * 1024 * 1024
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
// GetLogs returns log lines with optional incremental loading.
|
| 26 |
+
func (h *Handler) GetLogs(c *gin.Context) {
|
| 27 |
+
if h == nil {
|
| 28 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
| 29 |
+
return
|
| 30 |
+
}
|
| 31 |
+
if h.cfg == nil {
|
| 32 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
| 33 |
+
return
|
| 34 |
+
}
|
| 35 |
+
if !h.cfg.LoggingToFile {
|
| 36 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
|
| 37 |
+
return
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
logDir := h.logDirectory()
|
| 41 |
+
if strings.TrimSpace(logDir) == "" {
|
| 42 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
| 43 |
+
return
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
files, err := h.collectLogFiles(logDir)
|
| 47 |
+
if err != nil {
|
| 48 |
+
if os.IsNotExist(err) {
|
| 49 |
+
cutoff := parseCutoff(c.Query("after"))
|
| 50 |
+
c.JSON(http.StatusOK, gin.H{
|
| 51 |
+
"lines": []string{},
|
| 52 |
+
"line-count": 0,
|
| 53 |
+
"latest-timestamp": cutoff,
|
| 54 |
+
})
|
| 55 |
+
return
|
| 56 |
+
}
|
| 57 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log files: %v", err)})
|
| 58 |
+
return
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
limit, errLimit := parseLimit(c.Query("limit"))
|
| 62 |
+
if errLimit != nil {
|
| 63 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
|
| 64 |
+
return
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
cutoff := parseCutoff(c.Query("after"))
|
| 68 |
+
acc := newLogAccumulator(cutoff, limit)
|
| 69 |
+
for i := range files {
|
| 70 |
+
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
|
| 71 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
|
| 72 |
+
return
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
lines, total, latest := acc.result()
|
| 77 |
+
if latest == 0 || latest < cutoff {
|
| 78 |
+
latest = cutoff
|
| 79 |
+
}
|
| 80 |
+
c.JSON(http.StatusOK, gin.H{
|
| 81 |
+
"lines": lines,
|
| 82 |
+
"line-count": total,
|
| 83 |
+
"latest-timestamp": latest,
|
| 84 |
+
})
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// DeleteLogs removes all rotated log files and truncates the active log.
|
| 88 |
+
func (h *Handler) DeleteLogs(c *gin.Context) {
|
| 89 |
+
if h == nil {
|
| 90 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
| 91 |
+
return
|
| 92 |
+
}
|
| 93 |
+
if h.cfg == nil {
|
| 94 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
| 95 |
+
return
|
| 96 |
+
}
|
| 97 |
+
if !h.cfg.LoggingToFile {
|
| 98 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
|
| 99 |
+
return
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
dir := h.logDirectory()
|
| 103 |
+
if strings.TrimSpace(dir) == "" {
|
| 104 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
| 105 |
+
return
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
entries, err := os.ReadDir(dir)
|
| 109 |
+
if err != nil {
|
| 110 |
+
if os.IsNotExist(err) {
|
| 111 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
| 112 |
+
return
|
| 113 |
+
}
|
| 114 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
| 115 |
+
return
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
removed := 0
|
| 119 |
+
for _, entry := range entries {
|
| 120 |
+
if entry.IsDir() {
|
| 121 |
+
continue
|
| 122 |
+
}
|
| 123 |
+
name := entry.Name()
|
| 124 |
+
fullPath := filepath.Join(dir, name)
|
| 125 |
+
if name == defaultLogFileName {
|
| 126 |
+
if errTrunc := os.Truncate(fullPath, 0); errTrunc != nil && !os.IsNotExist(errTrunc) {
|
| 127 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to truncate log file: %v", errTrunc)})
|
| 128 |
+
return
|
| 129 |
+
}
|
| 130 |
+
continue
|
| 131 |
+
}
|
| 132 |
+
if isRotatedLogFile(name) {
|
| 133 |
+
if errRemove := os.Remove(fullPath); errRemove != nil && !os.IsNotExist(errRemove) {
|
| 134 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to remove %s: %v", name, errRemove)})
|
| 135 |
+
return
|
| 136 |
+
}
|
| 137 |
+
removed++
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
c.JSON(http.StatusOK, gin.H{
|
| 142 |
+
"success": true,
|
| 143 |
+
"message": "Logs cleared successfully",
|
| 144 |
+
"removed": removed,
|
| 145 |
+
})
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// GetRequestErrorLogs lists error request log files when RequestLog is disabled.
|
| 149 |
+
// It returns an empty list when RequestLog is enabled.
|
| 150 |
+
func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
| 151 |
+
if h == nil {
|
| 152 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
| 153 |
+
return
|
| 154 |
+
}
|
| 155 |
+
if h.cfg == nil {
|
| 156 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
| 157 |
+
return
|
| 158 |
+
}
|
| 159 |
+
if h.cfg.RequestLog {
|
| 160 |
+
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
| 161 |
+
return
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
dir := h.logDirectory()
|
| 165 |
+
if strings.TrimSpace(dir) == "" {
|
| 166 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
| 167 |
+
return
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
entries, err := os.ReadDir(dir)
|
| 171 |
+
if err != nil {
|
| 172 |
+
if os.IsNotExist(err) {
|
| 173 |
+
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
| 174 |
+
return
|
| 175 |
+
}
|
| 176 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
|
| 177 |
+
return
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
type errorLog struct {
|
| 181 |
+
Name string `json:"name"`
|
| 182 |
+
Size int64 `json:"size"`
|
| 183 |
+
Modified int64 `json:"modified"`
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
files := make([]errorLog, 0, len(entries))
|
| 187 |
+
for _, entry := range entries {
|
| 188 |
+
if entry.IsDir() {
|
| 189 |
+
continue
|
| 190 |
+
}
|
| 191 |
+
name := entry.Name()
|
| 192 |
+
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
| 193 |
+
continue
|
| 194 |
+
}
|
| 195 |
+
info, errInfo := entry.Info()
|
| 196 |
+
if errInfo != nil {
|
| 197 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
|
| 198 |
+
return
|
| 199 |
+
}
|
| 200 |
+
files = append(files, errorLog{
|
| 201 |
+
Name: name,
|
| 202 |
+
Size: info.Size(),
|
| 203 |
+
Modified: info.ModTime().Unix(),
|
| 204 |
+
})
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
|
| 208 |
+
|
| 209 |
+
c.JSON(http.StatusOK, gin.H{"files": files})
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
| 213 |
+
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
| 214 |
+
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
| 215 |
+
if h == nil {
|
| 216 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
| 217 |
+
return
|
| 218 |
+
}
|
| 219 |
+
if h.cfg == nil {
|
| 220 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
| 221 |
+
return
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
dir := h.logDirectory()
|
| 225 |
+
if strings.TrimSpace(dir) == "" {
|
| 226 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
| 227 |
+
return
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
requestID := strings.TrimSpace(c.Param("id"))
|
| 231 |
+
if requestID == "" {
|
| 232 |
+
requestID = strings.TrimSpace(c.Query("id"))
|
| 233 |
+
}
|
| 234 |
+
if requestID == "" {
|
| 235 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
| 236 |
+
return
|
| 237 |
+
}
|
| 238 |
+
if strings.ContainsAny(requestID, "/\\") {
|
| 239 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
| 240 |
+
return
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
entries, err := os.ReadDir(dir)
|
| 244 |
+
if err != nil {
|
| 245 |
+
if os.IsNotExist(err) {
|
| 246 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
| 247 |
+
return
|
| 248 |
+
}
|
| 249 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
| 250 |
+
return
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
suffix := "-" + requestID + ".log"
|
| 254 |
+
var matchedFile string
|
| 255 |
+
for _, entry := range entries {
|
| 256 |
+
if entry.IsDir() {
|
| 257 |
+
continue
|
| 258 |
+
}
|
| 259 |
+
name := entry.Name()
|
| 260 |
+
if strings.HasSuffix(name, suffix) {
|
| 261 |
+
matchedFile = name
|
| 262 |
+
break
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
if matchedFile == "" {
|
| 267 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
| 268 |
+
return
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
dirAbs, errAbs := filepath.Abs(dir)
|
| 272 |
+
if errAbs != nil {
|
| 273 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
| 274 |
+
return
|
| 275 |
+
}
|
| 276 |
+
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
| 277 |
+
prefix := dirAbs + string(os.PathSeparator)
|
| 278 |
+
if !strings.HasPrefix(fullPath, prefix) {
|
| 279 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
| 280 |
+
return
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
info, errStat := os.Stat(fullPath)
|
| 284 |
+
if errStat != nil {
|
| 285 |
+
if os.IsNotExist(errStat) {
|
| 286 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
| 287 |
+
return
|
| 288 |
+
}
|
| 289 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
| 290 |
+
return
|
| 291 |
+
}
|
| 292 |
+
if info.IsDir() {
|
| 293 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
| 294 |
+
return
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
c.FileAttachment(fullPath, matchedFile)
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
| 301 |
+
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
| 302 |
+
if h == nil {
|
| 303 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
| 304 |
+
return
|
| 305 |
+
}
|
| 306 |
+
if h.cfg == nil {
|
| 307 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
| 308 |
+
return
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
dir := h.logDirectory()
|
| 312 |
+
if strings.TrimSpace(dir) == "" {
|
| 313 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
| 314 |
+
return
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
name := strings.TrimSpace(c.Param("name"))
|
| 318 |
+
if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
|
| 319 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
|
| 320 |
+
return
|
| 321 |
+
}
|
| 322 |
+
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
| 323 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
| 324 |
+
return
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
dirAbs, errAbs := filepath.Abs(dir)
|
| 328 |
+
if errAbs != nil {
|
| 329 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
| 330 |
+
return
|
| 331 |
+
}
|
| 332 |
+
fullPath := filepath.Clean(filepath.Join(dirAbs, name))
|
| 333 |
+
prefix := dirAbs + string(os.PathSeparator)
|
| 334 |
+
if !strings.HasPrefix(fullPath, prefix) {
|
| 335 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
| 336 |
+
return
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
info, errStat := os.Stat(fullPath)
|
| 340 |
+
if errStat != nil {
|
| 341 |
+
if os.IsNotExist(errStat) {
|
| 342 |
+
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
| 343 |
+
return
|
| 344 |
+
}
|
| 345 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
| 346 |
+
return
|
| 347 |
+
}
|
| 348 |
+
if info.IsDir() {
|
| 349 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
| 350 |
+
return
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
c.FileAttachment(fullPath, name)
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
func (h *Handler) logDirectory() string {
|
| 357 |
+
if h == nil {
|
| 358 |
+
return ""
|
| 359 |
+
}
|
| 360 |
+
if h.logDir != "" {
|
| 361 |
+
return h.logDir
|
| 362 |
+
}
|
| 363 |
+
if base := util.WritablePath(); base != "" {
|
| 364 |
+
return filepath.Join(base, "logs")
|
| 365 |
+
}
|
| 366 |
+
if h.configFilePath != "" {
|
| 367 |
+
dir := filepath.Dir(h.configFilePath)
|
| 368 |
+
if dir != "" && dir != "." {
|
| 369 |
+
return filepath.Join(dir, "logs")
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
return "logs"
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
| 376 |
+
entries, err := os.ReadDir(dir)
|
| 377 |
+
if err != nil {
|
| 378 |
+
return nil, err
|
| 379 |
+
}
|
| 380 |
+
type candidate struct {
|
| 381 |
+
path string
|
| 382 |
+
order int64
|
| 383 |
+
}
|
| 384 |
+
cands := make([]candidate, 0, len(entries))
|
| 385 |
+
for _, entry := range entries {
|
| 386 |
+
if entry.IsDir() {
|
| 387 |
+
continue
|
| 388 |
+
}
|
| 389 |
+
name := entry.Name()
|
| 390 |
+
if name == defaultLogFileName {
|
| 391 |
+
cands = append(cands, candidate{path: filepath.Join(dir, name), order: 0})
|
| 392 |
+
continue
|
| 393 |
+
}
|
| 394 |
+
if order, ok := rotationOrder(name); ok {
|
| 395 |
+
cands = append(cands, candidate{path: filepath.Join(dir, name), order: order})
|
| 396 |
+
}
|
| 397 |
+
}
|
| 398 |
+
if len(cands) == 0 {
|
| 399 |
+
return []string{}, nil
|
| 400 |
+
}
|
| 401 |
+
sort.Slice(cands, func(i, j int) bool { return cands[i].order < cands[j].order })
|
| 402 |
+
paths := make([]string, 0, len(cands))
|
| 403 |
+
for i := len(cands) - 1; i >= 0; i-- {
|
| 404 |
+
paths = append(paths, cands[i].path)
|
| 405 |
+
}
|
| 406 |
+
return paths, nil
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
type logAccumulator struct {
|
| 410 |
+
cutoff int64
|
| 411 |
+
limit int
|
| 412 |
+
lines []string
|
| 413 |
+
total int
|
| 414 |
+
latest int64
|
| 415 |
+
include bool
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
|
| 419 |
+
capacity := 256
|
| 420 |
+
if limit > 0 && limit < capacity {
|
| 421 |
+
capacity = limit
|
| 422 |
+
}
|
| 423 |
+
return &logAccumulator{
|
| 424 |
+
cutoff: cutoff,
|
| 425 |
+
limit: limit,
|
| 426 |
+
lines: make([]string, 0, capacity),
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
func (acc *logAccumulator) consumeFile(path string) error {
|
| 431 |
+
file, err := os.Open(path)
|
| 432 |
+
if err != nil {
|
| 433 |
+
if os.IsNotExist(err) {
|
| 434 |
+
return nil
|
| 435 |
+
}
|
| 436 |
+
return err
|
| 437 |
+
}
|
| 438 |
+
defer func() {
|
| 439 |
+
_ = file.Close()
|
| 440 |
+
}()
|
| 441 |
+
|
| 442 |
+
scanner := bufio.NewScanner(file)
|
| 443 |
+
buf := make([]byte, 0, logScannerInitialBuffer)
|
| 444 |
+
scanner.Buffer(buf, logScannerMaxBuffer)
|
| 445 |
+
for scanner.Scan() {
|
| 446 |
+
acc.addLine(scanner.Text())
|
| 447 |
+
}
|
| 448 |
+
if errScan := scanner.Err(); errScan != nil {
|
| 449 |
+
return errScan
|
| 450 |
+
}
|
| 451 |
+
return nil
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
func (acc *logAccumulator) addLine(raw string) {
|
| 455 |
+
line := strings.TrimRight(raw, "\r")
|
| 456 |
+
acc.total++
|
| 457 |
+
ts := parseTimestamp(line)
|
| 458 |
+
if ts > acc.latest {
|
| 459 |
+
acc.latest = ts
|
| 460 |
+
}
|
| 461 |
+
if ts > 0 {
|
| 462 |
+
acc.include = acc.cutoff == 0 || ts > acc.cutoff
|
| 463 |
+
if acc.cutoff == 0 || acc.include {
|
| 464 |
+
acc.append(line)
|
| 465 |
+
}
|
| 466 |
+
return
|
| 467 |
+
}
|
| 468 |
+
if acc.cutoff == 0 || acc.include {
|
| 469 |
+
acc.append(line)
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
func (acc *logAccumulator) append(line string) {
|
| 474 |
+
acc.lines = append(acc.lines, line)
|
| 475 |
+
if acc.limit > 0 && len(acc.lines) > acc.limit {
|
| 476 |
+
acc.lines = acc.lines[len(acc.lines)-acc.limit:]
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
func (acc *logAccumulator) result() ([]string, int, int64) {
|
| 481 |
+
if acc.lines == nil {
|
| 482 |
+
acc.lines = []string{}
|
| 483 |
+
}
|
| 484 |
+
return acc.lines, acc.total, acc.latest
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
func parseCutoff(raw string) int64 {
|
| 488 |
+
value := strings.TrimSpace(raw)
|
| 489 |
+
if value == "" {
|
| 490 |
+
return 0
|
| 491 |
+
}
|
| 492 |
+
ts, err := strconv.ParseInt(value, 10, 64)
|
| 493 |
+
if err != nil || ts <= 0 {
|
| 494 |
+
return 0
|
| 495 |
+
}
|
| 496 |
+
return ts
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
func parseLimit(raw string) (int, error) {
|
| 500 |
+
value := strings.TrimSpace(raw)
|
| 501 |
+
if value == "" {
|
| 502 |
+
return 0, nil
|
| 503 |
+
}
|
| 504 |
+
limit, err := strconv.Atoi(value)
|
| 505 |
+
if err != nil {
|
| 506 |
+
return 0, fmt.Errorf("must be a positive integer")
|
| 507 |
+
}
|
| 508 |
+
if limit <= 0 {
|
| 509 |
+
return 0, fmt.Errorf("must be greater than zero")
|
| 510 |
+
}
|
| 511 |
+
return limit, nil
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
func parseTimestamp(line string) int64 {
|
| 515 |
+
if strings.HasPrefix(line, "[") {
|
| 516 |
+
line = line[1:]
|
| 517 |
+
}
|
| 518 |
+
if len(line) < 19 {
|
| 519 |
+
return 0
|
| 520 |
+
}
|
| 521 |
+
candidate := line[:19]
|
| 522 |
+
t, err := time.ParseInLocation("2006-01-02 15:04:05", candidate, time.Local)
|
| 523 |
+
if err != nil {
|
| 524 |
+
return 0
|
| 525 |
+
}
|
| 526 |
+
return t.Unix()
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
func isRotatedLogFile(name string) bool {
|
| 530 |
+
if _, ok := rotationOrder(name); ok {
|
| 531 |
+
return true
|
| 532 |
+
}
|
| 533 |
+
return false
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
func rotationOrder(name string) (int64, bool) {
|
| 537 |
+
if order, ok := numericRotationOrder(name); ok {
|
| 538 |
+
return order, true
|
| 539 |
+
}
|
| 540 |
+
if order, ok := timestampRotationOrder(name); ok {
|
| 541 |
+
return order, true
|
| 542 |
+
}
|
| 543 |
+
return 0, false
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
func numericRotationOrder(name string) (int64, bool) {
|
| 547 |
+
if !strings.HasPrefix(name, defaultLogFileName+".") {
|
| 548 |
+
return 0, false
|
| 549 |
+
}
|
| 550 |
+
suffix := strings.TrimPrefix(name, defaultLogFileName+".")
|
| 551 |
+
if suffix == "" {
|
| 552 |
+
return 0, false
|
| 553 |
+
}
|
| 554 |
+
n, err := strconv.Atoi(suffix)
|
| 555 |
+
if err != nil {
|
| 556 |
+
return 0, false
|
| 557 |
+
}
|
| 558 |
+
return int64(n), true
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
func timestampRotationOrder(name string) (int64, bool) {
|
| 562 |
+
ext := filepath.Ext(defaultLogFileName)
|
| 563 |
+
base := strings.TrimSuffix(defaultLogFileName, ext)
|
| 564 |
+
if base == "" {
|
| 565 |
+
return 0, false
|
| 566 |
+
}
|
| 567 |
+
prefix := base + "-"
|
| 568 |
+
if !strings.HasPrefix(name, prefix) {
|
| 569 |
+
return 0, false
|
| 570 |
+
}
|
| 571 |
+
clean := strings.TrimPrefix(name, prefix)
|
| 572 |
+
if strings.HasSuffix(clean, ".gz") {
|
| 573 |
+
clean = strings.TrimSuffix(clean, ".gz")
|
| 574 |
+
}
|
| 575 |
+
if ext != "" {
|
| 576 |
+
if !strings.HasSuffix(clean, ext) {
|
| 577 |
+
return 0, false
|
| 578 |
+
}
|
| 579 |
+
clean = strings.TrimSuffix(clean, ext)
|
| 580 |
+
}
|
| 581 |
+
if clean == "" {
|
| 582 |
+
return 0, false
|
| 583 |
+
}
|
| 584 |
+
if idx := strings.IndexByte(clean, '.'); idx != -1 {
|
| 585 |
+
clean = clean[:idx]
|
| 586 |
+
}
|
| 587 |
+
parsed, err := time.ParseInLocation("2006-01-02T15-04-05", clean, time.Local)
|
| 588 |
+
if err != nil {
|
| 589 |
+
return 0, false
|
| 590 |
+
}
|
| 591 |
+
return math.MaxInt64 - parsed.Unix(), true
|
| 592 |
+
}
|
internal/api/handlers/management/oauth_callback.go
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"errors"
|
| 5 |
+
"net/http"
|
| 6 |
+
"net/url"
|
| 7 |
+
"strings"
|
| 8 |
+
|
| 9 |
+
"github.com/gin-gonic/gin"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
type oauthCallbackRequest struct {
|
| 13 |
+
Provider string `json:"provider"`
|
| 14 |
+
RedirectURL string `json:"redirect_url"`
|
| 15 |
+
Code string `json:"code"`
|
| 16 |
+
State string `json:"state"`
|
| 17 |
+
Error string `json:"error"`
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
| 21 |
+
if h == nil || h.cfg == nil {
|
| 22 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
| 23 |
+
return
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
var req oauthCallbackRequest
|
| 27 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 28 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
| 29 |
+
return
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
| 33 |
+
if err != nil {
|
| 34 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
| 35 |
+
return
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
state := strings.TrimSpace(req.State)
|
| 39 |
+
code := strings.TrimSpace(req.Code)
|
| 40 |
+
errMsg := strings.TrimSpace(req.Error)
|
| 41 |
+
|
| 42 |
+
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
| 43 |
+
u, errParse := url.Parse(rawRedirect)
|
| 44 |
+
if errParse != nil {
|
| 45 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
| 46 |
+
return
|
| 47 |
+
}
|
| 48 |
+
q := u.Query()
|
| 49 |
+
if state == "" {
|
| 50 |
+
state = strings.TrimSpace(q.Get("state"))
|
| 51 |
+
}
|
| 52 |
+
if code == "" {
|
| 53 |
+
code = strings.TrimSpace(q.Get("code"))
|
| 54 |
+
}
|
| 55 |
+
if errMsg == "" {
|
| 56 |
+
errMsg = strings.TrimSpace(q.Get("error"))
|
| 57 |
+
if errMsg == "" {
|
| 58 |
+
errMsg = strings.TrimSpace(q.Get("error_description"))
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if state == "" {
|
| 64 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
| 65 |
+
return
|
| 66 |
+
}
|
| 67 |
+
if err := ValidateOAuthState(state); err != nil {
|
| 68 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
| 69 |
+
return
|
| 70 |
+
}
|
| 71 |
+
if code == "" && errMsg == "" {
|
| 72 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
| 73 |
+
return
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
| 77 |
+
if !ok {
|
| 78 |
+
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
| 79 |
+
return
|
| 80 |
+
}
|
| 81 |
+
if sessionStatus != "" {
|
| 82 |
+
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
| 83 |
+
return
|
| 84 |
+
}
|
| 85 |
+
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
| 86 |
+
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
| 87 |
+
return
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
| 91 |
+
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
| 92 |
+
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
| 93 |
+
return
|
| 94 |
+
}
|
| 95 |
+
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
| 96 |
+
return
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 100 |
+
}
|
internal/api/handlers/management/oauth_sessions.go
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"errors"
|
| 6 |
+
"fmt"
|
| 7 |
+
"os"
|
| 8 |
+
"path/filepath"
|
| 9 |
+
"strings"
|
| 10 |
+
"sync"
|
| 11 |
+
"time"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
const (
|
| 15 |
+
oauthSessionTTL = 10 * time.Minute
|
| 16 |
+
maxOAuthStateLength = 128
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
var (
|
| 20 |
+
errInvalidOAuthState = errors.New("invalid oauth state")
|
| 21 |
+
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
| 22 |
+
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
type oauthSession struct {
|
| 26 |
+
Provider string
|
| 27 |
+
Status string
|
| 28 |
+
CreatedAt time.Time
|
| 29 |
+
ExpiresAt time.Time
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
type oauthSessionStore struct {
|
| 33 |
+
mu sync.RWMutex
|
| 34 |
+
ttl time.Duration
|
| 35 |
+
sessions map[string]oauthSession
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
| 39 |
+
if ttl <= 0 {
|
| 40 |
+
ttl = oauthSessionTTL
|
| 41 |
+
}
|
| 42 |
+
return &oauthSessionStore{
|
| 43 |
+
ttl: ttl,
|
| 44 |
+
sessions: make(map[string]oauthSession),
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
| 49 |
+
for state, session := range s.sessions {
|
| 50 |
+
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
| 51 |
+
delete(s.sessions, state)
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func (s *oauthSessionStore) Register(state, provider string) {
|
| 57 |
+
state = strings.TrimSpace(state)
|
| 58 |
+
provider = strings.ToLower(strings.TrimSpace(provider))
|
| 59 |
+
if state == "" || provider == "" {
|
| 60 |
+
return
|
| 61 |
+
}
|
| 62 |
+
now := time.Now()
|
| 63 |
+
|
| 64 |
+
s.mu.Lock()
|
| 65 |
+
defer s.mu.Unlock()
|
| 66 |
+
|
| 67 |
+
s.purgeExpiredLocked(now)
|
| 68 |
+
s.sessions[state] = oauthSession{
|
| 69 |
+
Provider: provider,
|
| 70 |
+
Status: "",
|
| 71 |
+
CreatedAt: now,
|
| 72 |
+
ExpiresAt: now.Add(s.ttl),
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
func (s *oauthSessionStore) SetError(state, message string) {
|
| 77 |
+
state = strings.TrimSpace(state)
|
| 78 |
+
message = strings.TrimSpace(message)
|
| 79 |
+
if state == "" {
|
| 80 |
+
return
|
| 81 |
+
}
|
| 82 |
+
if message == "" {
|
| 83 |
+
message = "Authentication failed"
|
| 84 |
+
}
|
| 85 |
+
now := time.Now()
|
| 86 |
+
|
| 87 |
+
s.mu.Lock()
|
| 88 |
+
defer s.mu.Unlock()
|
| 89 |
+
|
| 90 |
+
s.purgeExpiredLocked(now)
|
| 91 |
+
session, ok := s.sessions[state]
|
| 92 |
+
if !ok {
|
| 93 |
+
return
|
| 94 |
+
}
|
| 95 |
+
session.Status = message
|
| 96 |
+
session.ExpiresAt = now.Add(s.ttl)
|
| 97 |
+
s.sessions[state] = session
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
func (s *oauthSessionStore) Complete(state string) {
|
| 101 |
+
state = strings.TrimSpace(state)
|
| 102 |
+
if state == "" {
|
| 103 |
+
return
|
| 104 |
+
}
|
| 105 |
+
now := time.Now()
|
| 106 |
+
|
| 107 |
+
s.mu.Lock()
|
| 108 |
+
defer s.mu.Unlock()
|
| 109 |
+
|
| 110 |
+
s.purgeExpiredLocked(now)
|
| 111 |
+
delete(s.sessions, state)
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
func (s *oauthSessionStore) CompleteProvider(provider string) int {
|
| 115 |
+
provider = strings.ToLower(strings.TrimSpace(provider))
|
| 116 |
+
if provider == "" {
|
| 117 |
+
return 0
|
| 118 |
+
}
|
| 119 |
+
now := time.Now()
|
| 120 |
+
|
| 121 |
+
s.mu.Lock()
|
| 122 |
+
defer s.mu.Unlock()
|
| 123 |
+
|
| 124 |
+
s.purgeExpiredLocked(now)
|
| 125 |
+
removed := 0
|
| 126 |
+
for state, session := range s.sessions {
|
| 127 |
+
if strings.EqualFold(session.Provider, provider) {
|
| 128 |
+
delete(s.sessions, state)
|
| 129 |
+
removed++
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
return removed
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
| 136 |
+
state = strings.TrimSpace(state)
|
| 137 |
+
now := time.Now()
|
| 138 |
+
|
| 139 |
+
s.mu.Lock()
|
| 140 |
+
defer s.mu.Unlock()
|
| 141 |
+
|
| 142 |
+
s.purgeExpiredLocked(now)
|
| 143 |
+
session, ok := s.sessions[state]
|
| 144 |
+
return session, ok
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
| 148 |
+
state = strings.TrimSpace(state)
|
| 149 |
+
provider = strings.ToLower(strings.TrimSpace(provider))
|
| 150 |
+
now := time.Now()
|
| 151 |
+
|
| 152 |
+
s.mu.Lock()
|
| 153 |
+
defer s.mu.Unlock()
|
| 154 |
+
|
| 155 |
+
s.purgeExpiredLocked(now)
|
| 156 |
+
session, ok := s.sessions[state]
|
| 157 |
+
if !ok {
|
| 158 |
+
return false
|
| 159 |
+
}
|
| 160 |
+
if session.Status != "" {
|
| 161 |
+
if !strings.EqualFold(session.Provider, "kiro") {
|
| 162 |
+
return false
|
| 163 |
+
}
|
| 164 |
+
if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
|
| 165 |
+
return false
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
if provider == "" {
|
| 169 |
+
return true
|
| 170 |
+
}
|
| 171 |
+
return strings.EqualFold(session.Provider, provider)
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
| 175 |
+
|
| 176 |
+
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
| 177 |
+
|
| 178 |
+
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
| 179 |
+
|
| 180 |
+
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
| 181 |
+
|
| 182 |
+
func CompleteOAuthSessionsByProvider(provider string) int {
|
| 183 |
+
return oauthSessions.CompleteProvider(provider)
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
| 187 |
+
session, ok := oauthSessions.Get(state)
|
| 188 |
+
if !ok {
|
| 189 |
+
return "", "", false
|
| 190 |
+
}
|
| 191 |
+
return session.Provider, session.Status, true
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
func IsOAuthSessionPending(state, provider string) bool {
|
| 195 |
+
return oauthSessions.IsPending(state, provider)
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
func ValidateOAuthState(state string) error {
|
| 199 |
+
trimmed := strings.TrimSpace(state)
|
| 200 |
+
if trimmed == "" {
|
| 201 |
+
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
| 202 |
+
}
|
| 203 |
+
if len(trimmed) > maxOAuthStateLength {
|
| 204 |
+
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
| 205 |
+
}
|
| 206 |
+
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
| 207 |
+
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
| 208 |
+
}
|
| 209 |
+
if strings.Contains(trimmed, "..") {
|
| 210 |
+
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
| 211 |
+
}
|
| 212 |
+
for _, r := range trimmed {
|
| 213 |
+
switch {
|
| 214 |
+
case r >= 'a' && r <= 'z':
|
| 215 |
+
case r >= 'A' && r <= 'Z':
|
| 216 |
+
case r >= '0' && r <= '9':
|
| 217 |
+
case r == '-' || r == '_' || r == '.':
|
| 218 |
+
default:
|
| 219 |
+
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
return nil
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
func NormalizeOAuthProvider(provider string) (string, error) {
|
| 226 |
+
switch strings.ToLower(strings.TrimSpace(provider)) {
|
| 227 |
+
case "anthropic", "claude":
|
| 228 |
+
return "anthropic", nil
|
| 229 |
+
case "codex", "openai":
|
| 230 |
+
return "codex", nil
|
| 231 |
+
case "gemini", "google":
|
| 232 |
+
return "gemini", nil
|
| 233 |
+
case "iflow", "i-flow":
|
| 234 |
+
return "iflow", nil
|
| 235 |
+
case "antigravity", "anti-gravity":
|
| 236 |
+
return "antigravity", nil
|
| 237 |
+
case "qwen":
|
| 238 |
+
return "qwen", nil
|
| 239 |
+
case "kiro":
|
| 240 |
+
return "kiro", nil
|
| 241 |
+
default:
|
| 242 |
+
return "", errUnsupportedOAuthFlow
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
type oauthCallbackFilePayload struct {
|
| 247 |
+
Code string `json:"code"`
|
| 248 |
+
State string `json:"state"`
|
| 249 |
+
Error string `json:"error"`
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
| 253 |
+
if strings.TrimSpace(authDir) == "" {
|
| 254 |
+
return "", fmt.Errorf("auth dir is empty")
|
| 255 |
+
}
|
| 256 |
+
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
| 257 |
+
if err != nil {
|
| 258 |
+
return "", err
|
| 259 |
+
}
|
| 260 |
+
if err := ValidateOAuthState(state); err != nil {
|
| 261 |
+
return "", err
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
| 265 |
+
filePath := filepath.Join(authDir, fileName)
|
| 266 |
+
payload := oauthCallbackFilePayload{
|
| 267 |
+
Code: strings.TrimSpace(code),
|
| 268 |
+
State: strings.TrimSpace(state),
|
| 269 |
+
Error: strings.TrimSpace(errorMessage),
|
| 270 |
+
}
|
| 271 |
+
data, err := json.Marshal(payload)
|
| 272 |
+
if err != nil {
|
| 273 |
+
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
| 274 |
+
}
|
| 275 |
+
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
| 276 |
+
return "", fmt.Errorf("write oauth callback file: %w", err)
|
| 277 |
+
}
|
| 278 |
+
return filePath, nil
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
| 282 |
+
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
| 283 |
+
if err != nil {
|
| 284 |
+
return "", err
|
| 285 |
+
}
|
| 286 |
+
if !IsOAuthSessionPending(state, canonicalProvider) {
|
| 287 |
+
return "", errOAuthSessionNotPending
|
| 288 |
+
}
|
| 289 |
+
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
| 290 |
+
}
|
internal/api/handlers/management/quota.go
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import "github.com/gin-gonic/gin"
|
| 4 |
+
|
| 5 |
+
// Quota exceeded toggles
|
| 6 |
+
func (h *Handler) GetSwitchProject(c *gin.Context) {
|
| 7 |
+
c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject})
|
| 8 |
+
}
|
| 9 |
+
func (h *Handler) PutSwitchProject(c *gin.Context) {
|
| 10 |
+
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v })
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
func (h *Handler) GetSwitchPreviewModel(c *gin.Context) {
|
| 14 |
+
c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel})
|
| 15 |
+
}
|
| 16 |
+
func (h *Handler) PutSwitchPreviewModel(c *gin.Context) {
|
| 17 |
+
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v })
|
| 18 |
+
}
|
internal/api/handlers/management/usage.go
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"net/http"
|
| 6 |
+
"time"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
type usageExportPayload struct {
|
| 13 |
+
Version int `json:"version"`
|
| 14 |
+
ExportedAt time.Time `json:"exported_at"`
|
| 15 |
+
Usage usage.StatisticsSnapshot `json:"usage"`
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
type usageImportPayload struct {
|
| 19 |
+
Version int `json:"version"`
|
| 20 |
+
Usage usage.StatisticsSnapshot `json:"usage"`
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
| 24 |
+
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
| 25 |
+
var snapshot usage.StatisticsSnapshot
|
| 26 |
+
if h != nil && h.usageStats != nil {
|
| 27 |
+
snapshot = h.usageStats.Snapshot()
|
| 28 |
+
}
|
| 29 |
+
c.JSON(http.StatusOK, gin.H{
|
| 30 |
+
"usage": snapshot,
|
| 31 |
+
"failed_requests": snapshot.FailureCount,
|
| 32 |
+
})
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
| 36 |
+
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
| 37 |
+
var snapshot usage.StatisticsSnapshot
|
| 38 |
+
if h != nil && h.usageStats != nil {
|
| 39 |
+
snapshot = h.usageStats.Snapshot()
|
| 40 |
+
}
|
| 41 |
+
c.JSON(http.StatusOK, usageExportPayload{
|
| 42 |
+
Version: 1,
|
| 43 |
+
ExportedAt: time.Now().UTC(),
|
| 44 |
+
Usage: snapshot,
|
| 45 |
+
})
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
| 49 |
+
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
| 50 |
+
if h == nil || h.usageStats == nil {
|
| 51 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
| 52 |
+
return
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
data, err := c.GetRawData()
|
| 56 |
+
if err != nil {
|
| 57 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
| 58 |
+
return
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
var payload usageImportPayload
|
| 62 |
+
if err := json.Unmarshal(data, &payload); err != nil {
|
| 63 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
| 64 |
+
return
|
| 65 |
+
}
|
| 66 |
+
if payload.Version != 0 && payload.Version != 1 {
|
| 67 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
| 68 |
+
return
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
result := h.usageStats.MergeSnapshot(payload.Usage)
|
| 72 |
+
snapshot := h.usageStats.Snapshot()
|
| 73 |
+
c.JSON(http.StatusOK, gin.H{
|
| 74 |
+
"added": result.Added,
|
| 75 |
+
"skipped": result.Skipped,
|
| 76 |
+
"total_requests": snapshot.TotalRequests,
|
| 77 |
+
"failed_requests": snapshot.FailureCount,
|
| 78 |
+
})
|
| 79 |
+
}
|
internal/api/handlers/management/vertex_import.go
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package management
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"net/http"
|
| 9 |
+
"strings"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
| 13 |
+
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
|
| 17 |
+
func (h *Handler) ImportVertexCredential(c *gin.Context) {
|
| 18 |
+
if h == nil || h.cfg == nil {
|
| 19 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
|
| 20 |
+
return
|
| 21 |
+
}
|
| 22 |
+
if h.cfg.AuthDir == "" {
|
| 23 |
+
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
|
| 24 |
+
return
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
fileHeader, err := c.FormFile("file")
|
| 28 |
+
if err != nil {
|
| 29 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
|
| 30 |
+
return
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
file, err := fileHeader.Open()
|
| 34 |
+
if err != nil {
|
| 35 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
| 36 |
+
return
|
| 37 |
+
}
|
| 38 |
+
defer file.Close()
|
| 39 |
+
|
| 40 |
+
data, err := io.ReadAll(file)
|
| 41 |
+
if err != nil {
|
| 42 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
| 43 |
+
return
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
var serviceAccount map[string]any
|
| 47 |
+
if err := json.Unmarshal(data, &serviceAccount); err != nil {
|
| 48 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
|
| 49 |
+
return
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
|
| 53 |
+
if err != nil {
|
| 54 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
|
| 55 |
+
return
|
| 56 |
+
}
|
| 57 |
+
serviceAccount = normalizedSA
|
| 58 |
+
|
| 59 |
+
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
|
| 60 |
+
if projectID == "" {
|
| 61 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
|
| 62 |
+
return
|
| 63 |
+
}
|
| 64 |
+
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
|
| 65 |
+
|
| 66 |
+
location := strings.TrimSpace(c.PostForm("location"))
|
| 67 |
+
if location == "" {
|
| 68 |
+
location = strings.TrimSpace(c.Query("location"))
|
| 69 |
+
}
|
| 70 |
+
if location == "" {
|
| 71 |
+
location = "us-central1"
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
|
| 75 |
+
label := labelForVertex(projectID, email)
|
| 76 |
+
storage := &vertex.VertexCredentialStorage{
|
| 77 |
+
ServiceAccount: serviceAccount,
|
| 78 |
+
ProjectID: projectID,
|
| 79 |
+
Email: email,
|
| 80 |
+
Location: location,
|
| 81 |
+
Type: "vertex",
|
| 82 |
+
}
|
| 83 |
+
metadata := map[string]any{
|
| 84 |
+
"service_account": serviceAccount,
|
| 85 |
+
"project_id": projectID,
|
| 86 |
+
"email": email,
|
| 87 |
+
"location": location,
|
| 88 |
+
"type": "vertex",
|
| 89 |
+
"label": label,
|
| 90 |
+
}
|
| 91 |
+
record := &coreauth.Auth{
|
| 92 |
+
ID: fileName,
|
| 93 |
+
Provider: "vertex",
|
| 94 |
+
FileName: fileName,
|
| 95 |
+
Storage: storage,
|
| 96 |
+
Label: label,
|
| 97 |
+
Metadata: metadata,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
ctx := context.Background()
|
| 101 |
+
if reqCtx := c.Request.Context(); reqCtx != nil {
|
| 102 |
+
ctx = reqCtx
|
| 103 |
+
}
|
| 104 |
+
savedPath, err := h.saveTokenRecord(ctx, record)
|
| 105 |
+
if err != nil {
|
| 106 |
+
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
|
| 107 |
+
return
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
c.JSON(http.StatusOK, gin.H{
|
| 111 |
+
"status": "ok",
|
| 112 |
+
"auth-file": savedPath,
|
| 113 |
+
"project_id": projectID,
|
| 114 |
+
"email": email,
|
| 115 |
+
"location": location,
|
| 116 |
+
})
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
func valueAsString(v any) string {
|
| 120 |
+
if v == nil {
|
| 121 |
+
return ""
|
| 122 |
+
}
|
| 123 |
+
switch t := v.(type) {
|
| 124 |
+
case string:
|
| 125 |
+
return t
|
| 126 |
+
default:
|
| 127 |
+
return fmt.Sprint(t)
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
func sanitizeVertexFilePart(s string) string {
|
| 132 |
+
out := strings.TrimSpace(s)
|
| 133 |
+
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
| 134 |
+
for i := 0; i < len(replacers); i += 2 {
|
| 135 |
+
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
| 136 |
+
}
|
| 137 |
+
if out == "" {
|
| 138 |
+
return "vertex"
|
| 139 |
+
}
|
| 140 |
+
return out
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
func labelForVertex(projectID, email string) string {
|
| 144 |
+
p := strings.TrimSpace(projectID)
|
| 145 |
+
e := strings.TrimSpace(email)
|
| 146 |
+
if p != "" && e != "" {
|
| 147 |
+
return fmt.Sprintf("%s (%s)", p, e)
|
| 148 |
+
}
|
| 149 |
+
if p != "" {
|
| 150 |
+
return p
|
| 151 |
+
}
|
| 152 |
+
if e != "" {
|
| 153 |
+
return e
|
| 154 |
+
}
|
| 155 |
+
return "vertex"
|
| 156 |
+
}
|
internal/api/middleware/request_logging.go
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
|
| 2 |
+
// This file contains the request logging middleware that captures comprehensive
|
| 3 |
+
// request and response data when enabled through configuration.
|
| 4 |
+
package middleware
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"bytes"
|
| 8 |
+
"io"
|
| 9 |
+
"net/http"
|
| 10 |
+
"strings"
|
| 11 |
+
|
| 12 |
+
"github.com/gin-gonic/gin"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
| 14 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
| 18 |
+
// It captures detailed information about the request and response, including headers and body,
|
| 19 |
+
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
| 20 |
+
// logger, it still captures data so that upstream errors can be persisted.
|
| 21 |
+
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
| 22 |
+
return func(c *gin.Context) {
|
| 23 |
+
if logger == nil {
|
| 24 |
+
c.Next()
|
| 25 |
+
return
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
if c.Request.Method == http.MethodGet {
|
| 29 |
+
c.Next()
|
| 30 |
+
return
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
path := c.Request.URL.Path
|
| 34 |
+
if !shouldLogRequest(path) {
|
| 35 |
+
c.Next()
|
| 36 |
+
return
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Capture request information
|
| 40 |
+
requestInfo, err := captureRequestInfo(c)
|
| 41 |
+
if err != nil {
|
| 42 |
+
// Log error but continue processing
|
| 43 |
+
// In a real implementation, you might want to use a proper logger here
|
| 44 |
+
c.Next()
|
| 45 |
+
return
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Create response writer wrapper
|
| 49 |
+
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
| 50 |
+
if !logger.IsEnabled() {
|
| 51 |
+
wrapper.logOnErrorOnly = true
|
| 52 |
+
}
|
| 53 |
+
c.Writer = wrapper
|
| 54 |
+
|
| 55 |
+
// Process the request
|
| 56 |
+
c.Next()
|
| 57 |
+
|
| 58 |
+
// Finalize logging after request processing
|
| 59 |
+
if err = wrapper.Finalize(c); err != nil {
|
| 60 |
+
// Log error but don't interrupt the response
|
| 61 |
+
// In a real implementation, you might want to use a proper logger here
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
| 67 |
+
// It captures the URL, method, headers, and body. The request body is read and then
|
| 68 |
+
// restored so that it can be processed by subsequent handlers.
|
| 69 |
+
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
| 70 |
+
// Capture URL with sensitive query parameters masked
|
| 71 |
+
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
| 72 |
+
url := c.Request.URL.Path
|
| 73 |
+
if maskedQuery != "" {
|
| 74 |
+
url += "?" + maskedQuery
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// Capture method
|
| 78 |
+
method := c.Request.Method
|
| 79 |
+
|
| 80 |
+
// Capture headers
|
| 81 |
+
headers := make(map[string][]string)
|
| 82 |
+
for key, values := range c.Request.Header {
|
| 83 |
+
headers[key] = values
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Capture request body
|
| 87 |
+
var body []byte
|
| 88 |
+
if c.Request.Body != nil {
|
| 89 |
+
// Read the body
|
| 90 |
+
bodyBytes, err := io.ReadAll(c.Request.Body)
|
| 91 |
+
if err != nil {
|
| 92 |
+
return nil, err
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// Restore the body for the actual request processing
|
| 96 |
+
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
| 97 |
+
body = bodyBytes
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
return &RequestInfo{
|
| 101 |
+
URL: url,
|
| 102 |
+
Method: method,
|
| 103 |
+
Headers: headers,
|
| 104 |
+
Body: body,
|
| 105 |
+
RequestID: logging.GetGinRequestID(c),
|
| 106 |
+
}, nil
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// shouldLogRequest determines whether the request should be logged.
|
| 110 |
+
// It skips management endpoints to avoid leaking secrets but allows
|
| 111 |
+
// all other routes, including module-provided ones, to honor request-log.
|
| 112 |
+
func shouldLogRequest(path string) bool {
|
| 113 |
+
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
| 114 |
+
return false
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
if strings.HasPrefix(path, "/api") {
|
| 118 |
+
return strings.HasPrefix(path, "/api/provider")
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
return true
|
| 122 |
+
}
|
internal/api/middleware/response_writer.go
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package middleware provides Gin HTTP middleware for the CLI Proxy API server.
|
| 2 |
+
// It includes a sophisticated response writer wrapper designed to capture and log request and response data,
|
| 3 |
+
// including support for streaming responses, without impacting latency.
|
| 4 |
+
package middleware
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"bytes"
|
| 8 |
+
"net/http"
|
| 9 |
+
"strings"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
| 17 |
+
type RequestInfo struct {
|
| 18 |
+
URL string // URL is the request URL.
|
| 19 |
+
Method string // Method is the HTTP method (e.g., GET, POST).
|
| 20 |
+
Headers map[string][]string // Headers contains the request headers.
|
| 21 |
+
Body []byte // Body is the raw request body.
|
| 22 |
+
RequestID string // RequestID is the unique identifier for the request.
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
| 26 |
+
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
| 27 |
+
type ResponseWriterWrapper struct {
|
| 28 |
+
gin.ResponseWriter
|
| 29 |
+
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
| 30 |
+
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
| 31 |
+
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
| 32 |
+
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
| 33 |
+
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
| 34 |
+
logger logging.RequestLogger // logger is the instance of the request logger service.
|
| 35 |
+
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
| 36 |
+
statusCode int // statusCode stores the HTTP status code of the response.
|
| 37 |
+
headers map[string][]string // headers stores the response headers.
|
| 38 |
+
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
| 42 |
+
// It takes the original gin.ResponseWriter, a logger instance, and request information.
|
| 43 |
+
//
|
| 44 |
+
// Parameters:
|
| 45 |
+
// - w: The original gin.ResponseWriter to wrap.
|
| 46 |
+
// - logger: The logging service to use for recording requests.
|
| 47 |
+
// - requestInfo: The pre-captured information about the incoming request.
|
| 48 |
+
//
|
| 49 |
+
// Returns:
|
| 50 |
+
// - A pointer to a new ResponseWriterWrapper.
|
| 51 |
+
func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
|
| 52 |
+
return &ResponseWriterWrapper{
|
| 53 |
+
ResponseWriter: w,
|
| 54 |
+
body: &bytes.Buffer{},
|
| 55 |
+
logger: logger,
|
| 56 |
+
requestInfo: requestInfo,
|
| 57 |
+
headers: make(map[string][]string),
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Write wraps the underlying ResponseWriter's Write method to capture response data.
|
| 62 |
+
// For non-streaming responses, it writes to an internal buffer. For streaming responses,
|
| 63 |
+
// it sends data chunks to a non-blocking channel for asynchronous logging.
|
| 64 |
+
// CRITICAL: This method prioritizes writing to the client to ensure zero latency,
|
| 65 |
+
// handling logging operations subsequently.
|
| 66 |
+
func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
| 67 |
+
// Ensure headers are captured before first write
|
| 68 |
+
// This is critical because Write() may trigger WriteHeader() internally
|
| 69 |
+
w.ensureHeadersCaptured()
|
| 70 |
+
|
| 71 |
+
// CRITICAL: Write to client first (zero latency)
|
| 72 |
+
n, err := w.ResponseWriter.Write(data)
|
| 73 |
+
|
| 74 |
+
// THEN: Handle logging based on response type
|
| 75 |
+
if w.isStreaming && w.chunkChannel != nil {
|
| 76 |
+
// For streaming responses: Send to async logging channel (non-blocking)
|
| 77 |
+
select {
|
| 78 |
+
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
| 79 |
+
default: // Channel full, skip logging to avoid blocking
|
| 80 |
+
}
|
| 81 |
+
return n, err
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
if w.shouldBufferResponseBody() {
|
| 85 |
+
w.body.Write(data)
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
return n, err
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
| 92 |
+
if w.logger != nil && w.logger.IsEnabled() {
|
| 93 |
+
return true
|
| 94 |
+
}
|
| 95 |
+
if !w.logOnErrorOnly {
|
| 96 |
+
return false
|
| 97 |
+
}
|
| 98 |
+
status := w.statusCode
|
| 99 |
+
if status == 0 {
|
| 100 |
+
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
| 101 |
+
status = statusWriter.Status()
|
| 102 |
+
} else {
|
| 103 |
+
status = http.StatusOK
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
return status >= http.StatusBadRequest
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
| 110 |
+
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
| 111 |
+
// bypass Write() and would be missing from request logs.
|
| 112 |
+
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
| 113 |
+
w.ensureHeadersCaptured()
|
| 114 |
+
|
| 115 |
+
// CRITICAL: Write to client first (zero latency)
|
| 116 |
+
n, err := w.ResponseWriter.WriteString(data)
|
| 117 |
+
|
| 118 |
+
// THEN: Capture for logging
|
| 119 |
+
if w.isStreaming && w.chunkChannel != nil {
|
| 120 |
+
select {
|
| 121 |
+
case w.chunkChannel <- []byte(data):
|
| 122 |
+
default:
|
| 123 |
+
}
|
| 124 |
+
return n, err
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if w.shouldBufferResponseBody() {
|
| 128 |
+
w.body.WriteString(data)
|
| 129 |
+
}
|
| 130 |
+
return n, err
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
| 134 |
+
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
| 135 |
+
// and initializes the appropriate logging mechanism (standard or streaming).
|
| 136 |
+
func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
| 137 |
+
w.statusCode = statusCode
|
| 138 |
+
|
| 139 |
+
// Capture response headers using the new method
|
| 140 |
+
w.captureCurrentHeaders()
|
| 141 |
+
|
| 142 |
+
// Detect streaming based on Content-Type
|
| 143 |
+
contentType := w.ResponseWriter.Header().Get("Content-Type")
|
| 144 |
+
w.isStreaming = w.detectStreaming(contentType)
|
| 145 |
+
|
| 146 |
+
// If streaming, initialize streaming log writer
|
| 147 |
+
if w.isStreaming && w.logger.IsEnabled() {
|
| 148 |
+
streamWriter, err := w.logger.LogStreamingRequest(
|
| 149 |
+
w.requestInfo.URL,
|
| 150 |
+
w.requestInfo.Method,
|
| 151 |
+
w.requestInfo.Headers,
|
| 152 |
+
w.requestInfo.Body,
|
| 153 |
+
w.requestInfo.RequestID,
|
| 154 |
+
)
|
| 155 |
+
if err == nil {
|
| 156 |
+
w.streamWriter = streamWriter
|
| 157 |
+
w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
|
| 158 |
+
doneChan := make(chan struct{})
|
| 159 |
+
w.streamDone = doneChan
|
| 160 |
+
|
| 161 |
+
// Start async chunk processor
|
| 162 |
+
go w.processStreamingChunks(doneChan)
|
| 163 |
+
|
| 164 |
+
// Write status immediately
|
| 165 |
+
_ = streamWriter.WriteStatus(statusCode, w.headers)
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// Call original WriteHeader
|
| 170 |
+
w.ResponseWriter.WriteHeader(statusCode)
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// ensureHeadersCaptured is a helper function to make sure response headers are captured.
|
| 174 |
+
// It is safe to call this method multiple times; it will always refresh the headers
|
| 175 |
+
// with the latest state from the underlying ResponseWriter.
|
| 176 |
+
func (w *ResponseWriterWrapper) ensureHeadersCaptured() {
|
| 177 |
+
// Always capture the current headers to ensure we have the latest state
|
| 178 |
+
w.captureCurrentHeaders()
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
// captureCurrentHeaders reads all headers from the underlying ResponseWriter and stores them
|
| 182 |
+
// in the wrapper's headers map. It creates copies of the header values to prevent race conditions.
|
| 183 |
+
func (w *ResponseWriterWrapper) captureCurrentHeaders() {
|
| 184 |
+
// Initialize headers map if needed
|
| 185 |
+
if w.headers == nil {
|
| 186 |
+
w.headers = make(map[string][]string)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Capture all current headers from the underlying ResponseWriter
|
| 190 |
+
for key, values := range w.ResponseWriter.Header() {
|
| 191 |
+
// Make a copy of the values slice to avoid reference issues
|
| 192 |
+
headerValues := make([]string, len(values))
|
| 193 |
+
copy(headerValues, values)
|
| 194 |
+
w.headers[key] = headerValues
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// detectStreaming determines if a response should be treated as a streaming response.
|
| 199 |
+
// It checks for a "text/event-stream" Content-Type or a '"stream": true'
|
| 200 |
+
// field in the original request body.
|
| 201 |
+
func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
| 202 |
+
// Check Content-Type for Server-Sent Events
|
| 203 |
+
if strings.Contains(contentType, "text/event-stream") {
|
| 204 |
+
return true
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
| 208 |
+
// treat it as non-streaming instead of inferring from the request payload.
|
| 209 |
+
if strings.TrimSpace(contentType) != "" {
|
| 210 |
+
return false
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// Only fall back to request payload hints when Content-Type is not set yet.
|
| 214 |
+
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
| 215 |
+
bodyStr := string(w.requestInfo.Body)
|
| 216 |
+
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
return false
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
// processStreamingChunks runs in a separate goroutine to process response chunks from the chunkChannel.
|
| 223 |
+
// It asynchronously writes each chunk to the streaming log writer.
|
| 224 |
+
func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
|
| 225 |
+
if done == nil {
|
| 226 |
+
return
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
defer close(done)
|
| 230 |
+
|
| 231 |
+
if w.streamWriter == nil || w.chunkChannel == nil {
|
| 232 |
+
return
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
for chunk := range w.chunkChannel {
|
| 236 |
+
w.streamWriter.WriteChunkAsync(chunk)
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
// Finalize completes the logging process for the request and response.
|
| 241 |
+
// For streaming responses, it closes the chunk channel and the stream writer.
|
| 242 |
+
// For non-streaming responses, it logs the complete request and response details,
|
| 243 |
+
// including any API-specific request/response data stored in the Gin context.
|
| 244 |
+
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
| 245 |
+
if w.logger == nil {
|
| 246 |
+
return nil
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
finalStatusCode := w.statusCode
|
| 250 |
+
if finalStatusCode == 0 {
|
| 251 |
+
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
| 252 |
+
finalStatusCode = statusWriter.Status()
|
| 253 |
+
} else {
|
| 254 |
+
finalStatusCode = 200
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
var slicesAPIResponseError []*interfaces.ErrorMessage
|
| 259 |
+
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
| 260 |
+
if isExist {
|
| 261 |
+
if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
|
| 262 |
+
slicesAPIResponseError = apiErrors
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
|
| 267 |
+
forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
|
| 268 |
+
if !w.logger.IsEnabled() && !forceLog {
|
| 269 |
+
return nil
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
if w.isStreaming && w.streamWriter != nil {
|
| 273 |
+
if w.chunkChannel != nil {
|
| 274 |
+
close(w.chunkChannel)
|
| 275 |
+
w.chunkChannel = nil
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
if w.streamDone != nil {
|
| 279 |
+
<-w.streamDone
|
| 280 |
+
w.streamDone = nil
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// Write API Request and Response to the streaming log before closing
|
| 284 |
+
apiRequest := w.extractAPIRequest(c)
|
| 285 |
+
if len(apiRequest) > 0 {
|
| 286 |
+
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
| 287 |
+
}
|
| 288 |
+
apiResponse := w.extractAPIResponse(c)
|
| 289 |
+
if len(apiResponse) > 0 {
|
| 290 |
+
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
| 291 |
+
}
|
| 292 |
+
if err := w.streamWriter.Close(); err != nil {
|
| 293 |
+
w.streamWriter = nil
|
| 294 |
+
return err
|
| 295 |
+
}
|
| 296 |
+
w.streamWriter = nil
|
| 297 |
+
return nil
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
| 304 |
+
w.ensureHeadersCaptured()
|
| 305 |
+
|
| 306 |
+
finalHeaders := make(map[string][]string, len(w.headers))
|
| 307 |
+
for key, values := range w.headers {
|
| 308 |
+
headerValues := make([]string, len(values))
|
| 309 |
+
copy(headerValues, values)
|
| 310 |
+
finalHeaders[key] = headerValues
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
return finalHeaders
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
|
| 317 |
+
apiRequest, isExist := c.Get("API_REQUEST")
|
| 318 |
+
if !isExist {
|
| 319 |
+
return nil
|
| 320 |
+
}
|
| 321 |
+
data, ok := apiRequest.([]byte)
|
| 322 |
+
if !ok || len(data) == 0 {
|
| 323 |
+
return nil
|
| 324 |
+
}
|
| 325 |
+
return data
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
| 329 |
+
apiResponse, isExist := c.Get("API_RESPONSE")
|
| 330 |
+
if !isExist {
|
| 331 |
+
return nil
|
| 332 |
+
}
|
| 333 |
+
data, ok := apiResponse.([]byte)
|
| 334 |
+
if !ok || len(data) == 0 {
|
| 335 |
+
return nil
|
| 336 |
+
}
|
| 337 |
+
return data
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
| 341 |
+
if w.requestInfo == nil {
|
| 342 |
+
return nil
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
var requestBody []byte
|
| 346 |
+
if len(w.requestInfo.Body) > 0 {
|
| 347 |
+
requestBody = w.requestInfo.Body
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
if loggerWithOptions, ok := w.logger.(interface {
|
| 351 |
+
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
| 352 |
+
}); ok {
|
| 353 |
+
return loggerWithOptions.LogRequestWithOptions(
|
| 354 |
+
w.requestInfo.URL,
|
| 355 |
+
w.requestInfo.Method,
|
| 356 |
+
w.requestInfo.Headers,
|
| 357 |
+
requestBody,
|
| 358 |
+
statusCode,
|
| 359 |
+
headers,
|
| 360 |
+
body,
|
| 361 |
+
apiRequestBody,
|
| 362 |
+
apiResponseBody,
|
| 363 |
+
apiResponseErrors,
|
| 364 |
+
forceLog,
|
| 365 |
+
w.requestInfo.RequestID,
|
| 366 |
+
)
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
return w.logger.LogRequest(
|
| 370 |
+
w.requestInfo.URL,
|
| 371 |
+
w.requestInfo.Method,
|
| 372 |
+
w.requestInfo.Headers,
|
| 373 |
+
requestBody,
|
| 374 |
+
statusCode,
|
| 375 |
+
headers,
|
| 376 |
+
body,
|
| 377 |
+
apiRequestBody,
|
| 378 |
+
apiResponseBody,
|
| 379 |
+
apiResponseErrors,
|
| 380 |
+
w.requestInfo.RequestID,
|
| 381 |
+
)
|
| 382 |
+
}
|
internal/api/modules/amp/amp.go
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package amp implements the Amp CLI routing module, providing OAuth-based
|
| 2 |
+
// integration with Amp CLI for ChatGPT and Anthropic subscriptions.
|
| 3 |
+
package amp
|
| 4 |
+
|
| 5 |
+
import (
|
| 6 |
+
"fmt"
|
| 7 |
+
"net/http/httputil"
|
| 8 |
+
"strings"
|
| 9 |
+
"sync"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 14 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 15 |
+
log "github.com/sirupsen/logrus"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// Option configures the AmpModule.
|
| 19 |
+
type Option func(*AmpModule)
|
| 20 |
+
|
| 21 |
+
// AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
|
| 22 |
+
// It provides:
|
| 23 |
+
// - Reverse proxy to Amp control plane for OAuth/management
|
| 24 |
+
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
| 25 |
+
// - Automatic gzip decompression for misconfigured upstreams
|
| 26 |
+
// - Model mapping for routing unavailable models to alternatives
|
| 27 |
+
type AmpModule struct {
|
| 28 |
+
secretSource SecretSource
|
| 29 |
+
proxy *httputil.ReverseProxy
|
| 30 |
+
proxyMu sync.RWMutex // protects proxy for hot-reload
|
| 31 |
+
accessManager *sdkaccess.Manager
|
| 32 |
+
authMiddleware_ gin.HandlerFunc
|
| 33 |
+
modelMapper *DefaultModelMapper
|
| 34 |
+
enabled bool
|
| 35 |
+
registerOnce sync.Once
|
| 36 |
+
|
| 37 |
+
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
|
| 38 |
+
restrictToLocalhost bool
|
| 39 |
+
restrictMu sync.RWMutex
|
| 40 |
+
|
| 41 |
+
// configMu protects lastConfig for partial reload comparison
|
| 42 |
+
configMu sync.RWMutex
|
| 43 |
+
lastConfig *config.AmpCode
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// New creates a new Amp routing module with the given options.
|
| 47 |
+
// This is the preferred constructor using the Option pattern.
|
| 48 |
+
//
|
| 49 |
+
// Example:
|
| 50 |
+
//
|
| 51 |
+
// ampModule := amp.New(
|
| 52 |
+
// amp.WithAccessManager(accessManager),
|
| 53 |
+
// amp.WithAuthMiddleware(authMiddleware),
|
| 54 |
+
// amp.WithSecretSource(customSecret),
|
| 55 |
+
// )
|
| 56 |
+
func New(opts ...Option) *AmpModule {
|
| 57 |
+
m := &AmpModule{
|
| 58 |
+
secretSource: nil, // Will be created on demand if not provided
|
| 59 |
+
}
|
| 60 |
+
for _, opt := range opts {
|
| 61 |
+
opt(m)
|
| 62 |
+
}
|
| 63 |
+
return m
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// NewLegacy creates a new Amp routing module using the legacy constructor signature.
|
| 67 |
+
// This is provided for backwards compatibility.
|
| 68 |
+
//
|
| 69 |
+
// DEPRECATED: Use New with options instead.
|
| 70 |
+
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
|
| 71 |
+
return New(
|
| 72 |
+
WithAccessManager(accessManager),
|
| 73 |
+
WithAuthMiddleware(authMiddleware),
|
| 74 |
+
)
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// WithSecretSource sets a custom secret source for the module.
|
| 78 |
+
func WithSecretSource(source SecretSource) Option {
|
| 79 |
+
return func(m *AmpModule) {
|
| 80 |
+
m.secretSource = source
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// WithAccessManager sets the access manager for the module.
|
| 85 |
+
func WithAccessManager(am *sdkaccess.Manager) Option {
|
| 86 |
+
return func(m *AmpModule) {
|
| 87 |
+
m.accessManager = am
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// WithAuthMiddleware sets the authentication middleware for provider routes.
|
| 92 |
+
func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
|
| 93 |
+
return func(m *AmpModule) {
|
| 94 |
+
m.authMiddleware_ = middleware
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// Name returns the module identifier
|
| 99 |
+
func (m *AmpModule) Name() string {
|
| 100 |
+
return "amp-routing"
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// forceModelMappings returns whether model mappings should take precedence over local API keys
|
| 104 |
+
func (m *AmpModule) forceModelMappings() bool {
|
| 105 |
+
m.configMu.RLock()
|
| 106 |
+
defer m.configMu.RUnlock()
|
| 107 |
+
if m.lastConfig == nil {
|
| 108 |
+
return false
|
| 109 |
+
}
|
| 110 |
+
return m.lastConfig.ForceModelMappings
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Register sets up Amp routes if configured.
|
| 114 |
+
// This implements the RouteModuleV2 interface with Context.
|
| 115 |
+
// Routes are registered only once via sync.Once for idempotent behavior.
|
| 116 |
+
func (m *AmpModule) Register(ctx modules.Context) error {
|
| 117 |
+
settings := ctx.Config.AmpCode
|
| 118 |
+
upstreamURL := strings.TrimSpace(settings.UpstreamURL)
|
| 119 |
+
|
| 120 |
+
// Determine auth middleware (from module or context)
|
| 121 |
+
auth := m.getAuthMiddleware(ctx)
|
| 122 |
+
|
| 123 |
+
// Use registerOnce to ensure routes are only registered once
|
| 124 |
+
var regErr error
|
| 125 |
+
m.registerOnce.Do(func() {
|
| 126 |
+
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
| 127 |
+
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
| 128 |
+
|
| 129 |
+
// Store initial config for partial reload comparison
|
| 130 |
+
settingsCopy := settings
|
| 131 |
+
m.lastConfig = &settingsCopy
|
| 132 |
+
|
| 133 |
+
// Initialize localhost restriction setting (hot-reloadable)
|
| 134 |
+
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
| 135 |
+
|
| 136 |
+
// Always register provider aliases - these work without an upstream
|
| 137 |
+
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
| 138 |
+
|
| 139 |
+
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
| 140 |
+
// Pass auth middleware to require valid API key for all management routes.
|
| 141 |
+
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
| 142 |
+
|
| 143 |
+
// If no upstream URL, skip proxy routes but provider aliases are still available
|
| 144 |
+
if upstreamURL == "" {
|
| 145 |
+
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
|
| 146 |
+
log.Debug("amp provider alias routes registered")
|
| 147 |
+
m.enabled = false
|
| 148 |
+
return
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
|
| 152 |
+
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
| 153 |
+
return
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
log.Debug("amp provider alias routes registered")
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
return regErr
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// getAuthMiddleware returns the authentication middleware, preferring the
|
| 163 |
+
// module's configured middleware, then the context middleware, then a fallback.
|
| 164 |
+
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
| 165 |
+
if m.authMiddleware_ != nil {
|
| 166 |
+
return m.authMiddleware_
|
| 167 |
+
}
|
| 168 |
+
if ctx.AuthMiddleware != nil {
|
| 169 |
+
return ctx.AuthMiddleware
|
| 170 |
+
}
|
| 171 |
+
// Fallback: no authentication (should not happen in production)
|
| 172 |
+
log.Warn("amp module: no auth middleware provided, allowing all requests")
|
| 173 |
+
return func(c *gin.Context) {
|
| 174 |
+
c.Next()
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// OnConfigUpdated handles configuration updates with partial reload support.
|
| 179 |
+
// Only updates components that have actually changed to avoid unnecessary work.
|
| 180 |
+
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
|
| 181 |
+
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
| 182 |
+
newSettings := cfg.AmpCode
|
| 183 |
+
|
| 184 |
+
// Get previous config for comparison
|
| 185 |
+
m.configMu.RLock()
|
| 186 |
+
oldSettings := m.lastConfig
|
| 187 |
+
m.configMu.RUnlock()
|
| 188 |
+
|
| 189 |
+
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
| 190 |
+
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
| 194 |
+
oldUpstreamURL := ""
|
| 195 |
+
if oldSettings != nil {
|
| 196 |
+
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if !m.enabled && newUpstreamURL != "" {
|
| 200 |
+
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
|
| 201 |
+
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Check model mappings change
|
| 206 |
+
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
|
| 207 |
+
if modelMappingsChanged {
|
| 208 |
+
if m.modelMapper != nil {
|
| 209 |
+
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
|
| 210 |
+
} else if m.enabled {
|
| 211 |
+
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
if m.enabled {
|
| 216 |
+
// Check upstream URL change - now supports hot-reload
|
| 217 |
+
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
| 218 |
+
m.setProxy(nil)
|
| 219 |
+
m.enabled = false
|
| 220 |
+
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
| 221 |
+
// Recreate proxy with new URL
|
| 222 |
+
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
| 223 |
+
if err != nil {
|
| 224 |
+
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
| 225 |
+
} else {
|
| 226 |
+
m.setProxy(proxy)
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
// Check API key change (both default and per-client mappings)
|
| 231 |
+
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
| 232 |
+
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
| 233 |
+
if apiKeyChanged || upstreamAPIKeysChanged {
|
| 234 |
+
if m.secretSource != nil {
|
| 235 |
+
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
| 236 |
+
if apiKeyChanged {
|
| 237 |
+
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
| 238 |
+
ms.InvalidateCache()
|
| 239 |
+
}
|
| 240 |
+
if upstreamAPIKeysChanged {
|
| 241 |
+
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
| 242 |
+
}
|
| 243 |
+
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
| 244 |
+
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
| 245 |
+
ms.InvalidateCache()
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
// Store current config for next comparison
|
| 253 |
+
m.configMu.Lock()
|
| 254 |
+
settingsCopy := newSettings // copy struct
|
| 255 |
+
m.lastConfig = &settingsCopy
|
| 256 |
+
m.configMu.Unlock()
|
| 257 |
+
|
| 258 |
+
return nil
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
| 262 |
+
if m.secretSource == nil {
|
| 263 |
+
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
| 264 |
+
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
| 265 |
+
mappedSource := NewMappedSecretSource(defaultSource)
|
| 266 |
+
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
| 267 |
+
m.secretSource = mappedSource
|
| 268 |
+
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
| 269 |
+
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
| 270 |
+
ms.InvalidateCache()
|
| 271 |
+
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
| 272 |
+
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
| 273 |
+
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
| 274 |
+
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
| 275 |
+
ms.InvalidateCache()
|
| 276 |
+
mappedSource := NewMappedSecretSource(ms)
|
| 277 |
+
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
| 278 |
+
m.secretSource = mappedSource
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
| 282 |
+
if err != nil {
|
| 283 |
+
return err
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
m.setProxy(proxy)
|
| 287 |
+
m.enabled = true
|
| 288 |
+
|
| 289 |
+
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
| 290 |
+
return nil
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
// hasModelMappingsChanged compares old and new model mappings.
|
| 294 |
+
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
| 295 |
+
if old == nil {
|
| 296 |
+
return len(new.ModelMappings) > 0
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
if len(old.ModelMappings) != len(new.ModelMappings) {
|
| 300 |
+
return true
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// Build map for efficient and robust comparison
|
| 304 |
+
type mappingInfo struct {
|
| 305 |
+
to string
|
| 306 |
+
regex bool
|
| 307 |
+
}
|
| 308 |
+
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
| 309 |
+
for _, mapping := range old.ModelMappings {
|
| 310 |
+
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
| 311 |
+
to: strings.TrimSpace(mapping.To),
|
| 312 |
+
regex: mapping.Regex,
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
for _, mapping := range new.ModelMappings {
|
| 317 |
+
from := strings.TrimSpace(mapping.From)
|
| 318 |
+
to := strings.TrimSpace(mapping.To)
|
| 319 |
+
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
| 320 |
+
return true
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
return false
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
// hasAPIKeyChanged compares old and new API keys.
|
| 328 |
+
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
| 329 |
+
oldKey := ""
|
| 330 |
+
if old != nil {
|
| 331 |
+
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
|
| 332 |
+
}
|
| 333 |
+
newKey := strings.TrimSpace(new.UpstreamAPIKey)
|
| 334 |
+
return oldKey != newKey
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
| 338 |
+
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
| 339 |
+
if old == nil {
|
| 340 |
+
return len(new.UpstreamAPIKeys) > 0
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
| 344 |
+
return true
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// Build map for comparison: upstreamKey -> set of clientKeys
|
| 348 |
+
type entryInfo struct {
|
| 349 |
+
upstreamKey string
|
| 350 |
+
clientKeys map[string]struct{}
|
| 351 |
+
}
|
| 352 |
+
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
| 353 |
+
for i, entry := range old.UpstreamAPIKeys {
|
| 354 |
+
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
| 355 |
+
for _, k := range entry.APIKeys {
|
| 356 |
+
trimmed := strings.TrimSpace(k)
|
| 357 |
+
if trimmed == "" {
|
| 358 |
+
continue
|
| 359 |
+
}
|
| 360 |
+
clientKeys[trimmed] = struct{}{}
|
| 361 |
+
}
|
| 362 |
+
oldEntries[i] = entryInfo{
|
| 363 |
+
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
| 364 |
+
clientKeys: clientKeys,
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
for i, newEntry := range new.UpstreamAPIKeys {
|
| 369 |
+
if i >= len(oldEntries) {
|
| 370 |
+
return true
|
| 371 |
+
}
|
| 372 |
+
oldE := oldEntries[i]
|
| 373 |
+
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
| 374 |
+
return true
|
| 375 |
+
}
|
| 376 |
+
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
| 377 |
+
for _, k := range newEntry.APIKeys {
|
| 378 |
+
trimmed := strings.TrimSpace(k)
|
| 379 |
+
if trimmed == "" {
|
| 380 |
+
continue
|
| 381 |
+
}
|
| 382 |
+
newKeys[trimmed] = struct{}{}
|
| 383 |
+
}
|
| 384 |
+
if len(newKeys) != len(oldE.clientKeys) {
|
| 385 |
+
return true
|
| 386 |
+
}
|
| 387 |
+
for k := range newKeys {
|
| 388 |
+
if _, ok := oldE.clientKeys[k]; !ok {
|
| 389 |
+
return true
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
return false
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
| 398 |
+
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
| 399 |
+
return m.modelMapper
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
// getProxy returns the current proxy instance (thread-safe for hot-reload).
|
| 403 |
+
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
|
| 404 |
+
m.proxyMu.RLock()
|
| 405 |
+
defer m.proxyMu.RUnlock()
|
| 406 |
+
return m.proxy
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
// setProxy updates the proxy instance (thread-safe for hot-reload).
|
| 410 |
+
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
|
| 411 |
+
m.proxyMu.Lock()
|
| 412 |
+
defer m.proxyMu.Unlock()
|
| 413 |
+
m.proxy = proxy
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
|
| 417 |
+
func (m *AmpModule) IsRestrictedToLocalhost() bool {
|
| 418 |
+
m.restrictMu.RLock()
|
| 419 |
+
defer m.restrictMu.RUnlock()
|
| 420 |
+
return m.restrictToLocalhost
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
// setRestrictToLocalhost updates the localhost restriction setting.
|
| 424 |
+
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
|
| 425 |
+
m.restrictMu.Lock()
|
| 426 |
+
defer m.restrictMu.Unlock()
|
| 427 |
+
m.restrictToLocalhost = restrict
|
| 428 |
+
}
|
internal/api/modules/amp/amp_test.go
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"net/http/httptest"
|
| 6 |
+
"os"
|
| 7 |
+
"path/filepath"
|
| 8 |
+
"testing"
|
| 9 |
+
"time"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 14 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 15 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
func TestAmpModule_Name(t *testing.T) {
|
| 19 |
+
m := New()
|
| 20 |
+
if m.Name() != "amp-routing" {
|
| 21 |
+
t.Fatalf("want amp-routing, got %s", m.Name())
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
func TestAmpModule_New(t *testing.T) {
|
| 26 |
+
accessManager := sdkaccess.NewManager()
|
| 27 |
+
authMiddleware := func(c *gin.Context) { c.Next() }
|
| 28 |
+
|
| 29 |
+
m := NewLegacy(accessManager, authMiddleware)
|
| 30 |
+
|
| 31 |
+
if m.accessManager != accessManager {
|
| 32 |
+
t.Fatal("accessManager not set")
|
| 33 |
+
}
|
| 34 |
+
if m.authMiddleware_ == nil {
|
| 35 |
+
t.Fatal("authMiddleware not set")
|
| 36 |
+
}
|
| 37 |
+
if m.enabled {
|
| 38 |
+
t.Fatal("enabled should be false initially")
|
| 39 |
+
}
|
| 40 |
+
if m.proxy != nil {
|
| 41 |
+
t.Fatal("proxy should be nil initially")
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
func TestAmpModule_Register_WithUpstream(t *testing.T) {
|
| 46 |
+
gin.SetMode(gin.TestMode)
|
| 47 |
+
r := gin.New()
|
| 48 |
+
|
| 49 |
+
// Fake upstream to ensure URL is valid
|
| 50 |
+
upstream := httptest.NewServer(nil)
|
| 51 |
+
defer upstream.Close()
|
| 52 |
+
|
| 53 |
+
accessManager := sdkaccess.NewManager()
|
| 54 |
+
base := &handlers.BaseAPIHandler{}
|
| 55 |
+
|
| 56 |
+
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
| 57 |
+
|
| 58 |
+
cfg := &config.Config{
|
| 59 |
+
AmpCode: config.AmpCode{
|
| 60 |
+
UpstreamURL: upstream.URL,
|
| 61 |
+
UpstreamAPIKey: "test-key",
|
| 62 |
+
},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
| 66 |
+
if err := m.Register(ctx); err != nil {
|
| 67 |
+
t.Fatalf("register error: %v", err)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if !m.enabled {
|
| 71 |
+
t.Fatal("module should be enabled with upstream URL")
|
| 72 |
+
}
|
| 73 |
+
if m.proxy == nil {
|
| 74 |
+
t.Fatal("proxy should be initialized")
|
| 75 |
+
}
|
| 76 |
+
if m.secretSource == nil {
|
| 77 |
+
t.Fatal("secretSource should be initialized")
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
func TestAmpModule_Register_WithoutUpstream(t *testing.T) {
|
| 82 |
+
gin.SetMode(gin.TestMode)
|
| 83 |
+
r := gin.New()
|
| 84 |
+
|
| 85 |
+
accessManager := sdkaccess.NewManager()
|
| 86 |
+
base := &handlers.BaseAPIHandler{}
|
| 87 |
+
|
| 88 |
+
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
| 89 |
+
|
| 90 |
+
cfg := &config.Config{
|
| 91 |
+
AmpCode: config.AmpCode{
|
| 92 |
+
UpstreamURL: "", // No upstream
|
| 93 |
+
},
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
| 97 |
+
if err := m.Register(ctx); err != nil {
|
| 98 |
+
t.Fatalf("register should not error without upstream: %v", err)
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
if m.enabled {
|
| 102 |
+
t.Fatal("module should be disabled without upstream URL")
|
| 103 |
+
}
|
| 104 |
+
if m.proxy != nil {
|
| 105 |
+
t.Fatal("proxy should not be initialized without upstream")
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// But provider aliases should still be registered
|
| 109 |
+
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
| 110 |
+
w := httptest.NewRecorder()
|
| 111 |
+
r.ServeHTTP(w, req)
|
| 112 |
+
|
| 113 |
+
if w.Code == 404 {
|
| 114 |
+
t.Fatal("provider aliases should be registered even without upstream")
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
func TestAmpModule_Register_InvalidUpstream(t *testing.T) {
|
| 119 |
+
gin.SetMode(gin.TestMode)
|
| 120 |
+
r := gin.New()
|
| 121 |
+
|
| 122 |
+
accessManager := sdkaccess.NewManager()
|
| 123 |
+
base := &handlers.BaseAPIHandler{}
|
| 124 |
+
|
| 125 |
+
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
| 126 |
+
|
| 127 |
+
cfg := &config.Config{
|
| 128 |
+
AmpCode: config.AmpCode{
|
| 129 |
+
UpstreamURL: "://invalid-url",
|
| 130 |
+
},
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
| 134 |
+
if err := m.Register(ctx); err == nil {
|
| 135 |
+
t.Fatal("expected error for invalid upstream URL")
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
| 140 |
+
tmpDir := t.TempDir()
|
| 141 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 142 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
| 143 |
+
t.Fatal(err)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
m := &AmpModule{enabled: true}
|
| 147 |
+
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
| 148 |
+
m.secretSource = ms
|
| 149 |
+
m.lastConfig = &config.AmpCode{
|
| 150 |
+
UpstreamAPIKey: "old-key",
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// Warm the cache
|
| 154 |
+
if _, err := ms.Get(context.Background()); err != nil {
|
| 155 |
+
t.Fatal(err)
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
if ms.cache == nil {
|
| 159 |
+
t.Fatal("expected cache to be set")
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Update config - should invalidate cache
|
| 163 |
+
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
| 164 |
+
t.Fatal(err)
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
if ms.cache != nil {
|
| 168 |
+
t.Fatal("expected cache to be invalidated")
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) {
|
| 173 |
+
m := &AmpModule{enabled: false}
|
| 174 |
+
|
| 175 |
+
// Should not error or panic when disabled
|
| 176 |
+
if err := m.OnConfigUpdated(&config.Config{}); err != nil {
|
| 177 |
+
t.Fatalf("unexpected error: %v", err)
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) {
|
| 182 |
+
m := &AmpModule{enabled: true}
|
| 183 |
+
ms := NewMultiSourceSecret("", 0)
|
| 184 |
+
m.secretSource = ms
|
| 185 |
+
|
| 186 |
+
// Config update with empty URL - should log warning but not error
|
| 187 |
+
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: ""}}
|
| 188 |
+
|
| 189 |
+
if err := m.OnConfigUpdated(cfg); err != nil {
|
| 190 |
+
t.Fatalf("unexpected error: %v", err)
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) {
|
| 195 |
+
// Test that OnConfigUpdated doesn't panic with StaticSecretSource
|
| 196 |
+
m := &AmpModule{enabled: true}
|
| 197 |
+
m.secretSource = NewStaticSecretSource("static-key")
|
| 198 |
+
|
| 199 |
+
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://example.com"}}
|
| 200 |
+
|
| 201 |
+
// Should not error or panic
|
| 202 |
+
if err := m.OnConfigUpdated(cfg); err != nil {
|
| 203 |
+
t.Fatalf("unexpected error: %v", err)
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) {
|
| 208 |
+
gin.SetMode(gin.TestMode)
|
| 209 |
+
r := gin.New()
|
| 210 |
+
|
| 211 |
+
// Create module with no auth middleware
|
| 212 |
+
m := &AmpModule{authMiddleware_: nil}
|
| 213 |
+
|
| 214 |
+
// Get the fallback middleware via getAuthMiddleware
|
| 215 |
+
ctx := modules.Context{Engine: r, AuthMiddleware: nil}
|
| 216 |
+
middleware := m.getAuthMiddleware(ctx)
|
| 217 |
+
|
| 218 |
+
if middleware == nil {
|
| 219 |
+
t.Fatal("getAuthMiddleware should return a fallback, not nil")
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
// Test that it works
|
| 223 |
+
called := false
|
| 224 |
+
r.GET("/test", middleware, func(c *gin.Context) {
|
| 225 |
+
called = true
|
| 226 |
+
c.String(200, "ok")
|
| 227 |
+
})
|
| 228 |
+
|
| 229 |
+
req := httptest.NewRequest("GET", "/test", nil)
|
| 230 |
+
w := httptest.NewRecorder()
|
| 231 |
+
r.ServeHTTP(w, req)
|
| 232 |
+
|
| 233 |
+
if !called {
|
| 234 |
+
t.Fatal("fallback middleware should allow requests through")
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
func TestAmpModule_SecretSource_FromConfig(t *testing.T) {
|
| 239 |
+
gin.SetMode(gin.TestMode)
|
| 240 |
+
r := gin.New()
|
| 241 |
+
|
| 242 |
+
upstream := httptest.NewServer(nil)
|
| 243 |
+
defer upstream.Close()
|
| 244 |
+
|
| 245 |
+
accessManager := sdkaccess.NewManager()
|
| 246 |
+
base := &handlers.BaseAPIHandler{}
|
| 247 |
+
|
| 248 |
+
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
| 249 |
+
|
| 250 |
+
// Config with explicit API key
|
| 251 |
+
cfg := &config.Config{
|
| 252 |
+
AmpCode: config.AmpCode{
|
| 253 |
+
UpstreamURL: upstream.URL,
|
| 254 |
+
UpstreamAPIKey: "config-key",
|
| 255 |
+
},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
| 259 |
+
if err := m.Register(ctx); err != nil {
|
| 260 |
+
t.Fatalf("register error: %v", err)
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// Secret source should be MultiSourceSecret with config key
|
| 264 |
+
if m.secretSource == nil {
|
| 265 |
+
t.Fatal("secretSource should be set")
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Verify it returns the config key
|
| 269 |
+
key, err := m.secretSource.Get(context.Background())
|
| 270 |
+
if err != nil {
|
| 271 |
+
t.Fatalf("Get error: %v", err)
|
| 272 |
+
}
|
| 273 |
+
if key != "config-key" {
|
| 274 |
+
t.Fatalf("want config-key, got %s", key)
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
| 279 |
+
gin.SetMode(gin.TestMode)
|
| 280 |
+
|
| 281 |
+
scenarios := []struct {
|
| 282 |
+
name string
|
| 283 |
+
configURL string
|
| 284 |
+
}{
|
| 285 |
+
{"with_upstream", "http://example.com"},
|
| 286 |
+
{"without_upstream", ""},
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
for _, scenario := range scenarios {
|
| 290 |
+
t.Run(scenario.name, func(t *testing.T) {
|
| 291 |
+
r := gin.New()
|
| 292 |
+
accessManager := sdkaccess.NewManager()
|
| 293 |
+
base := &handlers.BaseAPIHandler{}
|
| 294 |
+
|
| 295 |
+
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
|
| 296 |
+
|
| 297 |
+
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: scenario.configURL}}
|
| 298 |
+
|
| 299 |
+
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
|
| 300 |
+
if err := m.Register(ctx); err != nil && scenario.configURL != "" {
|
| 301 |
+
t.Fatalf("register error: %v", err)
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// Provider aliases should always be available
|
| 305 |
+
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
|
| 306 |
+
w := httptest.NewRecorder()
|
| 307 |
+
r.ServeHTTP(w, req)
|
| 308 |
+
|
| 309 |
+
if w.Code == 404 {
|
| 310 |
+
t.Fatal("provider aliases should be registered")
|
| 311 |
+
}
|
| 312 |
+
})
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
| 317 |
+
m := &AmpModule{}
|
| 318 |
+
|
| 319 |
+
oldCfg := &config.AmpCode{
|
| 320 |
+
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
| 321 |
+
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
| 322 |
+
},
|
| 323 |
+
}
|
| 324 |
+
newCfg := &config.AmpCode{
|
| 325 |
+
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
| 326 |
+
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
| 327 |
+
},
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
| 331 |
+
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
| 336 |
+
m := &AmpModule{}
|
| 337 |
+
|
| 338 |
+
oldCfg := &config.AmpCode{
|
| 339 |
+
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
| 340 |
+
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
| 341 |
+
},
|
| 342 |
+
}
|
| 343 |
+
newCfg := &config.AmpCode{
|
| 344 |
+
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
| 345 |
+
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
| 346 |
+
},
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
| 350 |
+
t.Fatal("expected no change when only whitespace/empty entries differ")
|
| 351 |
+
}
|
| 352 |
+
}
|
internal/api/modules/amp/fallback_handlers.go
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"io"
|
| 6 |
+
"net/http/httputil"
|
| 7 |
+
"strings"
|
| 8 |
+
"time"
|
| 9 |
+
|
| 10 |
+
"github.com/gin-gonic/gin"
|
| 11 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 12 |
+
log "github.com/sirupsen/logrus"
|
| 13 |
+
"github.com/tidwall/gjson"
|
| 14 |
+
"github.com/tidwall/sjson"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
// AmpRouteType represents the type of routing decision made for an Amp request
|
| 18 |
+
type AmpRouteType string
|
| 19 |
+
|
| 20 |
+
const (
|
| 21 |
+
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
|
| 22 |
+
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
|
| 23 |
+
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
|
| 24 |
+
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
|
| 25 |
+
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
|
| 26 |
+
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
|
| 27 |
+
// RouteTypeNoProvider indicates no provider or fallback available
|
| 28 |
+
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
| 32 |
+
const MappedModelContextKey = "mapped_model"
|
| 33 |
+
|
| 34 |
+
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
| 35 |
+
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
| 36 |
+
fields := log.Fields{
|
| 37 |
+
"component": "amp-routing",
|
| 38 |
+
"route_type": string(routeType),
|
| 39 |
+
"requested_model": requestedModel,
|
| 40 |
+
"path": path,
|
| 41 |
+
"timestamp": time.Now().Format(time.RFC3339),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
if resolvedModel != "" && resolvedModel != requestedModel {
|
| 45 |
+
fields["resolved_model"] = resolvedModel
|
| 46 |
+
}
|
| 47 |
+
if provider != "" {
|
| 48 |
+
fields["provider"] = provider
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
switch routeType {
|
| 52 |
+
case RouteTypeLocalProvider:
|
| 53 |
+
fields["cost"] = "free"
|
| 54 |
+
fields["source"] = "local_oauth"
|
| 55 |
+
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
| 56 |
+
|
| 57 |
+
case RouteTypeModelMapping:
|
| 58 |
+
fields["cost"] = "free"
|
| 59 |
+
fields["source"] = "local_oauth"
|
| 60 |
+
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
| 61 |
+
// model mapping already logged in mapper; avoid duplicate here
|
| 62 |
+
|
| 63 |
+
case RouteTypeAmpCredits:
|
| 64 |
+
fields["cost"] = "amp_credits"
|
| 65 |
+
fields["source"] = "ampcode.com"
|
| 66 |
+
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
| 67 |
+
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
| 68 |
+
|
| 69 |
+
case RouteTypeNoProvider:
|
| 70 |
+
fields["cost"] = "none"
|
| 71 |
+
fields["source"] = "error"
|
| 72 |
+
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
| 73 |
+
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
| 78 |
+
// when the model's provider is not available in CLIProxyAPI
|
| 79 |
+
type FallbackHandler struct {
|
| 80 |
+
getProxy func() *httputil.ReverseProxy
|
| 81 |
+
modelMapper ModelMapper
|
| 82 |
+
forceModelMappings func() bool
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// NewFallbackHandler creates a new fallback handler wrapper
|
| 86 |
+
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
| 87 |
+
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
| 88 |
+
return &FallbackHandler{
|
| 89 |
+
getProxy: getProxy,
|
| 90 |
+
forceModelMappings: func() bool { return false },
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
| 95 |
+
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
| 96 |
+
if forceModelMappings == nil {
|
| 97 |
+
forceModelMappings = func() bool { return false }
|
| 98 |
+
}
|
| 99 |
+
return &FallbackHandler{
|
| 100 |
+
getProxy: getProxy,
|
| 101 |
+
modelMapper: mapper,
|
| 102 |
+
forceModelMappings: forceModelMappings,
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// SetModelMapper sets the model mapper for this handler (allows late binding)
|
| 107 |
+
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
| 108 |
+
fh.modelMapper = mapper
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
| 112 |
+
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
| 113 |
+
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
| 114 |
+
return func(c *gin.Context) {
|
| 115 |
+
requestPath := c.Request.URL.Path
|
| 116 |
+
|
| 117 |
+
// Read the request body to extract the model name
|
| 118 |
+
bodyBytes, err := io.ReadAll(c.Request.Body)
|
| 119 |
+
if err != nil {
|
| 120 |
+
log.Errorf("amp fallback: failed to read request body: %v", err)
|
| 121 |
+
handler(c)
|
| 122 |
+
return
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Restore the body for the handler to read
|
| 126 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 127 |
+
|
| 128 |
+
// Try to extract model from request body or URL path (for Gemini)
|
| 129 |
+
modelName := extractModelFromRequest(bodyBytes, c)
|
| 130 |
+
if modelName == "" {
|
| 131 |
+
// Can't determine model, proceed with normal handler
|
| 132 |
+
handler(c)
|
| 133 |
+
return
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// Normalize model (handles dynamic thinking suffixes)
|
| 137 |
+
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
| 138 |
+
thinkingSuffix := ""
|
| 139 |
+
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
| 140 |
+
thinkingSuffix = modelName[len(normalizedModel):]
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
resolveMappedModel := func() (string, []string) {
|
| 144 |
+
if fh.modelMapper == nil {
|
| 145 |
+
return "", nil
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
mappedModel := fh.modelMapper.MapModel(modelName)
|
| 149 |
+
if mappedModel == "" {
|
| 150 |
+
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
| 151 |
+
}
|
| 152 |
+
mappedModel = strings.TrimSpace(mappedModel)
|
| 153 |
+
if mappedModel == "" {
|
| 154 |
+
return "", nil
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
| 158 |
+
// already specifies its own thinking suffix.
|
| 159 |
+
if thinkingSuffix != "" {
|
| 160 |
+
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
| 161 |
+
if mappedThinkingMetadata == nil {
|
| 162 |
+
mappedModel += thinkingSuffix
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
| 167 |
+
mappedProviders := util.GetProviderName(mappedBaseModel)
|
| 168 |
+
if len(mappedProviders) == 0 {
|
| 169 |
+
return "", nil
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return mappedModel, mappedProviders
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Track resolved model for logging (may change if mapping is applied)
|
| 176 |
+
resolvedModel := normalizedModel
|
| 177 |
+
usedMapping := false
|
| 178 |
+
var providers []string
|
| 179 |
+
|
| 180 |
+
// Check if model mappings should be forced ahead of local API keys
|
| 181 |
+
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
| 182 |
+
|
| 183 |
+
if forceMappings {
|
| 184 |
+
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
| 185 |
+
// This allows users to route Amp requests to their preferred OAuth providers
|
| 186 |
+
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
| 187 |
+
// Mapping found and provider available - rewrite the model in request body
|
| 188 |
+
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
| 189 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 190 |
+
// Store mapped model in context for handlers that check it (like gemini bridge)
|
| 191 |
+
c.Set(MappedModelContextKey, mappedModel)
|
| 192 |
+
resolvedModel = mappedModel
|
| 193 |
+
usedMapping = true
|
| 194 |
+
providers = mappedProviders
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// If no mapping applied, check for local providers
|
| 198 |
+
if !usedMapping {
|
| 199 |
+
providers = util.GetProviderName(normalizedModel)
|
| 200 |
+
}
|
| 201 |
+
} else {
|
| 202 |
+
// DEFAULT MODE: Check local providers first, then mappings as fallback
|
| 203 |
+
providers = util.GetProviderName(normalizedModel)
|
| 204 |
+
|
| 205 |
+
if len(providers) == 0 {
|
| 206 |
+
// No providers configured - check if we have a model mapping
|
| 207 |
+
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
| 208 |
+
// Mapping found and provider available - rewrite the model in request body
|
| 209 |
+
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
| 210 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 211 |
+
// Store mapped model in context for handlers that check it (like gemini bridge)
|
| 212 |
+
c.Set(MappedModelContextKey, mappedModel)
|
| 213 |
+
resolvedModel = mappedModel
|
| 214 |
+
usedMapping = true
|
| 215 |
+
providers = mappedProviders
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// If no providers available, fallback to ampcode.com
|
| 221 |
+
if len(providers) == 0 {
|
| 222 |
+
proxy := fh.getProxy()
|
| 223 |
+
if proxy != nil {
|
| 224 |
+
// Log: Forwarding to ampcode.com (uses Amp credits)
|
| 225 |
+
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
|
| 226 |
+
|
| 227 |
+
// Restore body again for the proxy
|
| 228 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 229 |
+
|
| 230 |
+
// Forward to ampcode.com
|
| 231 |
+
proxy.ServeHTTP(c.Writer, c.Request)
|
| 232 |
+
return
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// No proxy available, let the normal handler return the error
|
| 236 |
+
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
// Log the routing decision
|
| 240 |
+
providerName := ""
|
| 241 |
+
if len(providers) > 0 {
|
| 242 |
+
providerName = providers[0]
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
if usedMapping {
|
| 246 |
+
// Log: Model was mapped to another model
|
| 247 |
+
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
| 248 |
+
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
| 249 |
+
rewriter := NewResponseRewriter(c.Writer, modelName)
|
| 250 |
+
c.Writer = rewriter
|
| 251 |
+
// Filter Anthropic-Beta header only for local handling paths
|
| 252 |
+
filterAntropicBetaHeader(c)
|
| 253 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 254 |
+
handler(c)
|
| 255 |
+
rewriter.Flush()
|
| 256 |
+
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
| 257 |
+
} else if len(providers) > 0 {
|
| 258 |
+
// Log: Using local provider (free)
|
| 259 |
+
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
| 260 |
+
// Filter Anthropic-Beta header only for local handling paths
|
| 261 |
+
filterAntropicBetaHeader(c)
|
| 262 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 263 |
+
handler(c)
|
| 264 |
+
} else {
|
| 265 |
+
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
| 266 |
+
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
| 267 |
+
handler(c)
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
|
| 273 |
+
// This is needed when using local providers (bypassing the Amp proxy)
|
| 274 |
+
func filterAntropicBetaHeader(c *gin.Context) {
|
| 275 |
+
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
| 276 |
+
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
|
| 277 |
+
c.Request.Header.Set("Anthropic-Beta", filtered)
|
| 278 |
+
} else {
|
| 279 |
+
c.Request.Header.Del("Anthropic-Beta")
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// rewriteModelInRequest replaces the model name in a JSON request body
|
| 285 |
+
func rewriteModelInRequest(body []byte, newModel string) []byte {
|
| 286 |
+
if !gjson.GetBytes(body, "model").Exists() {
|
| 287 |
+
return body
|
| 288 |
+
}
|
| 289 |
+
result, err := sjson.SetBytes(body, "model", newModel)
|
| 290 |
+
if err != nil {
|
| 291 |
+
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
|
| 292 |
+
return body
|
| 293 |
+
}
|
| 294 |
+
return result
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
// extractModelFromRequest attempts to extract the model name from various request formats
|
| 298 |
+
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
| 299 |
+
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
| 300 |
+
// Check common model field names
|
| 301 |
+
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
| 302 |
+
return result.String()
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// For Gemini requests, model is in the URL path
|
| 306 |
+
// Standard format: /models/{model}:generateContent -> :action parameter
|
| 307 |
+
if action := c.Param("action"); action != "" {
|
| 308 |
+
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
| 309 |
+
parts := strings.Split(action, ":")
|
| 310 |
+
if len(parts) > 0 && parts[0] != "" {
|
| 311 |
+
return parts[0]
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
| 316 |
+
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
| 317 |
+
if path := c.Param("path"); path != "" {
|
| 318 |
+
// Look for /models/{model}:method pattern
|
| 319 |
+
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
| 320 |
+
modelPart := path[idx+8:] // Skip "/models/"
|
| 321 |
+
// Split by colon to get model name
|
| 322 |
+
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
| 323 |
+
return modelPart[:colonIdx]
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
return ""
|
| 329 |
+
}
|
internal/api/modules/amp/fallback_handlers_test.go
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"net/http"
|
| 7 |
+
"net/http/httptest"
|
| 8 |
+
"net/http/httputil"
|
| 9 |
+
"testing"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
| 17 |
+
gin.SetMode(gin.TestMode)
|
| 18 |
+
|
| 19 |
+
reg := registry.GetGlobalRegistry()
|
| 20 |
+
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
| 21 |
+
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
| 22 |
+
})
|
| 23 |
+
defer reg.UnregisterClient("test-client-amp-fallback")
|
| 24 |
+
|
| 25 |
+
mapper := NewModelMapper([]config.AmpModelMapping{
|
| 26 |
+
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
| 27 |
+
})
|
| 28 |
+
|
| 29 |
+
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
| 30 |
+
|
| 31 |
+
handler := func(c *gin.Context) {
|
| 32 |
+
var req struct {
|
| 33 |
+
Model string `json:"model"`
|
| 34 |
+
}
|
| 35 |
+
if err := c.ShouldBindJSON(&req); err != nil {
|
| 36 |
+
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
| 37 |
+
return
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
c.JSON(http.StatusOK, gin.H{
|
| 41 |
+
"model": req.Model,
|
| 42 |
+
"seen_model": req.Model,
|
| 43 |
+
})
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
r := gin.New()
|
| 47 |
+
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
| 48 |
+
|
| 49 |
+
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
| 50 |
+
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
| 51 |
+
req.Header.Set("Content-Type", "application/json")
|
| 52 |
+
w := httptest.NewRecorder()
|
| 53 |
+
r.ServeHTTP(w, req)
|
| 54 |
+
|
| 55 |
+
if w.Code != http.StatusOK {
|
| 56 |
+
t.Fatalf("Expected status 200, got %d", w.Code)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
var resp struct {
|
| 60 |
+
Model string `json:"model"`
|
| 61 |
+
SeenModel string `json:"seen_model"`
|
| 62 |
+
}
|
| 63 |
+
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
| 64 |
+
t.Fatalf("Failed to parse response JSON: %v", err)
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
if resp.Model != "gpt-5.2(xhigh)" {
|
| 68 |
+
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
| 69 |
+
}
|
| 70 |
+
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
| 71 |
+
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
| 72 |
+
}
|
| 73 |
+
}
|
internal/api/modules/amp/gemini_bridge.go
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"strings"
|
| 5 |
+
|
| 6 |
+
"github.com/gin-gonic/gin"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
| 10 |
+
// to our standard Gemini handler by rewriting the request context.
|
| 11 |
+
//
|
| 12 |
+
// AMP CLI format: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
| 13 |
+
// Standard format: /models/gemini-3-pro-preview:streamGenerateContent
|
| 14 |
+
//
|
| 15 |
+
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
| 16 |
+
// so the standard Gemini handler can process it.
|
| 17 |
+
//
|
| 18 |
+
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
|
| 19 |
+
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
| 20 |
+
return func(c *gin.Context) {
|
| 21 |
+
// Get the full path from the catch-all parameter
|
| 22 |
+
path := c.Param("path")
|
| 23 |
+
|
| 24 |
+
// Extract model:method from AMP CLI path format
|
| 25 |
+
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
| 26 |
+
const modelsPrefix = "/models/"
|
| 27 |
+
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
|
| 28 |
+
// Extract everything after modelsPrefix
|
| 29 |
+
actionPart := path[idx+len(modelsPrefix):]
|
| 30 |
+
|
| 31 |
+
// Check if model was mapped by FallbackHandler
|
| 32 |
+
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
| 33 |
+
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
| 34 |
+
// Replace the model part in the action
|
| 35 |
+
// actionPart is like "model-name:method"
|
| 36 |
+
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
| 37 |
+
method := actionPart[colonIdx:] // ":method"
|
| 38 |
+
actionPart = strModel + method
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// Set this as the :action parameter that the Gemini handler expects
|
| 44 |
+
c.Params = append(c.Params, gin.Param{
|
| 45 |
+
Key: "action",
|
| 46 |
+
Value: actionPart,
|
| 47 |
+
})
|
| 48 |
+
|
| 49 |
+
// Call the handler
|
| 50 |
+
handler(c)
|
| 51 |
+
return
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// If we can't parse the path, return 400
|
| 55 |
+
c.JSON(400, gin.H{
|
| 56 |
+
"error": "Invalid Gemini API path format",
|
| 57 |
+
})
|
| 58 |
+
}
|
| 59 |
+
}
|
internal/api/modules/amp/gemini_bridge_test.go
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
"net/http/httptest"
|
| 6 |
+
"testing"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
|
| 12 |
+
gin.SetMode(gin.TestMode)
|
| 13 |
+
|
| 14 |
+
tests := []struct {
|
| 15 |
+
name string
|
| 16 |
+
path string
|
| 17 |
+
mappedModel string // empty string means no mapping
|
| 18 |
+
expectedAction string
|
| 19 |
+
}{
|
| 20 |
+
{
|
| 21 |
+
name: "no_mapping_uses_url_model",
|
| 22 |
+
path: "/publishers/google/models/gemini-pro:generateContent",
|
| 23 |
+
mappedModel: "",
|
| 24 |
+
expectedAction: "gemini-pro:generateContent",
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
name: "mapped_model_replaces_url_model",
|
| 28 |
+
path: "/publishers/google/models/gemini-exp:generateContent",
|
| 29 |
+
mappedModel: "gemini-2.0-flash",
|
| 30 |
+
expectedAction: "gemini-2.0-flash:generateContent",
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
name: "mapping_preserves_method",
|
| 34 |
+
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
|
| 35 |
+
mappedModel: "gemini-flash",
|
| 36 |
+
expectedAction: "gemini-flash:streamGenerateContent",
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
for _, tt := range tests {
|
| 41 |
+
t.Run(tt.name, func(t *testing.T) {
|
| 42 |
+
var capturedAction string
|
| 43 |
+
|
| 44 |
+
mockGeminiHandler := func(c *gin.Context) {
|
| 45 |
+
capturedAction = c.Param("action")
|
| 46 |
+
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// Use the actual createGeminiBridgeHandler function
|
| 50 |
+
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
|
| 51 |
+
|
| 52 |
+
r := gin.New()
|
| 53 |
+
if tt.mappedModel != "" {
|
| 54 |
+
r.Use(func(c *gin.Context) {
|
| 55 |
+
c.Set(MappedModelContextKey, tt.mappedModel)
|
| 56 |
+
c.Next()
|
| 57 |
+
})
|
| 58 |
+
}
|
| 59 |
+
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
| 60 |
+
|
| 61 |
+
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
|
| 62 |
+
w := httptest.NewRecorder()
|
| 63 |
+
r.ServeHTTP(w, req)
|
| 64 |
+
|
| 65 |
+
if w.Code != http.StatusOK {
|
| 66 |
+
t.Fatalf("Expected status 200, got %d", w.Code)
|
| 67 |
+
}
|
| 68 |
+
if capturedAction != tt.expectedAction {
|
| 69 |
+
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
|
| 70 |
+
}
|
| 71 |
+
})
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
|
| 76 |
+
gin.SetMode(gin.TestMode)
|
| 77 |
+
|
| 78 |
+
mockHandler := func(c *gin.Context) {
|
| 79 |
+
c.JSON(http.StatusOK, gin.H{"ok": true})
|
| 80 |
+
}
|
| 81 |
+
bridgeHandler := createGeminiBridgeHandler(mockHandler)
|
| 82 |
+
|
| 83 |
+
r := gin.New()
|
| 84 |
+
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
| 85 |
+
|
| 86 |
+
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
|
| 87 |
+
w := httptest.NewRecorder()
|
| 88 |
+
r.ServeHTTP(w, req)
|
| 89 |
+
|
| 90 |
+
if w.Code != http.StatusBadRequest {
|
| 91 |
+
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
|
| 92 |
+
}
|
| 93 |
+
}
|
internal/api/modules/amp/model_mapping.go
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package amp provides model mapping functionality for routing Amp CLI requests
|
| 2 |
+
// to alternative models when the requested model is not available locally.
|
| 3 |
+
package amp
|
| 4 |
+
|
| 5 |
+
import (
|
| 6 |
+
"regexp"
|
| 7 |
+
"strings"
|
| 8 |
+
"sync"
|
| 9 |
+
|
| 10 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 11 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 12 |
+
log "github.com/sirupsen/logrus"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
|
| 16 |
+
// When an Amp request comes in for a model that isn't available locally,
|
| 17 |
+
// this mapper can redirect it to an alternative model that IS available.
|
| 18 |
+
type ModelMapper interface {
|
| 19 |
+
// MapModel returns the target model name if a mapping exists and the target
|
| 20 |
+
// model has available providers. Returns empty string if no mapping applies.
|
| 21 |
+
MapModel(requestedModel string) string
|
| 22 |
+
|
| 23 |
+
// UpdateMappings refreshes the mapping configuration (for hot-reload).
|
| 24 |
+
UpdateMappings(mappings []config.AmpModelMapping)
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
| 28 |
+
type DefaultModelMapper struct {
|
| 29 |
+
mu sync.RWMutex
|
| 30 |
+
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
| 31 |
+
regexps []regexMapping // regex rules evaluated in order
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
// NewModelMapper creates a new model mapper with the given initial mappings.
|
| 35 |
+
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
| 36 |
+
m := &DefaultModelMapper{
|
| 37 |
+
mappings: make(map[string]string),
|
| 38 |
+
regexps: nil,
|
| 39 |
+
}
|
| 40 |
+
m.UpdateMappings(mappings)
|
| 41 |
+
return m
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// MapModel checks if a mapping exists for the requested model and if the
|
| 45 |
+
// target model has available local providers. Returns the mapped model name
|
| 46 |
+
// or empty string if no valid mapping exists.
|
| 47 |
+
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
| 48 |
+
if requestedModel == "" {
|
| 49 |
+
return ""
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
m.mu.RLock()
|
| 53 |
+
defer m.mu.RUnlock()
|
| 54 |
+
|
| 55 |
+
// Normalize the requested model for lookup
|
| 56 |
+
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
| 57 |
+
|
| 58 |
+
// Check for direct mapping
|
| 59 |
+
targetModel, exists := m.mappings[normalizedRequest]
|
| 60 |
+
if !exists {
|
| 61 |
+
// Try regex mappings in order
|
| 62 |
+
base, _ := util.NormalizeThinkingModel(requestedModel)
|
| 63 |
+
for _, rm := range m.regexps {
|
| 64 |
+
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
|
| 65 |
+
targetModel = rm.to
|
| 66 |
+
exists = true
|
| 67 |
+
break
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
if !exists {
|
| 71 |
+
return ""
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// Verify target model has available providers
|
| 76 |
+
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
| 77 |
+
providers := util.GetProviderName(normalizedTarget)
|
| 78 |
+
if len(providers) == 0 {
|
| 79 |
+
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
| 80 |
+
return ""
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
| 84 |
+
return targetModel
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// UpdateMappings refreshes the mapping configuration from config.
|
| 88 |
+
// This is called during initialization and on config hot-reload.
|
| 89 |
+
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
| 90 |
+
m.mu.Lock()
|
| 91 |
+
defer m.mu.Unlock()
|
| 92 |
+
|
| 93 |
+
// Clear and rebuild mappings
|
| 94 |
+
m.mappings = make(map[string]string, len(mappings))
|
| 95 |
+
m.regexps = make([]regexMapping, 0, len(mappings))
|
| 96 |
+
|
| 97 |
+
for _, mapping := range mappings {
|
| 98 |
+
from := strings.TrimSpace(mapping.From)
|
| 99 |
+
to := strings.TrimSpace(mapping.To)
|
| 100 |
+
|
| 101 |
+
if from == "" || to == "" {
|
| 102 |
+
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
|
| 103 |
+
continue
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if mapping.Regex {
|
| 107 |
+
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
| 108 |
+
pattern := "(?i)" + from
|
| 109 |
+
re, err := regexp.Compile(pattern)
|
| 110 |
+
if err != nil {
|
| 111 |
+
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
| 112 |
+
continue
|
| 113 |
+
}
|
| 114 |
+
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
| 115 |
+
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
| 116 |
+
} else {
|
| 117 |
+
// Store with normalized lowercase key for case-insensitive lookup
|
| 118 |
+
normalizedFrom := strings.ToLower(from)
|
| 119 |
+
m.mappings[normalizedFrom] = to
|
| 120 |
+
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if len(m.mappings) > 0 {
|
| 125 |
+
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
| 126 |
+
}
|
| 127 |
+
if n := len(m.regexps); n > 0 {
|
| 128 |
+
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// GetMappings returns a copy of current mappings (for debugging/status).
|
| 133 |
+
func (m *DefaultModelMapper) GetMappings() map[string]string {
|
| 134 |
+
m.mu.RLock()
|
| 135 |
+
defer m.mu.RUnlock()
|
| 136 |
+
|
| 137 |
+
result := make(map[string]string, len(m.mappings))
|
| 138 |
+
for k, v := range m.mappings {
|
| 139 |
+
result[k] = v
|
| 140 |
+
}
|
| 141 |
+
return result
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
type regexMapping struct {
|
| 145 |
+
re *regexp.Regexp
|
| 146 |
+
to string
|
| 147 |
+
}
|
internal/api/modules/amp/model_mapping_test.go
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"testing"
|
| 5 |
+
|
| 6 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 7 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
func TestNewModelMapper(t *testing.T) {
|
| 11 |
+
mappings := []config.AmpModelMapping{
|
| 12 |
+
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
| 13 |
+
{From: "gpt-5", To: "gemini-2.5-pro"},
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
mapper := NewModelMapper(mappings)
|
| 17 |
+
if mapper == nil {
|
| 18 |
+
t.Fatal("Expected non-nil mapper")
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
result := mapper.GetMappings()
|
| 22 |
+
if len(result) != 2 {
|
| 23 |
+
t.Errorf("Expected 2 mappings, got %d", len(result))
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
func TestNewModelMapper_Empty(t *testing.T) {
|
| 28 |
+
mapper := NewModelMapper(nil)
|
| 29 |
+
if mapper == nil {
|
| 30 |
+
t.Fatal("Expected non-nil mapper")
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
result := mapper.GetMappings()
|
| 34 |
+
if len(result) != 0 {
|
| 35 |
+
t.Errorf("Expected 0 mappings, got %d", len(result))
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func TestModelMapper_MapModel_NoProvider(t *testing.T) {
|
| 40 |
+
mappings := []config.AmpModelMapping{
|
| 41 |
+
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
mapper := NewModelMapper(mappings)
|
| 45 |
+
|
| 46 |
+
// Without a registered provider for the target, mapping should return empty
|
| 47 |
+
result := mapper.MapModel("claude-opus-4.5")
|
| 48 |
+
if result != "" {
|
| 49 |
+
t.Errorf("Expected empty result when target has no provider, got %s", result)
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
| 54 |
+
// Register a mock provider for the target model
|
| 55 |
+
reg := registry.GetGlobalRegistry()
|
| 56 |
+
reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
|
| 57 |
+
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
| 58 |
+
})
|
| 59 |
+
defer reg.UnregisterClient("test-client")
|
| 60 |
+
|
| 61 |
+
mappings := []config.AmpModelMapping{
|
| 62 |
+
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
mapper := NewModelMapper(mappings)
|
| 66 |
+
|
| 67 |
+
// With a registered provider, mapping should work
|
| 68 |
+
result := mapper.MapModel("claude-opus-4.5")
|
| 69 |
+
if result != "claude-sonnet-4" {
|
| 70 |
+
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
| 75 |
+
reg := registry.GetGlobalRegistry()
|
| 76 |
+
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
| 77 |
+
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
| 78 |
+
})
|
| 79 |
+
defer reg.UnregisterClient("test-client-thinking")
|
| 80 |
+
|
| 81 |
+
mappings := []config.AmpModelMapping{
|
| 82 |
+
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
mapper := NewModelMapper(mappings)
|
| 86 |
+
|
| 87 |
+
result := mapper.MapModel("gpt-5.2-alias")
|
| 88 |
+
if result != "gpt-5.2(xhigh)" {
|
| 89 |
+
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
| 94 |
+
reg := registry.GetGlobalRegistry()
|
| 95 |
+
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
| 96 |
+
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
| 97 |
+
})
|
| 98 |
+
defer reg.UnregisterClient("test-client2")
|
| 99 |
+
|
| 100 |
+
mappings := []config.AmpModelMapping{
|
| 101 |
+
{From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
mapper := NewModelMapper(mappings)
|
| 105 |
+
|
| 106 |
+
// Should match case-insensitively
|
| 107 |
+
result := mapper.MapModel("claude-opus-4.5")
|
| 108 |
+
if result != "claude-sonnet-4" {
|
| 109 |
+
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
func TestModelMapper_MapModel_NotFound(t *testing.T) {
|
| 114 |
+
mappings := []config.AmpModelMapping{
|
| 115 |
+
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
mapper := NewModelMapper(mappings)
|
| 119 |
+
|
| 120 |
+
// Unknown model should return empty
|
| 121 |
+
result := mapper.MapModel("unknown-model")
|
| 122 |
+
if result != "" {
|
| 123 |
+
t.Errorf("Expected empty for unknown model, got %s", result)
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
|
| 128 |
+
mappings := []config.AmpModelMapping{
|
| 129 |
+
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
mapper := NewModelMapper(mappings)
|
| 133 |
+
|
| 134 |
+
result := mapper.MapModel("")
|
| 135 |
+
if result != "" {
|
| 136 |
+
t.Errorf("Expected empty for empty input, got %s", result)
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
func TestModelMapper_UpdateMappings(t *testing.T) {
|
| 141 |
+
mapper := NewModelMapper(nil)
|
| 142 |
+
|
| 143 |
+
// Initially empty
|
| 144 |
+
if len(mapper.GetMappings()) != 0 {
|
| 145 |
+
t.Error("Expected 0 initial mappings")
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// Update with new mappings
|
| 149 |
+
mapper.UpdateMappings([]config.AmpModelMapping{
|
| 150 |
+
{From: "model-a", To: "model-b"},
|
| 151 |
+
{From: "model-c", To: "model-d"},
|
| 152 |
+
})
|
| 153 |
+
|
| 154 |
+
result := mapper.GetMappings()
|
| 155 |
+
if len(result) != 2 {
|
| 156 |
+
t.Errorf("Expected 2 mappings after update, got %d", len(result))
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// Update again should replace, not append
|
| 160 |
+
mapper.UpdateMappings([]config.AmpModelMapping{
|
| 161 |
+
{From: "model-x", To: "model-y"},
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
result = mapper.GetMappings()
|
| 165 |
+
if len(result) != 1 {
|
| 166 |
+
t.Errorf("Expected 1 mapping after second update, got %d", len(result))
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
|
| 171 |
+
mapper := NewModelMapper(nil)
|
| 172 |
+
|
| 173 |
+
mapper.UpdateMappings([]config.AmpModelMapping{
|
| 174 |
+
{From: "", To: "model-b"}, // Invalid: empty from
|
| 175 |
+
{From: "model-a", To: ""}, // Invalid: empty to
|
| 176 |
+
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
| 177 |
+
{From: "model-c", To: "model-d"}, // Valid
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
result := mapper.GetMappings()
|
| 181 |
+
if len(result) != 1 {
|
| 182 |
+
t.Errorf("Expected 1 valid mapping, got %d", len(result))
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
| 187 |
+
mappings := []config.AmpModelMapping{
|
| 188 |
+
{From: "model-a", To: "model-b"},
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
mapper := NewModelMapper(mappings)
|
| 192 |
+
|
| 193 |
+
// Get mappings and modify the returned map
|
| 194 |
+
result := mapper.GetMappings()
|
| 195 |
+
result["new-key"] = "new-value"
|
| 196 |
+
|
| 197 |
+
// Original should be unchanged
|
| 198 |
+
original := mapper.GetMappings()
|
| 199 |
+
if len(original) != 1 {
|
| 200 |
+
t.Errorf("Expected original to have 1 mapping, got %d", len(original))
|
| 201 |
+
}
|
| 202 |
+
if _, exists := original["new-key"]; exists {
|
| 203 |
+
t.Error("Original map was modified")
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
| 208 |
+
reg := registry.GetGlobalRegistry()
|
| 209 |
+
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
| 210 |
+
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
| 211 |
+
})
|
| 212 |
+
defer reg.UnregisterClient("test-client-regex-1")
|
| 213 |
+
|
| 214 |
+
mappings := []config.AmpModelMapping{
|
| 215 |
+
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
mapper := NewModelMapper(mappings)
|
| 219 |
+
|
| 220 |
+
// Incoming model has reasoning suffix but should match base via regex
|
| 221 |
+
result := mapper.MapModel("gpt-5(high)")
|
| 222 |
+
if result != "gemini-2.5-pro" {
|
| 223 |
+
t.Errorf("Expected gemini-2.5-pro, got %s", result)
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
| 228 |
+
reg := registry.GetGlobalRegistry()
|
| 229 |
+
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
| 230 |
+
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
| 231 |
+
})
|
| 232 |
+
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
| 233 |
+
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
| 234 |
+
})
|
| 235 |
+
defer reg.UnregisterClient("test-client-regex-2")
|
| 236 |
+
defer reg.UnregisterClient("test-client-regex-3")
|
| 237 |
+
|
| 238 |
+
mappings := []config.AmpModelMapping{
|
| 239 |
+
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
| 240 |
+
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
mapper := NewModelMapper(mappings)
|
| 244 |
+
|
| 245 |
+
// Exact match should win over regex
|
| 246 |
+
result := mapper.MapModel("gpt-5")
|
| 247 |
+
if result != "claude-sonnet-4" {
|
| 248 |
+
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
| 253 |
+
// Invalid regex should be skipped and not cause panic
|
| 254 |
+
mappings := []config.AmpModelMapping{
|
| 255 |
+
{From: "(", To: "target", Regex: true},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
mapper := NewModelMapper(mappings)
|
| 259 |
+
|
| 260 |
+
result := mapper.MapModel("anything")
|
| 261 |
+
if result != "" {
|
| 262 |
+
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
| 267 |
+
reg := registry.GetGlobalRegistry()
|
| 268 |
+
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
| 269 |
+
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
| 270 |
+
})
|
| 271 |
+
defer reg.UnregisterClient("test-client-regex-4")
|
| 272 |
+
|
| 273 |
+
mappings := []config.AmpModelMapping{
|
| 274 |
+
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
mapper := NewModelMapper(mappings)
|
| 278 |
+
|
| 279 |
+
result := mapper.MapModel("claude-opus-4.5")
|
| 280 |
+
if result != "claude-sonnet-4" {
|
| 281 |
+
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
| 282 |
+
}
|
| 283 |
+
}
|
internal/api/modules/amp/proxy.go
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"compress/gzip"
|
| 6 |
+
"context"
|
| 7 |
+
"errors"
|
| 8 |
+
"fmt"
|
| 9 |
+
"io"
|
| 10 |
+
"net"
|
| 11 |
+
"net/http"
|
| 12 |
+
"net/http/httputil"
|
| 13 |
+
"net/url"
|
| 14 |
+
"strconv"
|
| 15 |
+
"strings"
|
| 16 |
+
|
| 17 |
+
"github.com/gin-gonic/gin"
|
| 18 |
+
log "github.com/sirupsen/logrus"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
| 22 |
+
if req == nil || req.URL == nil || match == "" {
|
| 23 |
+
return
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
q := req.URL.Query()
|
| 27 |
+
values, ok := q[key]
|
| 28 |
+
if !ok || len(values) == 0 {
|
| 29 |
+
return
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
kept := make([]string, 0, len(values))
|
| 33 |
+
for _, v := range values {
|
| 34 |
+
if v == match {
|
| 35 |
+
continue
|
| 36 |
+
}
|
| 37 |
+
kept = append(kept, v)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if len(kept) == 0 {
|
| 41 |
+
q.Del(key)
|
| 42 |
+
} else {
|
| 43 |
+
q[key] = kept
|
| 44 |
+
}
|
| 45 |
+
req.URL.RawQuery = q.Encode()
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// readCloser wraps a reader and forwards Close to a separate closer.
|
| 49 |
+
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
| 50 |
+
type readCloser struct {
|
| 51 |
+
r io.Reader
|
| 52 |
+
c io.Closer
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
|
| 56 |
+
func (rc *readCloser) Close() error { return rc.c.Close() }
|
| 57 |
+
|
| 58 |
+
// createReverseProxy creates a reverse proxy handler for Amp upstream
|
| 59 |
+
// with automatic gzip decompression via ModifyResponse
|
| 60 |
+
func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) {
|
| 61 |
+
parsed, err := url.Parse(upstreamURL)
|
| 62 |
+
if err != nil {
|
| 63 |
+
return nil, fmt.Errorf("invalid amp upstream url: %w", err)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
proxy := httputil.NewSingleHostReverseProxy(parsed)
|
| 67 |
+
originalDirector := proxy.Director
|
| 68 |
+
|
| 69 |
+
// Modify outgoing requests to inject API key and fix routing
|
| 70 |
+
proxy.Director = func(req *http.Request) {
|
| 71 |
+
originalDirector(req)
|
| 72 |
+
req.Host = parsed.Host
|
| 73 |
+
|
| 74 |
+
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
| 75 |
+
// We will set our own Authorization using the configured upstream-api-key
|
| 76 |
+
req.Header.Del("Authorization")
|
| 77 |
+
req.Header.Del("X-Api-Key")
|
| 78 |
+
req.Header.Del("X-Goog-Api-Key")
|
| 79 |
+
|
| 80 |
+
// Remove query-based credentials if they match the authenticated client API key.
|
| 81 |
+
// This prevents leaking client auth material to the Amp upstream while avoiding
|
| 82 |
+
// breaking unrelated upstream query parameters.
|
| 83 |
+
clientKey := getClientAPIKeyFromContext(req.Context())
|
| 84 |
+
removeQueryValuesMatching(req, "key", clientKey)
|
| 85 |
+
removeQueryValuesMatching(req, "auth_token", clientKey)
|
| 86 |
+
|
| 87 |
+
// Preserve correlation headers for debugging
|
| 88 |
+
if req.Header.Get("X-Request-ID") == "" {
|
| 89 |
+
// Could generate one here if needed
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// Note: We do NOT filter Anthropic-Beta headers in the proxy path
|
| 93 |
+
// Users going through ampcode.com proxy are paying for the service and should get all features
|
| 94 |
+
// including 1M context window (context-1m-2025-08-07)
|
| 95 |
+
|
| 96 |
+
// Inject API key from secret source (only uses upstream-api-key from config)
|
| 97 |
+
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
| 98 |
+
req.Header.Set("X-Api-Key", key)
|
| 99 |
+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
| 100 |
+
} else if err != nil {
|
| 101 |
+
log.Warnf("amp secret source error (continuing without auth): %v", err)
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Modify incoming responses to handle gzip without Content-Encoding
|
| 106 |
+
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
| 107 |
+
proxy.ModifyResponse = func(resp *http.Response) error {
|
| 108 |
+
// Log upstream error responses for diagnostics (502, 503, etc.)
|
| 109 |
+
// These are NOT proxy connection errors - the upstream responded with an error status
|
| 110 |
+
if resp.StatusCode >= 500 {
|
| 111 |
+
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
| 112 |
+
} else if resp.StatusCode >= 400 {
|
| 113 |
+
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Only process successful responses for gzip decompression
|
| 117 |
+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
| 118 |
+
return nil
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// Skip if already marked as gzip (Content-Encoding set)
|
| 122 |
+
if resp.Header.Get("Content-Encoding") != "" {
|
| 123 |
+
return nil
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Skip streaming responses (SSE, chunked)
|
| 127 |
+
if isStreamingResponse(resp) {
|
| 128 |
+
return nil
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// Save reference to original upstream body for proper cleanup
|
| 132 |
+
originalBody := resp.Body
|
| 133 |
+
|
| 134 |
+
// Peek at first 2 bytes to detect gzip magic bytes
|
| 135 |
+
header := make([]byte, 2)
|
| 136 |
+
n, _ := io.ReadFull(originalBody, header)
|
| 137 |
+
|
| 138 |
+
// Check for gzip magic bytes (0x1f 0x8b)
|
| 139 |
+
// If n < 2, we didn't get enough bytes, so it's not gzip
|
| 140 |
+
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
|
| 141 |
+
// It's gzip - read the rest of the body
|
| 142 |
+
rest, err := io.ReadAll(originalBody)
|
| 143 |
+
if err != nil {
|
| 144 |
+
// Restore what we read and return original body (preserve Close behavior)
|
| 145 |
+
resp.Body = &readCloser{
|
| 146 |
+
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
| 147 |
+
c: originalBody,
|
| 148 |
+
}
|
| 149 |
+
return nil
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Reconstruct complete gzipped data
|
| 153 |
+
gzippedData := append(header[:n], rest...)
|
| 154 |
+
|
| 155 |
+
// Decompress
|
| 156 |
+
gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData))
|
| 157 |
+
if err != nil {
|
| 158 |
+
log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err)
|
| 159 |
+
// Close original body and return in-memory copy
|
| 160 |
+
_ = originalBody.Close()
|
| 161 |
+
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
| 162 |
+
return nil
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
decompressed, err := io.ReadAll(gzipReader)
|
| 166 |
+
_ = gzipReader.Close()
|
| 167 |
+
if err != nil {
|
| 168 |
+
log.Warnf("amp proxy: gzip decompress error: %v", err)
|
| 169 |
+
// Close original body and return in-memory copy
|
| 170 |
+
_ = originalBody.Close()
|
| 171 |
+
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
|
| 172 |
+
return nil
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Close original body since we're replacing with in-memory decompressed content
|
| 176 |
+
_ = originalBody.Close()
|
| 177 |
+
|
| 178 |
+
// Replace body with decompressed content
|
| 179 |
+
resp.Body = io.NopCloser(bytes.NewReader(decompressed))
|
| 180 |
+
resp.ContentLength = int64(len(decompressed))
|
| 181 |
+
|
| 182 |
+
// Update headers to reflect decompressed state
|
| 183 |
+
resp.Header.Del("Content-Encoding") // No longer compressed
|
| 184 |
+
resp.Header.Del("Content-Length") // Remove stale compressed length
|
| 185 |
+
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
|
| 186 |
+
|
| 187 |
+
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
|
| 188 |
+
} else {
|
| 189 |
+
// Not gzip - restore peeked bytes while preserving Close behavior
|
| 190 |
+
// Handle edge cases: n might be 0, 1, or 2 depending on EOF
|
| 191 |
+
resp.Body = &readCloser{
|
| 192 |
+
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
|
| 193 |
+
c: originalBody,
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return nil
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
// Error handler for proxy failures with detailed error classification for diagnostics
|
| 201 |
+
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
| 202 |
+
// Classify the error type for better diagnostics
|
| 203 |
+
var errType string
|
| 204 |
+
if errors.Is(err, context.DeadlineExceeded) {
|
| 205 |
+
errType = "timeout"
|
| 206 |
+
} else if errors.Is(err, context.Canceled) {
|
| 207 |
+
errType = "canceled"
|
| 208 |
+
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
| 209 |
+
errType = "dial_timeout"
|
| 210 |
+
} else if _, ok := err.(net.Error); ok {
|
| 211 |
+
errType = "network_error"
|
| 212 |
+
} else {
|
| 213 |
+
errType = "connection_error"
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// Don't log as error for context canceled - it's usually client closing connection
|
| 217 |
+
if errors.Is(err, context.Canceled) {
|
| 218 |
+
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
| 219 |
+
} else {
|
| 220 |
+
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
rw.Header().Set("Content-Type", "application/json")
|
| 224 |
+
rw.WriteHeader(http.StatusBadGateway)
|
| 225 |
+
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
return proxy, nil
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
// isStreamingResponse detects if the response is streaming (SSE only)
|
| 232 |
+
// Note: We only treat text/event-stream as streaming. Chunked transfer encoding
|
| 233 |
+
// is a transport-level detail and doesn't mean we can't decompress the full response.
|
| 234 |
+
// Many JSON APIs use chunked encoding for normal responses.
|
| 235 |
+
func isStreamingResponse(resp *http.Response) bool {
|
| 236 |
+
contentType := resp.Header.Get("Content-Type")
|
| 237 |
+
|
| 238 |
+
// Only Server-Sent Events are true streaming responses
|
| 239 |
+
if strings.Contains(contentType, "text/event-stream") {
|
| 240 |
+
return true
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
return false
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc
|
| 247 |
+
func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc {
|
| 248 |
+
return func(c *gin.Context) {
|
| 249 |
+
proxy.ServeHTTP(c.Writer, c.Request)
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
// filterBetaFeatures removes a specific beta feature from comma-separated list
|
| 254 |
+
func filterBetaFeatures(header, featureToRemove string) string {
|
| 255 |
+
features := strings.Split(header, ",")
|
| 256 |
+
filtered := make([]string, 0, len(features))
|
| 257 |
+
|
| 258 |
+
for _, feature := range features {
|
| 259 |
+
trimmed := strings.TrimSpace(feature)
|
| 260 |
+
if trimmed != "" && trimmed != featureToRemove {
|
| 261 |
+
filtered = append(filtered, trimmed)
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
return strings.Join(filtered, ",")
|
| 266 |
+
}
|
internal/api/modules/amp/proxy_test.go
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"compress/gzip"
|
| 6 |
+
"context"
|
| 7 |
+
"fmt"
|
| 8 |
+
"io"
|
| 9 |
+
"net/http"
|
| 10 |
+
"net/http/httptest"
|
| 11 |
+
"strings"
|
| 12 |
+
"testing"
|
| 13 |
+
|
| 14 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
// Helper: compress data with gzip
|
| 18 |
+
func gzipBytes(b []byte) []byte {
|
| 19 |
+
var buf bytes.Buffer
|
| 20 |
+
zw := gzip.NewWriter(&buf)
|
| 21 |
+
zw.Write(b)
|
| 22 |
+
zw.Close()
|
| 23 |
+
return buf.Bytes()
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// Helper: create a mock http.Response
|
| 27 |
+
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
|
| 28 |
+
if hdr == nil {
|
| 29 |
+
hdr = http.Header{}
|
| 30 |
+
}
|
| 31 |
+
return &http.Response{
|
| 32 |
+
StatusCode: status,
|
| 33 |
+
Header: hdr,
|
| 34 |
+
Body: io.NopCloser(bytes.NewReader(body)),
|
| 35 |
+
ContentLength: int64(len(body)),
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func TestCreateReverseProxy_ValidURL(t *testing.T) {
|
| 40 |
+
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
|
| 41 |
+
if err != nil {
|
| 42 |
+
t.Fatalf("expected no error, got: %v", err)
|
| 43 |
+
}
|
| 44 |
+
if proxy == nil {
|
| 45 |
+
t.Fatal("expected proxy to be created")
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
|
| 50 |
+
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
|
| 51 |
+
if err == nil {
|
| 52 |
+
t.Fatal("expected error for invalid URL")
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
func TestModifyResponse_GzipScenarios(t *testing.T) {
|
| 57 |
+
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
| 58 |
+
if err != nil {
|
| 59 |
+
t.Fatal(err)
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
goodJSON := []byte(`{"ok":true}`)
|
| 63 |
+
good := gzipBytes(goodJSON)
|
| 64 |
+
truncated := good[:10]
|
| 65 |
+
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
|
| 66 |
+
|
| 67 |
+
cases := []struct {
|
| 68 |
+
name string
|
| 69 |
+
header http.Header
|
| 70 |
+
body []byte
|
| 71 |
+
status int
|
| 72 |
+
wantBody []byte
|
| 73 |
+
wantCE string
|
| 74 |
+
}{
|
| 75 |
+
{
|
| 76 |
+
name: "decompresses_valid_gzip_no_header",
|
| 77 |
+
header: http.Header{},
|
| 78 |
+
body: good,
|
| 79 |
+
status: 200,
|
| 80 |
+
wantBody: goodJSON,
|
| 81 |
+
wantCE: "",
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
name: "skips_when_ce_present",
|
| 85 |
+
header: http.Header{"Content-Encoding": []string{"gzip"}},
|
| 86 |
+
body: good,
|
| 87 |
+
status: 200,
|
| 88 |
+
wantBody: good,
|
| 89 |
+
wantCE: "gzip",
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
name: "passes_truncated_unchanged",
|
| 93 |
+
header: http.Header{},
|
| 94 |
+
body: truncated,
|
| 95 |
+
status: 200,
|
| 96 |
+
wantBody: truncated,
|
| 97 |
+
wantCE: "",
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
name: "passes_corrupted_unchanged",
|
| 101 |
+
header: http.Header{},
|
| 102 |
+
body: corrupted,
|
| 103 |
+
status: 200,
|
| 104 |
+
wantBody: corrupted,
|
| 105 |
+
wantCE: "",
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
name: "non_gzip_unchanged",
|
| 109 |
+
header: http.Header{},
|
| 110 |
+
body: []byte("plain"),
|
| 111 |
+
status: 200,
|
| 112 |
+
wantBody: []byte("plain"),
|
| 113 |
+
wantCE: "",
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
name: "empty_body",
|
| 117 |
+
header: http.Header{},
|
| 118 |
+
body: []byte{},
|
| 119 |
+
status: 200,
|
| 120 |
+
wantBody: []byte{},
|
| 121 |
+
wantCE: "",
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
name: "single_byte_body",
|
| 125 |
+
header: http.Header{},
|
| 126 |
+
body: []byte{0x1f},
|
| 127 |
+
status: 200,
|
| 128 |
+
wantBody: []byte{0x1f},
|
| 129 |
+
wantCE: "",
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
name: "skips_non_2xx_status",
|
| 133 |
+
header: http.Header{},
|
| 134 |
+
body: good,
|
| 135 |
+
status: 404,
|
| 136 |
+
wantBody: good,
|
| 137 |
+
wantCE: "",
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
for _, tc := range cases {
|
| 142 |
+
t.Run(tc.name, func(t *testing.T) {
|
| 143 |
+
resp := mkResp(tc.status, tc.header, tc.body)
|
| 144 |
+
if err := proxy.ModifyResponse(resp); err != nil {
|
| 145 |
+
t.Fatalf("ModifyResponse error: %v", err)
|
| 146 |
+
}
|
| 147 |
+
got, err := io.ReadAll(resp.Body)
|
| 148 |
+
if err != nil {
|
| 149 |
+
t.Fatalf("ReadAll error: %v", err)
|
| 150 |
+
}
|
| 151 |
+
if !bytes.Equal(got, tc.wantBody) {
|
| 152 |
+
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
|
| 153 |
+
}
|
| 154 |
+
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
|
| 155 |
+
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
|
| 156 |
+
}
|
| 157 |
+
})
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
|
| 162 |
+
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
| 163 |
+
if err != nil {
|
| 164 |
+
t.Fatal(err)
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
goodJSON := []byte(`{"message":"test response"}`)
|
| 168 |
+
gzipped := gzipBytes(goodJSON)
|
| 169 |
+
|
| 170 |
+
// Simulate upstream response with gzip body AND Content-Length header
|
| 171 |
+
// (this is the scenario the bot flagged - stale Content-Length after decompression)
|
| 172 |
+
resp := mkResp(200, http.Header{
|
| 173 |
+
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
|
| 174 |
+
}, gzipped)
|
| 175 |
+
|
| 176 |
+
if err := proxy.ModifyResponse(resp); err != nil {
|
| 177 |
+
t.Fatalf("ModifyResponse error: %v", err)
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Verify body is decompressed
|
| 181 |
+
got, _ := io.ReadAll(resp.Body)
|
| 182 |
+
if !bytes.Equal(got, goodJSON) {
|
| 183 |
+
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// Verify Content-Length header is updated to decompressed size
|
| 187 |
+
wantCL := fmt.Sprintf("%d", len(goodJSON))
|
| 188 |
+
gotCL := resp.Header.Get("Content-Length")
|
| 189 |
+
if gotCL != wantCL {
|
| 190 |
+
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Verify struct field also matches
|
| 194 |
+
if resp.ContentLength != int64(len(goodJSON)) {
|
| 195 |
+
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
|
| 200 |
+
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
| 201 |
+
if err != nil {
|
| 202 |
+
t.Fatal(err)
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
goodJSON := []byte(`{"ok":true}`)
|
| 206 |
+
gzipped := gzipBytes(goodJSON)
|
| 207 |
+
|
| 208 |
+
t.Run("sse_skips_decompression", func(t *testing.T) {
|
| 209 |
+
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
|
| 210 |
+
if err := proxy.ModifyResponse(resp); err != nil {
|
| 211 |
+
t.Fatalf("ModifyResponse error: %v", err)
|
| 212 |
+
}
|
| 213 |
+
// SSE should NOT be decompressed
|
| 214 |
+
got, _ := io.ReadAll(resp.Body)
|
| 215 |
+
if !bytes.Equal(got, gzipped) {
|
| 216 |
+
t.Fatal("SSE response should not be decompressed")
|
| 217 |
+
}
|
| 218 |
+
})
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
|
| 222 |
+
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
|
| 223 |
+
if err != nil {
|
| 224 |
+
t.Fatal(err)
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
goodJSON := []byte(`{"ok":true}`)
|
| 228 |
+
gzipped := gzipBytes(goodJSON)
|
| 229 |
+
|
| 230 |
+
t.Run("chunked_json_decompresses", func(t *testing.T) {
|
| 231 |
+
// Chunked JSON responses (like thread APIs) should be decompressed
|
| 232 |
+
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
|
| 233 |
+
if err := proxy.ModifyResponse(resp); err != nil {
|
| 234 |
+
t.Fatalf("ModifyResponse error: %v", err)
|
| 235 |
+
}
|
| 236 |
+
// Should decompress because it's not SSE
|
| 237 |
+
got, _ := io.ReadAll(resp.Body)
|
| 238 |
+
if !bytes.Equal(got, goodJSON) {
|
| 239 |
+
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
|
| 240 |
+
}
|
| 241 |
+
})
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
func TestReverseProxy_InjectsHeaders(t *testing.T) {
|
| 245 |
+
gotHeaders := make(chan http.Header, 1)
|
| 246 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 247 |
+
gotHeaders <- r.Header.Clone()
|
| 248 |
+
w.WriteHeader(200)
|
| 249 |
+
w.Write([]byte(`ok`))
|
| 250 |
+
}))
|
| 251 |
+
defer upstream.Close()
|
| 252 |
+
|
| 253 |
+
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
|
| 254 |
+
if err != nil {
|
| 255 |
+
t.Fatal(err)
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 259 |
+
proxy.ServeHTTP(w, r)
|
| 260 |
+
}))
|
| 261 |
+
defer srv.Close()
|
| 262 |
+
|
| 263 |
+
res, err := http.Get(srv.URL + "/test")
|
| 264 |
+
if err != nil {
|
| 265 |
+
t.Fatal(err)
|
| 266 |
+
}
|
| 267 |
+
res.Body.Close()
|
| 268 |
+
|
| 269 |
+
hdr := <-gotHeaders
|
| 270 |
+
if hdr.Get("X-Api-Key") != "secret" {
|
| 271 |
+
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
| 272 |
+
}
|
| 273 |
+
if hdr.Get("Authorization") != "Bearer secret" {
|
| 274 |
+
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
func TestReverseProxy_EmptySecret(t *testing.T) {
|
| 279 |
+
gotHeaders := make(chan http.Header, 1)
|
| 280 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 281 |
+
gotHeaders <- r.Header.Clone()
|
| 282 |
+
w.WriteHeader(200)
|
| 283 |
+
w.Write([]byte(`ok`))
|
| 284 |
+
}))
|
| 285 |
+
defer upstream.Close()
|
| 286 |
+
|
| 287 |
+
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
|
| 288 |
+
if err != nil {
|
| 289 |
+
t.Fatal(err)
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 293 |
+
proxy.ServeHTTP(w, r)
|
| 294 |
+
}))
|
| 295 |
+
defer srv.Close()
|
| 296 |
+
|
| 297 |
+
res, err := http.Get(srv.URL + "/test")
|
| 298 |
+
if err != nil {
|
| 299 |
+
t.Fatal(err)
|
| 300 |
+
}
|
| 301 |
+
res.Body.Close()
|
| 302 |
+
|
| 303 |
+
hdr := <-gotHeaders
|
| 304 |
+
// Should NOT inject headers when secret is empty
|
| 305 |
+
if hdr.Get("X-Api-Key") != "" {
|
| 306 |
+
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
|
| 307 |
+
}
|
| 308 |
+
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
|
| 309 |
+
t.Fatalf("Authorization should not be set, got: %q", authVal)
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
| 314 |
+
type captured struct {
|
| 315 |
+
headers http.Header
|
| 316 |
+
query string
|
| 317 |
+
}
|
| 318 |
+
got := make(chan captured, 1)
|
| 319 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 320 |
+
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
| 321 |
+
w.WriteHeader(200)
|
| 322 |
+
w.Write([]byte(`ok`))
|
| 323 |
+
}))
|
| 324 |
+
defer upstream.Close()
|
| 325 |
+
|
| 326 |
+
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
| 327 |
+
if err != nil {
|
| 328 |
+
t.Fatal(err)
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 332 |
+
// Simulate clientAPIKeyMiddleware injection (per-request)
|
| 333 |
+
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
| 334 |
+
proxy.ServeHTTP(w, r.WithContext(ctx))
|
| 335 |
+
}))
|
| 336 |
+
defer srv.Close()
|
| 337 |
+
|
| 338 |
+
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
| 339 |
+
if err != nil {
|
| 340 |
+
t.Fatal(err)
|
| 341 |
+
}
|
| 342 |
+
req.Header.Set("Authorization", "Bearer client-key")
|
| 343 |
+
req.Header.Set("X-Api-Key", "client-key")
|
| 344 |
+
req.Header.Set("X-Goog-Api-Key", "client-key")
|
| 345 |
+
|
| 346 |
+
res, err := http.DefaultClient.Do(req)
|
| 347 |
+
if err != nil {
|
| 348 |
+
t.Fatal(err)
|
| 349 |
+
}
|
| 350 |
+
res.Body.Close()
|
| 351 |
+
|
| 352 |
+
c := <-got
|
| 353 |
+
|
| 354 |
+
// These are client-provided credentials and must not reach the upstream.
|
| 355 |
+
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
| 356 |
+
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
| 360 |
+
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
| 361 |
+
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
| 362 |
+
}
|
| 363 |
+
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
| 364 |
+
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
// Query-based credentials should be stripped only when they match the authenticated client key.
|
| 368 |
+
// Should keep unrelated values and parameters.
|
| 369 |
+
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
| 370 |
+
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
| 371 |
+
}
|
| 372 |
+
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
| 373 |
+
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
| 378 |
+
gotHeaders := make(chan http.Header, 1)
|
| 379 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 380 |
+
gotHeaders <- r.Header.Clone()
|
| 381 |
+
w.WriteHeader(200)
|
| 382 |
+
w.Write([]byte(`ok`))
|
| 383 |
+
}))
|
| 384 |
+
defer upstream.Close()
|
| 385 |
+
|
| 386 |
+
defaultSource := NewStaticSecretSource("default")
|
| 387 |
+
mapped := NewMappedSecretSource(defaultSource)
|
| 388 |
+
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
| 389 |
+
{
|
| 390 |
+
UpstreamAPIKey: "u1",
|
| 391 |
+
APIKeys: []string{"k1"},
|
| 392 |
+
},
|
| 393 |
+
})
|
| 394 |
+
|
| 395 |
+
proxy, err := createReverseProxy(upstream.URL, mapped)
|
| 396 |
+
if err != nil {
|
| 397 |
+
t.Fatal(err)
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 401 |
+
// Simulate clientAPIKeyMiddleware injection (per-request)
|
| 402 |
+
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
| 403 |
+
proxy.ServeHTTP(w, r.WithContext(ctx))
|
| 404 |
+
}))
|
| 405 |
+
defer srv.Close()
|
| 406 |
+
|
| 407 |
+
res, err := http.Get(srv.URL + "/test")
|
| 408 |
+
if err != nil {
|
| 409 |
+
t.Fatal(err)
|
| 410 |
+
}
|
| 411 |
+
res.Body.Close()
|
| 412 |
+
|
| 413 |
+
hdr := <-gotHeaders
|
| 414 |
+
if hdr.Get("X-Api-Key") != "u1" {
|
| 415 |
+
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
| 416 |
+
}
|
| 417 |
+
if hdr.Get("Authorization") != "Bearer u1" {
|
| 418 |
+
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
| 423 |
+
gotHeaders := make(chan http.Header, 1)
|
| 424 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 425 |
+
gotHeaders <- r.Header.Clone()
|
| 426 |
+
w.WriteHeader(200)
|
| 427 |
+
w.Write([]byte(`ok`))
|
| 428 |
+
}))
|
| 429 |
+
defer upstream.Close()
|
| 430 |
+
|
| 431 |
+
defaultSource := NewStaticSecretSource("default")
|
| 432 |
+
mapped := NewMappedSecretSource(defaultSource)
|
| 433 |
+
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
| 434 |
+
{
|
| 435 |
+
UpstreamAPIKey: "u1",
|
| 436 |
+
APIKeys: []string{"k1"},
|
| 437 |
+
},
|
| 438 |
+
})
|
| 439 |
+
|
| 440 |
+
proxy, err := createReverseProxy(upstream.URL, mapped)
|
| 441 |
+
if err != nil {
|
| 442 |
+
t.Fatal(err)
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 446 |
+
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
| 447 |
+
proxy.ServeHTTP(w, r.WithContext(ctx))
|
| 448 |
+
}))
|
| 449 |
+
defer srv.Close()
|
| 450 |
+
|
| 451 |
+
res, err := http.Get(srv.URL + "/test")
|
| 452 |
+
if err != nil {
|
| 453 |
+
t.Fatal(err)
|
| 454 |
+
}
|
| 455 |
+
res.Body.Close()
|
| 456 |
+
|
| 457 |
+
hdr := <-gotHeaders
|
| 458 |
+
if hdr.Get("X-Api-Key") != "default" {
|
| 459 |
+
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
| 460 |
+
}
|
| 461 |
+
if hdr.Get("Authorization") != "Bearer default" {
|
| 462 |
+
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
| 467 |
+
// Point proxy to a non-routable address to trigger error
|
| 468 |
+
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
| 469 |
+
if err != nil {
|
| 470 |
+
t.Fatal(err)
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 474 |
+
proxy.ServeHTTP(w, r)
|
| 475 |
+
}))
|
| 476 |
+
defer srv.Close()
|
| 477 |
+
|
| 478 |
+
res, err := http.Get(srv.URL + "/any")
|
| 479 |
+
if err != nil {
|
| 480 |
+
t.Fatal(err)
|
| 481 |
+
}
|
| 482 |
+
body, _ := io.ReadAll(res.Body)
|
| 483 |
+
res.Body.Close()
|
| 484 |
+
|
| 485 |
+
if res.StatusCode != http.StatusBadGateway {
|
| 486 |
+
t.Fatalf("want 502, got %d", res.StatusCode)
|
| 487 |
+
}
|
| 488 |
+
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
|
| 489 |
+
t.Fatalf("unexpected body: %s", body)
|
| 490 |
+
}
|
| 491 |
+
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
|
| 492 |
+
t.Fatalf("content-type: want application/json, got %s", ct)
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
| 497 |
+
// Upstream returns gzipped JSON without Content-Encoding header
|
| 498 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 499 |
+
w.WriteHeader(200)
|
| 500 |
+
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
|
| 501 |
+
}))
|
| 502 |
+
defer upstream.Close()
|
| 503 |
+
|
| 504 |
+
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
| 505 |
+
if err != nil {
|
| 506 |
+
t.Fatal(err)
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 510 |
+
proxy.ServeHTTP(w, r)
|
| 511 |
+
}))
|
| 512 |
+
defer srv.Close()
|
| 513 |
+
|
| 514 |
+
res, err := http.Get(srv.URL + "/test")
|
| 515 |
+
if err != nil {
|
| 516 |
+
t.Fatal(err)
|
| 517 |
+
}
|
| 518 |
+
body, _ := io.ReadAll(res.Body)
|
| 519 |
+
res.Body.Close()
|
| 520 |
+
|
| 521 |
+
expected := []byte(`{"upstream":"ok"}`)
|
| 522 |
+
if !bytes.Equal(body, expected) {
|
| 523 |
+
t.Fatalf("want decompressed JSON, got: %s", body)
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
|
| 528 |
+
// Upstream returns plain JSON
|
| 529 |
+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 530 |
+
w.Header().Set("Content-Type", "application/json")
|
| 531 |
+
w.WriteHeader(200)
|
| 532 |
+
w.Write([]byte(`{"plain":"json"}`))
|
| 533 |
+
}))
|
| 534 |
+
defer upstream.Close()
|
| 535 |
+
|
| 536 |
+
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
|
| 537 |
+
if err != nil {
|
| 538 |
+
t.Fatal(err)
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 542 |
+
proxy.ServeHTTP(w, r)
|
| 543 |
+
}))
|
| 544 |
+
defer srv.Close()
|
| 545 |
+
|
| 546 |
+
res, err := http.Get(srv.URL + "/test")
|
| 547 |
+
if err != nil {
|
| 548 |
+
t.Fatal(err)
|
| 549 |
+
}
|
| 550 |
+
body, _ := io.ReadAll(res.Body)
|
| 551 |
+
res.Body.Close()
|
| 552 |
+
|
| 553 |
+
expected := []byte(`{"plain":"json"}`)
|
| 554 |
+
if !bytes.Equal(body, expected) {
|
| 555 |
+
t.Fatalf("want plain JSON unchanged, got: %s", body)
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
func TestIsStreamingResponse(t *testing.T) {
|
| 560 |
+
cases := []struct {
|
| 561 |
+
name string
|
| 562 |
+
header http.Header
|
| 563 |
+
want bool
|
| 564 |
+
}{
|
| 565 |
+
{
|
| 566 |
+
name: "sse",
|
| 567 |
+
header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
| 568 |
+
want: true,
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
name: "chunked_not_streaming",
|
| 572 |
+
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
|
| 573 |
+
want: false, // Chunked is transport-level, not streaming
|
| 574 |
+
},
|
| 575 |
+
{
|
| 576 |
+
name: "normal_json",
|
| 577 |
+
header: http.Header{"Content-Type": []string{"application/json"}},
|
| 578 |
+
want: false,
|
| 579 |
+
},
|
| 580 |
+
{
|
| 581 |
+
name: "empty",
|
| 582 |
+
header: http.Header{},
|
| 583 |
+
want: false,
|
| 584 |
+
},
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
for _, tc := range cases {
|
| 588 |
+
t.Run(tc.name, func(t *testing.T) {
|
| 589 |
+
resp := &http.Response{Header: tc.header}
|
| 590 |
+
got := isStreamingResponse(resp)
|
| 591 |
+
if got != tc.want {
|
| 592 |
+
t.Fatalf("want %v, got %v", tc.want, got)
|
| 593 |
+
}
|
| 594 |
+
})
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
func TestFilterBetaFeatures(t *testing.T) {
|
| 599 |
+
tests := []struct {
|
| 600 |
+
name string
|
| 601 |
+
header string
|
| 602 |
+
featureToRemove string
|
| 603 |
+
expected string
|
| 604 |
+
}{
|
| 605 |
+
{
|
| 606 |
+
name: "Remove context-1m from middle",
|
| 607 |
+
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
| 608 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 609 |
+
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
| 610 |
+
},
|
| 611 |
+
{
|
| 612 |
+
name: "Remove context-1m from start",
|
| 613 |
+
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
| 614 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 615 |
+
expected: "fine-grained-tool-streaming-2025-05-14",
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
name: "Remove context-1m from end",
|
| 619 |
+
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
| 620 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 621 |
+
expected: "fine-grained-tool-streaming-2025-05-14",
|
| 622 |
+
},
|
| 623 |
+
{
|
| 624 |
+
name: "Feature not present",
|
| 625 |
+
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
| 626 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 627 |
+
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
| 628 |
+
},
|
| 629 |
+
{
|
| 630 |
+
name: "Only feature to remove",
|
| 631 |
+
header: "context-1m-2025-08-07",
|
| 632 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 633 |
+
expected: "",
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
name: "Empty header",
|
| 637 |
+
header: "",
|
| 638 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 639 |
+
expected: "",
|
| 640 |
+
},
|
| 641 |
+
{
|
| 642 |
+
name: "Header with spaces",
|
| 643 |
+
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
| 644 |
+
featureToRemove: "context-1m-2025-08-07",
|
| 645 |
+
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
| 646 |
+
},
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
for _, tt := range tests {
|
| 650 |
+
t.Run(tt.name, func(t *testing.T) {
|
| 651 |
+
result := filterBetaFeatures(tt.header, tt.featureToRemove)
|
| 652 |
+
if result != tt.expected {
|
| 653 |
+
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
|
| 654 |
+
}
|
| 655 |
+
})
|
| 656 |
+
}
|
| 657 |
+
}
|
internal/api/modules/amp/response_rewriter.go
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"net/http"
|
| 6 |
+
"strings"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
log "github.com/sirupsen/logrus"
|
| 10 |
+
"github.com/tidwall/gjson"
|
| 11 |
+
"github.com/tidwall/sjson"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
| 15 |
+
// It's used to rewrite model names in responses when model mapping is used
|
| 16 |
+
type ResponseRewriter struct {
|
| 17 |
+
gin.ResponseWriter
|
| 18 |
+
body *bytes.Buffer
|
| 19 |
+
originalModel string
|
| 20 |
+
isStreaming bool
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// NewResponseRewriter creates a new response rewriter for model name substitution
|
| 24 |
+
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
| 25 |
+
return &ResponseRewriter{
|
| 26 |
+
ResponseWriter: w,
|
| 27 |
+
body: &bytes.Buffer{},
|
| 28 |
+
originalModel: originalModel,
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
| 33 |
+
|
| 34 |
+
func looksLikeSSEChunk(data []byte) bool {
|
| 35 |
+
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
| 36 |
+
// Heuristics are intentionally simple and cheap.
|
| 37 |
+
return bytes.Contains(data, []byte("data:")) ||
|
| 38 |
+
bytes.Contains(data, []byte("event:")) ||
|
| 39 |
+
bytes.Contains(data, []byte("message_start")) ||
|
| 40 |
+
bytes.Contains(data, []byte("message_delta")) ||
|
| 41 |
+
bytes.Contains(data, []byte("content_block_start")) ||
|
| 42 |
+
bytes.Contains(data, []byte("content_block_delta")) ||
|
| 43 |
+
bytes.Contains(data, []byte("content_block_stop")) ||
|
| 44 |
+
bytes.Contains(data, []byte("\n\n"))
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
| 48 |
+
if rw.isStreaming {
|
| 49 |
+
return nil
|
| 50 |
+
}
|
| 51 |
+
rw.isStreaming = true
|
| 52 |
+
|
| 53 |
+
// Flush any previously buffered data to avoid reordering or data loss.
|
| 54 |
+
if rw.body != nil && rw.body.Len() > 0 {
|
| 55 |
+
buf := rw.body.Bytes()
|
| 56 |
+
// Copy before Reset() to keep bytes stable.
|
| 57 |
+
toFlush := make([]byte, len(buf))
|
| 58 |
+
copy(toFlush, buf)
|
| 59 |
+
rw.body.Reset()
|
| 60 |
+
|
| 61 |
+
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
| 62 |
+
return err
|
| 63 |
+
}
|
| 64 |
+
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
| 65 |
+
flusher.Flush()
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
| 70 |
+
return nil
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Write intercepts response writes and buffers them for model name replacement
|
| 74 |
+
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
| 75 |
+
// Detect streaming on first write (header-based)
|
| 76 |
+
if !rw.isStreaming && rw.body.Len() == 0 {
|
| 77 |
+
contentType := rw.Header().Get("Content-Type")
|
| 78 |
+
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
| 79 |
+
strings.Contains(contentType, "stream")
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if !rw.isStreaming {
|
| 83 |
+
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
|
| 84 |
+
if looksLikeSSEChunk(data) {
|
| 85 |
+
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
| 86 |
+
return 0, err
|
| 87 |
+
}
|
| 88 |
+
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
| 89 |
+
// Safety cap: avoid unbounded buffering on large responses.
|
| 90 |
+
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
| 91 |
+
if err := rw.enableStreaming("buffer limit"); err != nil {
|
| 92 |
+
return 0, err
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if rw.isStreaming {
|
| 98 |
+
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
| 99 |
+
if err == nil {
|
| 100 |
+
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
| 101 |
+
flusher.Flush()
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
return n, err
|
| 105 |
+
}
|
| 106 |
+
return rw.body.Write(data)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Flush writes the buffered response with model names rewritten
|
| 110 |
+
func (rw *ResponseRewriter) Flush() {
|
| 111 |
+
if rw.isStreaming {
|
| 112 |
+
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
| 113 |
+
flusher.Flush()
|
| 114 |
+
}
|
| 115 |
+
return
|
| 116 |
+
}
|
| 117 |
+
if rw.body.Len() > 0 {
|
| 118 |
+
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
| 119 |
+
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// modelFieldPaths lists all JSON paths where model name may appear
|
| 125 |
+
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
| 126 |
+
|
| 127 |
+
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
| 128 |
+
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
| 129 |
+
if rw.originalModel == "" {
|
| 130 |
+
return data
|
| 131 |
+
}
|
| 132 |
+
for _, path := range modelFieldPaths {
|
| 133 |
+
if gjson.GetBytes(data, path).Exists() {
|
| 134 |
+
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
return data
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
| 141 |
+
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
| 142 |
+
if rw.originalModel == "" {
|
| 143 |
+
return chunk
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// SSE format: "data: {json}\n\n"
|
| 147 |
+
lines := bytes.Split(chunk, []byte("\n"))
|
| 148 |
+
for i, line := range lines {
|
| 149 |
+
if bytes.HasPrefix(line, []byte("data: ")) {
|
| 150 |
+
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
| 151 |
+
if len(jsonData) > 0 && jsonData[0] == '{' {
|
| 152 |
+
// Rewrite JSON in the data line
|
| 153 |
+
rewritten := rw.rewriteModelInResponse(jsonData)
|
| 154 |
+
lines[i] = append([]byte("data: "), rewritten...)
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
return bytes.Join(lines, []byte("\n"))
|
| 160 |
+
}
|
internal/api/modules/amp/routes.go
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"errors"
|
| 6 |
+
"net"
|
| 7 |
+
"net/http"
|
| 8 |
+
"net/http/httputil"
|
| 9 |
+
"strings"
|
| 10 |
+
|
| 11 |
+
"github.com/gin-gonic/gin"
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
| 14 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
| 15 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
| 16 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
| 17 |
+
log "github.com/sirupsen/logrus"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
// clientAPIKeyContextKey is the context key used to pass the client API key
|
| 21 |
+
// from gin.Context to the request context for SecretSource lookup.
|
| 22 |
+
type clientAPIKeyContextKey struct{}
|
| 23 |
+
|
| 24 |
+
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
| 25 |
+
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
| 26 |
+
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
| 27 |
+
return func(c *gin.Context) {
|
| 28 |
+
// Extract the client API key from gin context (set by AuthMiddleware)
|
| 29 |
+
if apiKey, exists := c.Get("apiKey"); exists {
|
| 30 |
+
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
| 31 |
+
// Inject into request context for SecretSource.Get(ctx) to read
|
| 32 |
+
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
| 33 |
+
c.Request = c.Request.WithContext(ctx)
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
c.Next()
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
| 41 |
+
// Returns empty string if not present.
|
| 42 |
+
func getClientAPIKeyFromContext(ctx context.Context) string {
|
| 43 |
+
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
| 44 |
+
if keyStr, ok := val.(string); ok {
|
| 45 |
+
return keyStr
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
return ""
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
| 52 |
+
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
| 53 |
+
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
| 54 |
+
return func(c *gin.Context) {
|
| 55 |
+
// Check current setting (hot-reloadable)
|
| 56 |
+
if !m.IsRestrictedToLocalhost() {
|
| 57 |
+
c.Next()
|
| 58 |
+
return
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
| 62 |
+
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
| 63 |
+
remoteAddr := c.Request.RemoteAddr
|
| 64 |
+
|
| 65 |
+
// RemoteAddr format is "IP:port" or "[IPv6]:port", extract just the IP
|
| 66 |
+
host, _, err := net.SplitHostPort(remoteAddr)
|
| 67 |
+
if err != nil {
|
| 68 |
+
// Try parsing as raw IP (shouldn't happen with standard HTTP, but be defensive)
|
| 69 |
+
host = remoteAddr
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// Parse the IP to handle both IPv4 and IPv6
|
| 73 |
+
ip := net.ParseIP(host)
|
| 74 |
+
if ip == nil {
|
| 75 |
+
log.Warnf("amp management: invalid RemoteAddr %s, denying access", remoteAddr)
|
| 76 |
+
c.AbortWithStatusJSON(403, gin.H{
|
| 77 |
+
"error": "Access denied: management routes restricted to localhost",
|
| 78 |
+
})
|
| 79 |
+
return
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Check if IP is loopback (127.0.0.1 or ::1)
|
| 83 |
+
if !ip.IsLoopback() {
|
| 84 |
+
log.Warnf("amp management: non-localhost connection from %s attempted access, denying", remoteAddr)
|
| 85 |
+
c.AbortWithStatusJSON(403, gin.H{
|
| 86 |
+
"error": "Access denied: management routes restricted to localhost",
|
| 87 |
+
})
|
| 88 |
+
return
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
c.Next()
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks.
|
| 96 |
+
// This overwrites any global CORS headers set by the server.
|
| 97 |
+
func noCORSMiddleware() gin.HandlerFunc {
|
| 98 |
+
return func(c *gin.Context) {
|
| 99 |
+
// Remove CORS headers to prevent cross-origin access from browsers
|
| 100 |
+
c.Header("Access-Control-Allow-Origin", "")
|
| 101 |
+
c.Header("Access-Control-Allow-Methods", "")
|
| 102 |
+
c.Header("Access-Control-Allow-Headers", "")
|
| 103 |
+
c.Header("Access-Control-Allow-Credentials", "")
|
| 104 |
+
|
| 105 |
+
// For OPTIONS preflight, deny with 403
|
| 106 |
+
if c.Request.Method == "OPTIONS" {
|
| 107 |
+
c.AbortWithStatus(403)
|
| 108 |
+
return
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
c.Next()
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// managementAvailabilityMiddleware short-circuits management routes when the upstream
|
| 116 |
+
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
|
| 117 |
+
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
| 118 |
+
return func(c *gin.Context) {
|
| 119 |
+
if m.getProxy() == nil {
|
| 120 |
+
logging.SkipGinRequestLogging(c)
|
| 121 |
+
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
|
| 122 |
+
"error": "amp upstream proxy not available",
|
| 123 |
+
})
|
| 124 |
+
return
|
| 125 |
+
}
|
| 126 |
+
c.Next()
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
| 131 |
+
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
| 132 |
+
return func(c *gin.Context) {
|
| 133 |
+
path := c.Request.URL.Path
|
| 134 |
+
for _, prefix := range prefixes {
|
| 135 |
+
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
| 136 |
+
c.Next()
|
| 137 |
+
return
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
auth(c)
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// registerManagementRoutes registers Amp management proxy routes
|
| 145 |
+
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
| 146 |
+
// Uses dynamic middleware and proxy getter for hot-reload support.
|
| 147 |
+
// The auth middleware validates Authorization header against configured API keys.
|
| 148 |
+
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
| 149 |
+
ampAPI := engine.Group("/api")
|
| 150 |
+
|
| 151 |
+
// Always disable CORS for management routes to prevent browser-based attacks
|
| 152 |
+
ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
|
| 153 |
+
|
| 154 |
+
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
| 155 |
+
ampAPI.Use(m.localhostOnlyMiddleware())
|
| 156 |
+
|
| 157 |
+
// Apply authentication middleware - requires valid API key in Authorization header
|
| 158 |
+
var authWithBypass gin.HandlerFunc
|
| 159 |
+
if auth != nil {
|
| 160 |
+
ampAPI.Use(auth)
|
| 161 |
+
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// Inject client API key into request context for per-client upstream routing
|
| 165 |
+
ampAPI.Use(clientAPIKeyMiddleware())
|
| 166 |
+
|
| 167 |
+
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
| 168 |
+
proxyHandler := func(c *gin.Context) {
|
| 169 |
+
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
| 170 |
+
defer func() {
|
| 171 |
+
if rec := recover(); rec != nil {
|
| 172 |
+
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
| 173 |
+
// Upstream already wrote the status (often 404) before the client/stream ended.
|
| 174 |
+
return
|
| 175 |
+
}
|
| 176 |
+
panic(rec)
|
| 177 |
+
}
|
| 178 |
+
}()
|
| 179 |
+
|
| 180 |
+
proxy := m.getProxy()
|
| 181 |
+
if proxy == nil {
|
| 182 |
+
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
| 183 |
+
return
|
| 184 |
+
}
|
| 185 |
+
proxy.ServeHTTP(c.Writer, c.Request)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Management routes - these are proxied directly to Amp upstream
|
| 189 |
+
ampAPI.Any("/internal", proxyHandler)
|
| 190 |
+
ampAPI.Any("/internal/*path", proxyHandler)
|
| 191 |
+
ampAPI.Any("/user", proxyHandler)
|
| 192 |
+
ampAPI.Any("/user/*path", proxyHandler)
|
| 193 |
+
ampAPI.Any("/auth", proxyHandler)
|
| 194 |
+
ampAPI.Any("/auth/*path", proxyHandler)
|
| 195 |
+
ampAPI.Any("/meta", proxyHandler)
|
| 196 |
+
ampAPI.Any("/meta/*path", proxyHandler)
|
| 197 |
+
ampAPI.Any("/ads", proxyHandler)
|
| 198 |
+
ampAPI.Any("/telemetry", proxyHandler)
|
| 199 |
+
ampAPI.Any("/telemetry/*path", proxyHandler)
|
| 200 |
+
ampAPI.Any("/threads", proxyHandler)
|
| 201 |
+
ampAPI.Any("/threads/*path", proxyHandler)
|
| 202 |
+
ampAPI.Any("/otel", proxyHandler)
|
| 203 |
+
ampAPI.Any("/otel/*path", proxyHandler)
|
| 204 |
+
ampAPI.Any("/tab", proxyHandler)
|
| 205 |
+
ampAPI.Any("/tab/*path", proxyHandler)
|
| 206 |
+
|
| 207 |
+
// Root-level routes that AMP CLI expects without /api prefix
|
| 208 |
+
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
| 209 |
+
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
| 210 |
+
if authWithBypass != nil {
|
| 211 |
+
rootMiddleware = append(rootMiddleware, authWithBypass)
|
| 212 |
+
}
|
| 213 |
+
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
| 214 |
+
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
| 215 |
+
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
| 216 |
+
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
| 217 |
+
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
| 218 |
+
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
| 219 |
+
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
|
| 220 |
+
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
|
| 221 |
+
|
| 222 |
+
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
| 223 |
+
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
| 224 |
+
|
| 225 |
+
// Root-level auth routes for CLI login flow
|
| 226 |
+
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
|
| 227 |
+
// We proxy all /auth/* to support the complete OAuth flow
|
| 228 |
+
engine.Any("/auth", append(rootMiddleware, proxyHandler)...)
|
| 229 |
+
engine.Any("/auth/*path", append(rootMiddleware, proxyHandler)...)
|
| 230 |
+
|
| 231 |
+
// Google v1beta1 passthrough with OAuth fallback
|
| 232 |
+
// AMP CLI uses non-standard paths like /publishers/google/models/...
|
| 233 |
+
// We bridge these to our standard Gemini handler to enable local OAuth.
|
| 234 |
+
// If no local OAuth is available, falls back to ampcode.com proxy.
|
| 235 |
+
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
| 236 |
+
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
| 237 |
+
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
| 238 |
+
return m.getProxy()
|
| 239 |
+
}, m.modelMapper, m.forceModelMappings)
|
| 240 |
+
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
| 241 |
+
|
| 242 |
+
// Route POST model calls through Gemini bridge with FallbackHandler.
|
| 243 |
+
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
| 244 |
+
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
| 245 |
+
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
| 246 |
+
if c.Request.Method == "POST" {
|
| 247 |
+
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
| 248 |
+
// POST with /models/ path -> use Gemini bridge with fallback handler
|
| 249 |
+
// FallbackHandler will check provider/mapping and proxy if needed
|
| 250 |
+
geminiV1Beta1Handler(c)
|
| 251 |
+
return
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
// Non-POST or no local provider available -> proxy upstream
|
| 255 |
+
proxyHandler(c)
|
| 256 |
+
})
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// registerProviderAliases registers /api/provider/{provider}/... routes
|
| 260 |
+
// These allow Amp CLI to route requests like:
|
| 261 |
+
//
|
| 262 |
+
// /api/provider/openai/v1/chat/completions
|
| 263 |
+
// /api/provider/anthropic/v1/messages
|
| 264 |
+
// /api/provider/google/v1beta/models
|
| 265 |
+
func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
| 266 |
+
// Create handler instances for different providers
|
| 267 |
+
openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler)
|
| 268 |
+
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
| 269 |
+
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
| 270 |
+
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
| 271 |
+
|
| 272 |
+
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
| 273 |
+
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
| 274 |
+
// Also includes model mapping support for routing unavailable models to alternatives
|
| 275 |
+
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
| 276 |
+
return m.getProxy()
|
| 277 |
+
}, m.modelMapper, m.forceModelMappings)
|
| 278 |
+
|
| 279 |
+
// Provider-specific routes under /api/provider/:provider
|
| 280 |
+
ampProviders := engine.Group("/api/provider")
|
| 281 |
+
if auth != nil {
|
| 282 |
+
ampProviders.Use(auth)
|
| 283 |
+
}
|
| 284 |
+
// Inject client API key into request context for per-client upstream routing
|
| 285 |
+
ampProviders.Use(clientAPIKeyMiddleware())
|
| 286 |
+
|
| 287 |
+
provider := ampProviders.Group("/:provider")
|
| 288 |
+
|
| 289 |
+
// Dynamic models handler - routes to appropriate provider based on path parameter
|
| 290 |
+
ampModelsHandler := func(c *gin.Context) {
|
| 291 |
+
providerName := strings.ToLower(c.Param("provider"))
|
| 292 |
+
|
| 293 |
+
switch providerName {
|
| 294 |
+
case "anthropic":
|
| 295 |
+
claudeCodeHandlers.ClaudeModels(c)
|
| 296 |
+
case "google":
|
| 297 |
+
geminiHandlers.GeminiModels(c)
|
| 298 |
+
default:
|
| 299 |
+
// Default to OpenAI-compatible (works for openai, groq, cerebras, etc.)
|
| 300 |
+
openaiHandlers.OpenAIModels(c)
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
| 305 |
+
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
| 306 |
+
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
| 307 |
+
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
| 308 |
+
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
| 309 |
+
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
| 310 |
+
|
| 311 |
+
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
| 312 |
+
v1Amp := provider.Group("/v1")
|
| 313 |
+
{
|
| 314 |
+
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
| 315 |
+
|
| 316 |
+
// OpenAI-compatible endpoints with fallback
|
| 317 |
+
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
| 318 |
+
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
| 319 |
+
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
| 320 |
+
|
| 321 |
+
// Claude/Anthropic-compatible endpoints with fallback
|
| 322 |
+
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
| 323 |
+
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
// /v1beta routes (Gemini native API)
|
| 327 |
+
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
| 328 |
+
v1betaAmp := provider.Group("/v1beta")
|
| 329 |
+
{
|
| 330 |
+
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
| 331 |
+
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
| 332 |
+
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
| 333 |
+
}
|
| 334 |
+
}
|
internal/api/modules/amp/routes_test.go
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
"net/http/httptest"
|
| 6 |
+
"testing"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
func TestRegisterManagementRoutes(t *testing.T) {
|
| 13 |
+
gin.SetMode(gin.TestMode)
|
| 14 |
+
r := gin.New()
|
| 15 |
+
|
| 16 |
+
// Create module with proxy for testing
|
| 17 |
+
m := &AmpModule{
|
| 18 |
+
restrictToLocalhost: false, // disable localhost restriction for tests
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
// Create a mock proxy that tracks calls
|
| 22 |
+
proxyCalled := false
|
| 23 |
+
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
| 24 |
+
proxyCalled = true
|
| 25 |
+
w.WriteHeader(200)
|
| 26 |
+
w.Write([]byte("proxied"))
|
| 27 |
+
}))
|
| 28 |
+
defer mockProxy.Close()
|
| 29 |
+
|
| 30 |
+
// Create real proxy to mock server
|
| 31 |
+
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
|
| 32 |
+
m.setProxy(proxy)
|
| 33 |
+
|
| 34 |
+
base := &handlers.BaseAPIHandler{}
|
| 35 |
+
m.registerManagementRoutes(r, base, nil)
|
| 36 |
+
srv := httptest.NewServer(r)
|
| 37 |
+
defer srv.Close()
|
| 38 |
+
|
| 39 |
+
managementPaths := []struct {
|
| 40 |
+
path string
|
| 41 |
+
method string
|
| 42 |
+
}{
|
| 43 |
+
{"/api/internal", http.MethodGet},
|
| 44 |
+
{"/api/internal/some/path", http.MethodGet},
|
| 45 |
+
{"/api/user", http.MethodGet},
|
| 46 |
+
{"/api/user/profile", http.MethodGet},
|
| 47 |
+
{"/api/auth", http.MethodGet},
|
| 48 |
+
{"/api/auth/login", http.MethodGet},
|
| 49 |
+
{"/api/meta", http.MethodGet},
|
| 50 |
+
{"/api/telemetry", http.MethodGet},
|
| 51 |
+
{"/api/threads", http.MethodGet},
|
| 52 |
+
{"/threads/", http.MethodGet},
|
| 53 |
+
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
| 54 |
+
{"/api/otel", http.MethodGet},
|
| 55 |
+
{"/api/tab", http.MethodGet},
|
| 56 |
+
{"/api/tab/some/path", http.MethodGet},
|
| 57 |
+
{"/auth", http.MethodGet}, // Root-level auth route
|
| 58 |
+
{"/auth/cli-login", http.MethodGet}, // CLI login flow
|
| 59 |
+
{"/auth/callback", http.MethodGet}, // OAuth callback
|
| 60 |
+
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
| 61 |
+
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
| 62 |
+
{"/api/provider/google/v1beta1/models", http.MethodPost},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
for _, path := range managementPaths {
|
| 66 |
+
t.Run(path.path, func(t *testing.T) {
|
| 67 |
+
proxyCalled = false
|
| 68 |
+
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
| 69 |
+
if err != nil {
|
| 70 |
+
t.Fatalf("failed to build request: %v", err)
|
| 71 |
+
}
|
| 72 |
+
resp, err := http.DefaultClient.Do(req)
|
| 73 |
+
if err != nil {
|
| 74 |
+
t.Fatalf("request failed: %v", err)
|
| 75 |
+
}
|
| 76 |
+
defer resp.Body.Close()
|
| 77 |
+
|
| 78 |
+
if resp.StatusCode == http.StatusNotFound {
|
| 79 |
+
t.Fatalf("route %s not registered", path.path)
|
| 80 |
+
}
|
| 81 |
+
if !proxyCalled {
|
| 82 |
+
t.Fatalf("proxy handler not called for %s", path.path)
|
| 83 |
+
}
|
| 84 |
+
})
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
|
| 89 |
+
gin.SetMode(gin.TestMode)
|
| 90 |
+
r := gin.New()
|
| 91 |
+
|
| 92 |
+
// Minimal base handler setup (no need to initialize, just check routing)
|
| 93 |
+
base := &handlers.BaseAPIHandler{}
|
| 94 |
+
|
| 95 |
+
// Track if auth middleware was called
|
| 96 |
+
authCalled := false
|
| 97 |
+
authMiddleware := func(c *gin.Context) {
|
| 98 |
+
authCalled = true
|
| 99 |
+
c.Header("X-Auth", "ok")
|
| 100 |
+
// Abort with success to avoid calling the actual handler (which needs full setup)
|
| 101 |
+
c.AbortWithStatus(http.StatusOK)
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
m := &AmpModule{authMiddleware_: authMiddleware}
|
| 105 |
+
m.registerProviderAliases(r, base, authMiddleware)
|
| 106 |
+
|
| 107 |
+
paths := []struct {
|
| 108 |
+
path string
|
| 109 |
+
method string
|
| 110 |
+
}{
|
| 111 |
+
{"/api/provider/openai/models", http.MethodGet},
|
| 112 |
+
{"/api/provider/anthropic/models", http.MethodGet},
|
| 113 |
+
{"/api/provider/google/models", http.MethodGet},
|
| 114 |
+
{"/api/provider/groq/models", http.MethodGet},
|
| 115 |
+
{"/api/provider/openai/chat/completions", http.MethodPost},
|
| 116 |
+
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
| 117 |
+
{"/api/provider/google/v1beta/models", http.MethodGet},
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
for _, tc := range paths {
|
| 121 |
+
t.Run(tc.path, func(t *testing.T) {
|
| 122 |
+
authCalled = false
|
| 123 |
+
req := httptest.NewRequest(tc.method, tc.path, nil)
|
| 124 |
+
w := httptest.NewRecorder()
|
| 125 |
+
r.ServeHTTP(w, req)
|
| 126 |
+
|
| 127 |
+
if w.Code == http.StatusNotFound {
|
| 128 |
+
t.Fatalf("route %s %s not registered", tc.method, tc.path)
|
| 129 |
+
}
|
| 130 |
+
if !authCalled {
|
| 131 |
+
t.Fatalf("auth middleware not executed for %s", tc.path)
|
| 132 |
+
}
|
| 133 |
+
if w.Header().Get("X-Auth") != "ok" {
|
| 134 |
+
t.Fatalf("auth middleware header not set for %s", tc.path)
|
| 135 |
+
}
|
| 136 |
+
})
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
|
| 141 |
+
gin.SetMode(gin.TestMode)
|
| 142 |
+
r := gin.New()
|
| 143 |
+
|
| 144 |
+
base := &handlers.BaseAPIHandler{}
|
| 145 |
+
|
| 146 |
+
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
| 147 |
+
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
| 148 |
+
|
| 149 |
+
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
|
| 150 |
+
|
| 151 |
+
for _, provider := range providers {
|
| 152 |
+
t.Run(provider, func(t *testing.T) {
|
| 153 |
+
path := "/api/provider/" + provider + "/models"
|
| 154 |
+
req := httptest.NewRequest(http.MethodGet, path, nil)
|
| 155 |
+
w := httptest.NewRecorder()
|
| 156 |
+
r.ServeHTTP(w, req)
|
| 157 |
+
|
| 158 |
+
// Should not 404
|
| 159 |
+
if w.Code == http.StatusNotFound {
|
| 160 |
+
t.Fatalf("models route not found for provider: %s", provider)
|
| 161 |
+
}
|
| 162 |
+
})
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
|
| 167 |
+
gin.SetMode(gin.TestMode)
|
| 168 |
+
r := gin.New()
|
| 169 |
+
|
| 170 |
+
base := &handlers.BaseAPIHandler{}
|
| 171 |
+
|
| 172 |
+
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
| 173 |
+
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
| 174 |
+
|
| 175 |
+
v1Paths := []struct {
|
| 176 |
+
path string
|
| 177 |
+
method string
|
| 178 |
+
}{
|
| 179 |
+
{"/api/provider/openai/v1/models", http.MethodGet},
|
| 180 |
+
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
|
| 181 |
+
{"/api/provider/openai/v1/completions", http.MethodPost},
|
| 182 |
+
{"/api/provider/anthropic/v1/messages", http.MethodPost},
|
| 183 |
+
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
for _, tc := range v1Paths {
|
| 187 |
+
t.Run(tc.path, func(t *testing.T) {
|
| 188 |
+
req := httptest.NewRequest(tc.method, tc.path, nil)
|
| 189 |
+
w := httptest.NewRecorder()
|
| 190 |
+
r.ServeHTTP(w, req)
|
| 191 |
+
|
| 192 |
+
if w.Code == http.StatusNotFound {
|
| 193 |
+
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
|
| 194 |
+
}
|
| 195 |
+
})
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
|
| 200 |
+
gin.SetMode(gin.TestMode)
|
| 201 |
+
r := gin.New()
|
| 202 |
+
|
| 203 |
+
base := &handlers.BaseAPIHandler{}
|
| 204 |
+
|
| 205 |
+
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
|
| 206 |
+
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
| 207 |
+
|
| 208 |
+
v1betaPaths := []struct {
|
| 209 |
+
path string
|
| 210 |
+
method string
|
| 211 |
+
}{
|
| 212 |
+
{"/api/provider/google/v1beta/models", http.MethodGet},
|
| 213 |
+
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
for _, tc := range v1betaPaths {
|
| 217 |
+
t.Run(tc.path, func(t *testing.T) {
|
| 218 |
+
req := httptest.NewRequest(tc.method, tc.path, nil)
|
| 219 |
+
w := httptest.NewRecorder()
|
| 220 |
+
r.ServeHTTP(w, req)
|
| 221 |
+
|
| 222 |
+
if w.Code == http.StatusNotFound {
|
| 223 |
+
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
|
| 224 |
+
}
|
| 225 |
+
})
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
|
| 230 |
+
// Test that routes still register even if auth middleware is nil (fallback behavior)
|
| 231 |
+
gin.SetMode(gin.TestMode)
|
| 232 |
+
r := gin.New()
|
| 233 |
+
|
| 234 |
+
base := &handlers.BaseAPIHandler{}
|
| 235 |
+
|
| 236 |
+
m := &AmpModule{authMiddleware_: nil} // No auth middleware
|
| 237 |
+
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
|
| 238 |
+
|
| 239 |
+
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
|
| 240 |
+
w := httptest.NewRecorder()
|
| 241 |
+
r.ServeHTTP(w, req)
|
| 242 |
+
|
| 243 |
+
// Should still work (with fallback no-op auth)
|
| 244 |
+
if w.Code == http.StatusNotFound {
|
| 245 |
+
t.Fatal("routes should register even without auth middleware")
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
| 250 |
+
gin.SetMode(gin.TestMode)
|
| 251 |
+
r := gin.New()
|
| 252 |
+
|
| 253 |
+
// Create module with localhost restriction enabled
|
| 254 |
+
m := &AmpModule{
|
| 255 |
+
restrictToLocalhost: true,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// Apply dynamic localhost-only middleware
|
| 259 |
+
r.Use(m.localhostOnlyMiddleware())
|
| 260 |
+
r.GET("/test", func(c *gin.Context) {
|
| 261 |
+
c.String(http.StatusOK, "ok")
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
tests := []struct {
|
| 265 |
+
name string
|
| 266 |
+
remoteAddr string
|
| 267 |
+
forwardedFor string
|
| 268 |
+
expectedStatus int
|
| 269 |
+
description string
|
| 270 |
+
}{
|
| 271 |
+
{
|
| 272 |
+
name: "spoofed_header_remote_connection",
|
| 273 |
+
remoteAddr: "192.168.1.100:12345",
|
| 274 |
+
forwardedFor: "127.0.0.1",
|
| 275 |
+
expectedStatus: http.StatusForbidden,
|
| 276 |
+
description: "Spoofed X-Forwarded-For header should be ignored",
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
name: "real_localhost_ipv4",
|
| 280 |
+
remoteAddr: "127.0.0.1:54321",
|
| 281 |
+
forwardedFor: "",
|
| 282 |
+
expectedStatus: http.StatusOK,
|
| 283 |
+
description: "Real localhost IPv4 connection should work",
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
name: "real_localhost_ipv6",
|
| 287 |
+
remoteAddr: "[::1]:54321",
|
| 288 |
+
forwardedFor: "",
|
| 289 |
+
expectedStatus: http.StatusOK,
|
| 290 |
+
description: "Real localhost IPv6 connection should work",
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
name: "remote_ipv4",
|
| 294 |
+
remoteAddr: "203.0.113.42:8080",
|
| 295 |
+
forwardedFor: "",
|
| 296 |
+
expectedStatus: http.StatusForbidden,
|
| 297 |
+
description: "Remote IPv4 connection should be blocked",
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
name: "remote_ipv6",
|
| 301 |
+
remoteAddr: "[2001:db8::1]:9090",
|
| 302 |
+
forwardedFor: "",
|
| 303 |
+
expectedStatus: http.StatusForbidden,
|
| 304 |
+
description: "Remote IPv6 connection should be blocked",
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
name: "spoofed_localhost_ipv6",
|
| 308 |
+
remoteAddr: "203.0.113.42:8080",
|
| 309 |
+
forwardedFor: "::1",
|
| 310 |
+
expectedStatus: http.StatusForbidden,
|
| 311 |
+
description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
|
| 312 |
+
},
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
for _, tt := range tests {
|
| 316 |
+
t.Run(tt.name, func(t *testing.T) {
|
| 317 |
+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
| 318 |
+
req.RemoteAddr = tt.remoteAddr
|
| 319 |
+
if tt.forwardedFor != "" {
|
| 320 |
+
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
w := httptest.NewRecorder()
|
| 324 |
+
r.ServeHTTP(w, req)
|
| 325 |
+
|
| 326 |
+
if w.Code != tt.expectedStatus {
|
| 327 |
+
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
|
| 328 |
+
}
|
| 329 |
+
})
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
|
| 334 |
+
gin.SetMode(gin.TestMode)
|
| 335 |
+
r := gin.New()
|
| 336 |
+
|
| 337 |
+
// Create module with localhost restriction initially enabled
|
| 338 |
+
m := &AmpModule{
|
| 339 |
+
restrictToLocalhost: true,
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Apply dynamic localhost-only middleware
|
| 343 |
+
r.Use(m.localhostOnlyMiddleware())
|
| 344 |
+
r.GET("/test", func(c *gin.Context) {
|
| 345 |
+
c.String(http.StatusOK, "ok")
|
| 346 |
+
})
|
| 347 |
+
|
| 348 |
+
// Test 1: Remote IP should be blocked when restriction is enabled
|
| 349 |
+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
| 350 |
+
req.RemoteAddr = "192.168.1.100:12345"
|
| 351 |
+
w := httptest.NewRecorder()
|
| 352 |
+
r.ServeHTTP(w, req)
|
| 353 |
+
|
| 354 |
+
if w.Code != http.StatusForbidden {
|
| 355 |
+
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// Test 2: Hot-reload - disable restriction
|
| 359 |
+
m.setRestrictToLocalhost(false)
|
| 360 |
+
|
| 361 |
+
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
| 362 |
+
req.RemoteAddr = "192.168.1.100:12345"
|
| 363 |
+
w = httptest.NewRecorder()
|
| 364 |
+
r.ServeHTTP(w, req)
|
| 365 |
+
|
| 366 |
+
if w.Code != http.StatusOK {
|
| 367 |
+
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// Test 3: Hot-reload - re-enable restriction
|
| 371 |
+
m.setRestrictToLocalhost(true)
|
| 372 |
+
|
| 373 |
+
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
| 374 |
+
req.RemoteAddr = "192.168.1.100:12345"
|
| 375 |
+
w = httptest.NewRecorder()
|
| 376 |
+
r.ServeHTTP(w, req)
|
| 377 |
+
|
| 378 |
+
if w.Code != http.StatusForbidden {
|
| 379 |
+
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
|
| 380 |
+
}
|
| 381 |
+
}
|
internal/api/modules/amp/secret.go
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"os"
|
| 8 |
+
"path/filepath"
|
| 9 |
+
"strings"
|
| 10 |
+
"sync"
|
| 11 |
+
"time"
|
| 12 |
+
|
| 13 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 14 |
+
log "github.com/sirupsen/logrus"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
// SecretSource provides Amp API keys with configurable precedence and caching
|
| 18 |
+
type SecretSource interface {
|
| 19 |
+
Get(ctx context.Context) (string, error)
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
// cachedSecret holds a secret value with expiration
|
| 23 |
+
type cachedSecret struct {
|
| 24 |
+
value string
|
| 25 |
+
expiresAt time.Time
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// MultiSourceSecret implements precedence-based secret lookup:
|
| 29 |
+
// 1. Explicit config value (highest priority)
|
| 30 |
+
// 2. Environment variable AMP_API_KEY
|
| 31 |
+
// 3. File-based secret (lowest priority)
|
| 32 |
+
type MultiSourceSecret struct {
|
| 33 |
+
explicitKey string
|
| 34 |
+
envKey string
|
| 35 |
+
filePath string
|
| 36 |
+
cacheTTL time.Duration
|
| 37 |
+
|
| 38 |
+
mu sync.RWMutex
|
| 39 |
+
cache *cachedSecret
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// NewMultiSourceSecret creates a secret source with precedence and caching
|
| 43 |
+
func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret {
|
| 44 |
+
if cacheTTL == 0 {
|
| 45 |
+
cacheTTL = 5 * time.Minute // Default 5 minute cache
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
home, _ := os.UserHomeDir()
|
| 49 |
+
filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json")
|
| 50 |
+
|
| 51 |
+
return &MultiSourceSecret{
|
| 52 |
+
explicitKey: strings.TrimSpace(explicitKey),
|
| 53 |
+
envKey: "AMP_API_KEY",
|
| 54 |
+
filePath: filePath,
|
| 55 |
+
cacheTTL: cacheTTL,
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing)
|
| 60 |
+
func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret {
|
| 61 |
+
if cacheTTL == 0 {
|
| 62 |
+
cacheTTL = 5 * time.Minute
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
return &MultiSourceSecret{
|
| 66 |
+
explicitKey: strings.TrimSpace(explicitKey),
|
| 67 |
+
envKey: "AMP_API_KEY",
|
| 68 |
+
filePath: filePath,
|
| 69 |
+
cacheTTL: cacheTTL,
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Get retrieves the Amp API key using precedence: config > env > file
|
| 74 |
+
// Results are cached for cacheTTL duration to avoid excessive file reads
|
| 75 |
+
func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) {
|
| 76 |
+
// Precedence 1: Explicit config key (highest priority, no caching needed)
|
| 77 |
+
if s.explicitKey != "" {
|
| 78 |
+
return s.explicitKey, nil
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Precedence 2: Environment variable
|
| 82 |
+
if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" {
|
| 83 |
+
return envValue, nil
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Precedence 3: File-based secret (lowest priority, cached)
|
| 87 |
+
// Check cache first
|
| 88 |
+
s.mu.RLock()
|
| 89 |
+
if s.cache != nil && time.Now().Before(s.cache.expiresAt) {
|
| 90 |
+
value := s.cache.value
|
| 91 |
+
s.mu.RUnlock()
|
| 92 |
+
return value, nil
|
| 93 |
+
}
|
| 94 |
+
s.mu.RUnlock()
|
| 95 |
+
|
| 96 |
+
// Cache miss or expired - read from file
|
| 97 |
+
key, err := s.readFromFile()
|
| 98 |
+
if err != nil {
|
| 99 |
+
// Cache empty result to avoid repeated file reads on missing files
|
| 100 |
+
s.updateCache("")
|
| 101 |
+
return "", err
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Cache the result
|
| 105 |
+
s.updateCache(key)
|
| 106 |
+
return key, nil
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// readFromFile reads the Amp API key from the secrets file
|
| 110 |
+
func (s *MultiSourceSecret) readFromFile() (string, error) {
|
| 111 |
+
content, err := os.ReadFile(s.filePath)
|
| 112 |
+
if err != nil {
|
| 113 |
+
if os.IsNotExist(err) {
|
| 114 |
+
return "", nil // Missing file is not an error, just no key available
|
| 115 |
+
}
|
| 116 |
+
return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err)
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
var secrets map[string]string
|
| 120 |
+
if err := json.Unmarshal(content, &secrets); err != nil {
|
| 121 |
+
return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err)
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"])
|
| 125 |
+
return key, nil
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// updateCache updates the cached secret value
|
| 129 |
+
func (s *MultiSourceSecret) updateCache(value string) {
|
| 130 |
+
s.mu.Lock()
|
| 131 |
+
defer s.mu.Unlock()
|
| 132 |
+
s.cache = &cachedSecret{
|
| 133 |
+
value: value,
|
| 134 |
+
expiresAt: time.Now().Add(s.cacheTTL),
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// InvalidateCache clears the cached secret, forcing a fresh read on next Get
|
| 139 |
+
func (s *MultiSourceSecret) InvalidateCache() {
|
| 140 |
+
s.mu.Lock()
|
| 141 |
+
defer s.mu.Unlock()
|
| 142 |
+
s.cache = nil
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// UpdateExplicitKey refreshes the config-provided key and clears cache.
|
| 146 |
+
func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
|
| 147 |
+
if s == nil {
|
| 148 |
+
return
|
| 149 |
+
}
|
| 150 |
+
s.mu.Lock()
|
| 151 |
+
s.explicitKey = strings.TrimSpace(key)
|
| 152 |
+
s.cache = nil
|
| 153 |
+
s.mu.Unlock()
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// StaticSecretSource returns a fixed API key (for testing)
|
| 157 |
+
type StaticSecretSource struct {
|
| 158 |
+
key string
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// NewStaticSecretSource creates a secret source with a fixed key
|
| 162 |
+
func NewStaticSecretSource(key string) *StaticSecretSource {
|
| 163 |
+
return &StaticSecretSource{key: strings.TrimSpace(key)}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// Get returns the static API key
|
| 167 |
+
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
| 168 |
+
return s.key, nil
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
| 172 |
+
// When a request context contains a client API key that matches a configured mapping,
|
| 173 |
+
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
| 174 |
+
type MappedSecretSource struct {
|
| 175 |
+
defaultSource SecretSource
|
| 176 |
+
mu sync.RWMutex
|
| 177 |
+
lookup map[string]string // clientKey -> upstreamKey
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
| 181 |
+
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
| 182 |
+
return &MappedSecretSource{
|
| 183 |
+
defaultSource: defaultSource,
|
| 184 |
+
lookup: make(map[string]string),
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Get retrieves the Amp API key, checking per-client mappings first.
|
| 189 |
+
// If the request context contains a client API key that matches a configured mapping,
|
| 190 |
+
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
| 191 |
+
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
| 192 |
+
// Try to get client API key from request context
|
| 193 |
+
clientKey := getClientAPIKeyFromContext(ctx)
|
| 194 |
+
if clientKey != "" {
|
| 195 |
+
s.mu.RLock()
|
| 196 |
+
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
| 197 |
+
s.mu.RUnlock()
|
| 198 |
+
return upstreamKey, nil
|
| 199 |
+
}
|
| 200 |
+
s.mu.RUnlock()
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
// Fall back to default source
|
| 204 |
+
return s.defaultSource.Get(ctx)
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
| 208 |
+
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
| 209 |
+
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
| 210 |
+
newLookup := make(map[string]string)
|
| 211 |
+
|
| 212 |
+
for _, entry := range entries {
|
| 213 |
+
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
| 214 |
+
if upstreamKey == "" {
|
| 215 |
+
continue
|
| 216 |
+
}
|
| 217 |
+
for _, clientKey := range entry.APIKeys {
|
| 218 |
+
trimmedKey := strings.TrimSpace(clientKey)
|
| 219 |
+
if trimmedKey == "" {
|
| 220 |
+
continue
|
| 221 |
+
}
|
| 222 |
+
if _, exists := newLookup[trimmedKey]; exists {
|
| 223 |
+
// Log warning for duplicate client key, first one wins
|
| 224 |
+
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
| 225 |
+
continue
|
| 226 |
+
}
|
| 227 |
+
newLookup[trimmedKey] = upstreamKey
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
s.mu.Lock()
|
| 232 |
+
s.lookup = newLookup
|
| 233 |
+
s.mu.Unlock()
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
| 237 |
+
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
| 238 |
+
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
| 239 |
+
ms.UpdateExplicitKey(key)
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
| 244 |
+
func (s *MappedSecretSource) InvalidateCache() {
|
| 245 |
+
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
| 246 |
+
ms.InvalidateCache()
|
| 247 |
+
}
|
| 248 |
+
}
|
internal/api/modules/amp/secret_test.go
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package amp
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"os"
|
| 7 |
+
"path/filepath"
|
| 8 |
+
"sync"
|
| 9 |
+
"testing"
|
| 10 |
+
"time"
|
| 11 |
+
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 13 |
+
log "github.com/sirupsen/logrus"
|
| 14 |
+
"github.com/sirupsen/logrus/hooks/test"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
| 18 |
+
ctx := context.Background()
|
| 19 |
+
|
| 20 |
+
cases := []struct {
|
| 21 |
+
name string
|
| 22 |
+
configKey string
|
| 23 |
+
envKey string
|
| 24 |
+
fileJSON string
|
| 25 |
+
want string
|
| 26 |
+
}{
|
| 27 |
+
{"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"},
|
| 28 |
+
{"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
| 29 |
+
{"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
| 30 |
+
{"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
|
| 31 |
+
{"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
|
| 32 |
+
{"missing_file_returns_empty", "", "", "", ""},
|
| 33 |
+
{"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
for _, tc := range cases {
|
| 37 |
+
tc := tc // capture range variable
|
| 38 |
+
t.Run(tc.name, func(t *testing.T) {
|
| 39 |
+
tmpDir := t.TempDir()
|
| 40 |
+
secretsPath := filepath.Join(tmpDir, "secrets.json")
|
| 41 |
+
|
| 42 |
+
if tc.fileJSON != "" {
|
| 43 |
+
if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil {
|
| 44 |
+
t.Fatal(err)
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
t.Setenv("AMP_API_KEY", tc.envKey)
|
| 49 |
+
|
| 50 |
+
s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond)
|
| 51 |
+
got, err := s.Get(ctx)
|
| 52 |
+
if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) {
|
| 53 |
+
t.Fatalf("unexpected error: %v", err)
|
| 54 |
+
}
|
| 55 |
+
if got != tc.want {
|
| 56 |
+
t.Fatalf("want %q, got %q", tc.want, got)
|
| 57 |
+
}
|
| 58 |
+
})
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
func TestMultiSourceSecret_CacheBehavior(t *testing.T) {
|
| 63 |
+
ctx := context.Background()
|
| 64 |
+
tmpDir := t.TempDir()
|
| 65 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 66 |
+
|
| 67 |
+
// Initial value
|
| 68 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
|
| 69 |
+
t.Fatal(err)
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond)
|
| 73 |
+
|
| 74 |
+
// First read - should return v1
|
| 75 |
+
got1, err := s.Get(ctx)
|
| 76 |
+
if err != nil {
|
| 77 |
+
t.Fatalf("Get failed: %v", err)
|
| 78 |
+
}
|
| 79 |
+
if got1 != "v1" {
|
| 80 |
+
t.Fatalf("expected v1, got %s", got1)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// Change file; within TTL we should still see v1 (cached)
|
| 84 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil {
|
| 85 |
+
t.Fatal(err)
|
| 86 |
+
}
|
| 87 |
+
got2, _ := s.Get(ctx)
|
| 88 |
+
if got2 != "v1" {
|
| 89 |
+
t.Fatalf("cache hit expected v1, got %s", got2)
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// After TTL expires, should see v2
|
| 93 |
+
time.Sleep(60 * time.Millisecond)
|
| 94 |
+
got3, _ := s.Get(ctx)
|
| 95 |
+
if got3 != "v2" {
|
| 96 |
+
t.Fatalf("cache miss expected v2, got %s", got3)
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Invalidate forces re-read immediately
|
| 100 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil {
|
| 101 |
+
t.Fatal(err)
|
| 102 |
+
}
|
| 103 |
+
s.InvalidateCache()
|
| 104 |
+
got4, _ := s.Get(ctx)
|
| 105 |
+
if got4 != "v3" {
|
| 106 |
+
t.Fatalf("invalidate expected v3, got %s", got4)
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
func TestMultiSourceSecret_FileHandling(t *testing.T) {
|
| 111 |
+
ctx := context.Background()
|
| 112 |
+
|
| 113 |
+
t.Run("missing_file_no_error", func(t *testing.T) {
|
| 114 |
+
s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond)
|
| 115 |
+
got, err := s.Get(ctx)
|
| 116 |
+
if err != nil {
|
| 117 |
+
t.Fatalf("expected no error for missing file, got: %v", err)
|
| 118 |
+
}
|
| 119 |
+
if got != "" {
|
| 120 |
+
t.Fatalf("expected empty string, got %q", got)
|
| 121 |
+
}
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
t.Run("invalid_json", func(t *testing.T) {
|
| 125 |
+
tmpDir := t.TempDir()
|
| 126 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 127 |
+
if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil {
|
| 128 |
+
t.Fatal(err)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
| 132 |
+
_, err := s.Get(ctx)
|
| 133 |
+
if err == nil {
|
| 134 |
+
t.Fatal("expected error for invalid JSON")
|
| 135 |
+
}
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
t.Run("missing_key_in_json", func(t *testing.T) {
|
| 139 |
+
tmpDir := t.TempDir()
|
| 140 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 141 |
+
if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil {
|
| 142 |
+
t.Fatal(err)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
| 146 |
+
got, err := s.Get(ctx)
|
| 147 |
+
if err != nil {
|
| 148 |
+
t.Fatalf("unexpected error: %v", err)
|
| 149 |
+
}
|
| 150 |
+
if got != "" {
|
| 151 |
+
t.Fatalf("expected empty string for missing key, got %q", got)
|
| 152 |
+
}
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
t.Run("empty_key_value", func(t *testing.T) {
|
| 156 |
+
tmpDir := t.TempDir()
|
| 157 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 158 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil {
|
| 159 |
+
t.Fatal(err)
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
| 163 |
+
got, _ := s.Get(ctx)
|
| 164 |
+
if got != "" {
|
| 165 |
+
t.Fatalf("expected empty after trim, got %q", got)
|
| 166 |
+
}
|
| 167 |
+
})
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
func TestMultiSourceSecret_Concurrency(t *testing.T) {
|
| 171 |
+
tmpDir := t.TempDir()
|
| 172 |
+
p := filepath.Join(tmpDir, "secrets.json")
|
| 173 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil {
|
| 174 |
+
t.Fatal(err)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
s := NewMultiSourceSecretWithPath("", p, 5*time.Second)
|
| 178 |
+
ctx := context.Background()
|
| 179 |
+
|
| 180 |
+
// Spawn many goroutines calling Get concurrently
|
| 181 |
+
const goroutines = 50
|
| 182 |
+
const iterations = 100
|
| 183 |
+
|
| 184 |
+
var wg sync.WaitGroup
|
| 185 |
+
errors := make(chan error, goroutines)
|
| 186 |
+
|
| 187 |
+
for i := 0; i < goroutines; i++ {
|
| 188 |
+
wg.Add(1)
|
| 189 |
+
go func() {
|
| 190 |
+
defer wg.Done()
|
| 191 |
+
for j := 0; j < iterations; j++ {
|
| 192 |
+
val, err := s.Get(ctx)
|
| 193 |
+
if err != nil {
|
| 194 |
+
errors <- err
|
| 195 |
+
return
|
| 196 |
+
}
|
| 197 |
+
if val != "concurrent" {
|
| 198 |
+
errors <- err
|
| 199 |
+
return
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}()
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
wg.Wait()
|
| 206 |
+
close(errors)
|
| 207 |
+
|
| 208 |
+
for err := range errors {
|
| 209 |
+
t.Errorf("concurrency error: %v", err)
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
func TestStaticSecretSource(t *testing.T) {
|
| 214 |
+
ctx := context.Background()
|
| 215 |
+
|
| 216 |
+
t.Run("returns_provided_key", func(t *testing.T) {
|
| 217 |
+
s := NewStaticSecretSource("test-key-123")
|
| 218 |
+
got, err := s.Get(ctx)
|
| 219 |
+
if err != nil {
|
| 220 |
+
t.Fatalf("unexpected error: %v", err)
|
| 221 |
+
}
|
| 222 |
+
if got != "test-key-123" {
|
| 223 |
+
t.Fatalf("want test-key-123, got %q", got)
|
| 224 |
+
}
|
| 225 |
+
})
|
| 226 |
+
|
| 227 |
+
t.Run("trims_whitespace", func(t *testing.T) {
|
| 228 |
+
s := NewStaticSecretSource(" test-key ")
|
| 229 |
+
got, err := s.Get(ctx)
|
| 230 |
+
if err != nil {
|
| 231 |
+
t.Fatalf("unexpected error: %v", err)
|
| 232 |
+
}
|
| 233 |
+
if got != "test-key" {
|
| 234 |
+
t.Fatalf("want test-key, got %q", got)
|
| 235 |
+
}
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
t.Run("empty_string", func(t *testing.T) {
|
| 239 |
+
s := NewStaticSecretSource("")
|
| 240 |
+
got, err := s.Get(ctx)
|
| 241 |
+
if err != nil {
|
| 242 |
+
t.Fatalf("unexpected error: %v", err)
|
| 243 |
+
}
|
| 244 |
+
if got != "" {
|
| 245 |
+
t.Fatalf("want empty string, got %q", got)
|
| 246 |
+
}
|
| 247 |
+
})
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
| 251 |
+
// Test that missing file results are cached to avoid repeated file reads
|
| 252 |
+
tmpDir := t.TempDir()
|
| 253 |
+
p := filepath.Join(tmpDir, "nonexistent.json")
|
| 254 |
+
|
| 255 |
+
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
|
| 256 |
+
ctx := context.Background()
|
| 257 |
+
|
| 258 |
+
// First call - file doesn't exist, should cache empty result
|
| 259 |
+
got1, err := s.Get(ctx)
|
| 260 |
+
if err != nil {
|
| 261 |
+
t.Fatalf("expected no error for missing file, got: %v", err)
|
| 262 |
+
}
|
| 263 |
+
if got1 != "" {
|
| 264 |
+
t.Fatalf("expected empty string, got %q", got1)
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
// Create the file now
|
| 268 |
+
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil {
|
| 269 |
+
t.Fatal(err)
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// Second call - should still return empty (cached), not read the new file
|
| 273 |
+
got2, _ := s.Get(ctx)
|
| 274 |
+
if got2 != "" {
|
| 275 |
+
t.Fatalf("cache should return empty, got %q", got2)
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// After TTL expires, should see the new value
|
| 279 |
+
time.Sleep(110 * time.Millisecond)
|
| 280 |
+
got3, _ := s.Get(ctx)
|
| 281 |
+
if got3 != "new-value" {
|
| 282 |
+
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
| 287 |
+
defaultSource := NewStaticSecretSource("default")
|
| 288 |
+
s := NewMappedSecretSource(defaultSource)
|
| 289 |
+
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
| 290 |
+
{
|
| 291 |
+
UpstreamAPIKey: "u1",
|
| 292 |
+
APIKeys: []string{"k1"},
|
| 293 |
+
},
|
| 294 |
+
})
|
| 295 |
+
|
| 296 |
+
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
| 297 |
+
got, err := s.Get(ctx)
|
| 298 |
+
if err != nil {
|
| 299 |
+
t.Fatalf("unexpected error: %v", err)
|
| 300 |
+
}
|
| 301 |
+
if got != "u1" {
|
| 302 |
+
t.Fatalf("want u1, got %q", got)
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
| 306 |
+
got, err = s.Get(ctx)
|
| 307 |
+
if err != nil {
|
| 308 |
+
t.Fatalf("unexpected error: %v", err)
|
| 309 |
+
}
|
| 310 |
+
if got != "default" {
|
| 311 |
+
t.Fatalf("want default fallback, got %q", got)
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
| 316 |
+
defaultSource := NewStaticSecretSource("default")
|
| 317 |
+
s := NewMappedSecretSource(defaultSource)
|
| 318 |
+
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
| 319 |
+
{
|
| 320 |
+
UpstreamAPIKey: "u1",
|
| 321 |
+
APIKeys: []string{"k1"},
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
UpstreamAPIKey: "u2",
|
| 325 |
+
APIKeys: []string{"k1"},
|
| 326 |
+
},
|
| 327 |
+
})
|
| 328 |
+
|
| 329 |
+
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
| 330 |
+
got, err := s.Get(ctx)
|
| 331 |
+
if err != nil {
|
| 332 |
+
t.Fatalf("unexpected error: %v", err)
|
| 333 |
+
}
|
| 334 |
+
if got != "u1" {
|
| 335 |
+
t.Fatalf("want u1 (first wins), got %q", got)
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
| 340 |
+
hook := test.NewLocal(log.StandardLogger())
|
| 341 |
+
defer hook.Reset()
|
| 342 |
+
|
| 343 |
+
defaultSource := NewStaticSecretSource("default")
|
| 344 |
+
s := NewMappedSecretSource(defaultSource)
|
| 345 |
+
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
| 346 |
+
{
|
| 347 |
+
UpstreamAPIKey: "u1",
|
| 348 |
+
APIKeys: []string{"k1"},
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
UpstreamAPIKey: "u2",
|
| 352 |
+
APIKeys: []string{"k1"},
|
| 353 |
+
},
|
| 354 |
+
})
|
| 355 |
+
|
| 356 |
+
foundWarning := false
|
| 357 |
+
for _, entry := range hook.AllEntries() {
|
| 358 |
+
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
| 359 |
+
foundWarning = true
|
| 360 |
+
break
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
if !foundWarning {
|
| 364 |
+
t.Fatal("expected warning log for duplicate client key, but none was found")
|
| 365 |
+
}
|
| 366 |
+
}
|
internal/api/modules/modules.go
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package modules provides a pluggable routing module system for extending
|
| 2 |
+
// the API server with optional features without modifying core routing logic.
|
| 3 |
+
package modules
|
| 4 |
+
|
| 5 |
+
import (
|
| 6 |
+
"fmt"
|
| 7 |
+
|
| 8 |
+
"github.com/gin-gonic/gin"
|
| 9 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 10 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
// Context encapsulates the dependencies exposed to routing modules during
|
| 14 |
+
// registration. Modules can use the Gin engine to attach routes, the shared
|
| 15 |
+
// BaseAPIHandler for constructing SDK-specific handlers, and the resolved
|
| 16 |
+
// authentication middleware for protecting routes that require API keys.
|
| 17 |
+
type Context struct {
|
| 18 |
+
Engine *gin.Engine
|
| 19 |
+
BaseHandler *handlers.BaseAPIHandler
|
| 20 |
+
Config *config.Config
|
| 21 |
+
AuthMiddleware gin.HandlerFunc
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// RouteModule represents a pluggable routing module that can register routes
|
| 25 |
+
// and handle configuration updates independently of the core server.
|
| 26 |
+
//
|
| 27 |
+
// DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for
|
| 28 |
+
// backwards compatibility and will be removed in a future version.
|
| 29 |
+
type RouteModule interface {
|
| 30 |
+
// Name returns a human-readable identifier for the module
|
| 31 |
+
Name() string
|
| 32 |
+
|
| 33 |
+
// Register sets up routes and handlers for this module.
|
| 34 |
+
// It receives the Gin engine, base handlers, and current configuration.
|
| 35 |
+
// Returns an error if registration fails (errors are logged but don't stop the server).
|
| 36 |
+
Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error
|
| 37 |
+
|
| 38 |
+
// OnConfigUpdated is called when the configuration is reloaded.
|
| 39 |
+
// Modules can respond to configuration changes here.
|
| 40 |
+
// Returns an error if the update cannot be applied.
|
| 41 |
+
OnConfigUpdated(cfg *config.Config) error
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
// RouteModuleV2 represents a pluggable bundle of routes that can integrate with
|
| 45 |
+
// the API server without modifying its core routing logic. Implementations can
|
| 46 |
+
// attach routes during Register and react to configuration updates via
|
| 47 |
+
// OnConfigUpdated.
|
| 48 |
+
//
|
| 49 |
+
// This is the preferred interface for new modules. It uses Context for cleaner
|
| 50 |
+
// dependency injection and supports idempotent registration.
|
| 51 |
+
type RouteModuleV2 interface {
|
| 52 |
+
// Name returns a unique identifier for logging and diagnostics.
|
| 53 |
+
Name() string
|
| 54 |
+
|
| 55 |
+
// Register wires the module's routes into the provided Gin engine. Modules
|
| 56 |
+
// should treat multiple calls as idempotent and avoid duplicate route
|
| 57 |
+
// registration when invoked more than once.
|
| 58 |
+
Register(ctx Context) error
|
| 59 |
+
|
| 60 |
+
// OnConfigUpdated notifies the module when the server configuration changes
|
| 61 |
+
// via hot reload. Implementations can refresh cached state or emit warnings.
|
| 62 |
+
OnConfigUpdated(cfg *config.Config) error
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// RegisterModule is a helper that registers a module using either the V1 or V2
|
| 66 |
+
// interface. This allows gradual migration from V1 to V2 without breaking
|
| 67 |
+
// existing modules.
|
| 68 |
+
//
|
| 69 |
+
// Example usage:
|
| 70 |
+
//
|
| 71 |
+
// ctx := modules.Context{
|
| 72 |
+
// Engine: engine,
|
| 73 |
+
// BaseHandler: baseHandler,
|
| 74 |
+
// Config: cfg,
|
| 75 |
+
// AuthMiddleware: authMiddleware,
|
| 76 |
+
// }
|
| 77 |
+
// if err := modules.RegisterModule(ctx, ampModule); err != nil {
|
| 78 |
+
// log.Errorf("Failed to register module: %v", err)
|
| 79 |
+
// }
|
| 80 |
+
func RegisterModule(ctx Context, mod interface{}) error {
|
| 81 |
+
// Try V2 interface first (preferred)
|
| 82 |
+
if v2, ok := mod.(RouteModuleV2); ok {
|
| 83 |
+
return v2.Register(ctx)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Fall back to V1 interface for backwards compatibility
|
| 87 |
+
if v1, ok := mod.(RouteModule); ok {
|
| 88 |
+
return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config)
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod)
|
| 92 |
+
}
|
internal/api/server.go
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package api provides the HTTP API server implementation for the CLI Proxy API.
|
| 2 |
+
// It includes the main server struct, routing setup, middleware for CORS and authentication,
|
| 3 |
+
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
|
| 4 |
+
// The server supports hot-reloading of clients and configuration.
|
| 5 |
+
package api
|
| 6 |
+
|
| 7 |
+
import (
|
| 8 |
+
"context"
|
| 9 |
+
"crypto/subtle"
|
| 10 |
+
"errors"
|
| 11 |
+
"fmt"
|
| 12 |
+
"net/http"
|
| 13 |
+
"os"
|
| 14 |
+
"path/filepath"
|
| 15 |
+
"strings"
|
| 16 |
+
"sync"
|
| 17 |
+
"sync/atomic"
|
| 18 |
+
"time"
|
| 19 |
+
|
| 20 |
+
"github.com/gin-gonic/gin"
|
| 21 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/access"
|
| 22 |
+
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
|
| 23 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
| 24 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
| 25 |
+
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
| 26 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 27 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
| 28 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
| 29 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
| 30 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 31 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 32 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
| 33 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
| 34 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
| 35 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
| 36 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 37 |
+
log "github.com/sirupsen/logrus"
|
| 38 |
+
"gopkg.in/yaml.v3"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
const oauthCallbackSuccessHTML = `<html><head><meta charset="utf-8"><title>Authentication successful</title><script>setTimeout(function(){window.close();},5000);</script></head><body><h1>Authentication successful!</h1><p>You can close this window.</p><p>This window will close automatically in 5 seconds.</p></body></html>`
|
| 42 |
+
|
| 43 |
+
type serverOptionConfig struct {
|
| 44 |
+
extraMiddleware []gin.HandlerFunc
|
| 45 |
+
engineConfigurator func(*gin.Engine)
|
| 46 |
+
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
|
| 47 |
+
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
|
| 48 |
+
localPassword string
|
| 49 |
+
keepAliveEnabled bool
|
| 50 |
+
keepAliveTimeout time.Duration
|
| 51 |
+
keepAliveOnTimeout func()
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
// ServerOption customises HTTP server construction.
|
| 55 |
+
type ServerOption func(*serverOptionConfig)
|
| 56 |
+
|
| 57 |
+
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
| 58 |
+
configDir := filepath.Dir(configPath)
|
| 59 |
+
if base := util.WritablePath(); base != "" {
|
| 60 |
+
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
|
| 61 |
+
}
|
| 62 |
+
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// WithMiddleware appends additional Gin middleware during server construction.
|
| 66 |
+
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption {
|
| 67 |
+
return func(cfg *serverOptionConfig) {
|
| 68 |
+
cfg.extraMiddleware = append(cfg.extraMiddleware, mw...)
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
|
| 73 |
+
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
|
| 74 |
+
return func(cfg *serverOptionConfig) {
|
| 75 |
+
cfg.engineConfigurator = fn
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// WithRouterConfigurator appends a callback after default routes are registered.
|
| 80 |
+
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
|
| 81 |
+
return func(cfg *serverOptionConfig) {
|
| 82 |
+
cfg.routerConfigurator = fn
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
|
| 87 |
+
func WithLocalManagementPassword(password string) ServerOption {
|
| 88 |
+
return func(cfg *serverOptionConfig) {
|
| 89 |
+
cfg.localPassword = password
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
| 94 |
+
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
| 95 |
+
return func(cfg *serverOptionConfig) {
|
| 96 |
+
if timeout <= 0 || onTimeout == nil {
|
| 97 |
+
return
|
| 98 |
+
}
|
| 99 |
+
cfg.keepAliveEnabled = true
|
| 100 |
+
cfg.keepAliveTimeout = timeout
|
| 101 |
+
cfg.keepAliveOnTimeout = onTimeout
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// WithRequestLoggerFactory customises request logger creation.
|
| 106 |
+
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
| 107 |
+
return func(cfg *serverOptionConfig) {
|
| 108 |
+
cfg.requestLoggerFactory = factory
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Server represents the main API server.
|
| 113 |
+
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
| 114 |
+
type Server struct {
|
| 115 |
+
// engine is the Gin web framework engine instance.
|
| 116 |
+
engine *gin.Engine
|
| 117 |
+
|
| 118 |
+
// server is the underlying HTTP server.
|
| 119 |
+
server *http.Server
|
| 120 |
+
|
| 121 |
+
// handlers contains the API handlers for processing requests.
|
| 122 |
+
handlers *handlers.BaseAPIHandler
|
| 123 |
+
|
| 124 |
+
// cfg holds the current server configuration.
|
| 125 |
+
cfg *config.Config
|
| 126 |
+
|
| 127 |
+
// oldConfigYaml stores a YAML snapshot of the previous configuration for change detection.
|
| 128 |
+
// This prevents issues when the config object is modified in place by Management API.
|
| 129 |
+
oldConfigYaml []byte
|
| 130 |
+
|
| 131 |
+
// accessManager handles request authentication providers.
|
| 132 |
+
accessManager *sdkaccess.Manager
|
| 133 |
+
|
| 134 |
+
// requestLogger is the request logger instance for dynamic configuration updates.
|
| 135 |
+
requestLogger logging.RequestLogger
|
| 136 |
+
loggerToggle func(bool)
|
| 137 |
+
|
| 138 |
+
// configFilePath is the absolute path to the YAML config file for persistence.
|
| 139 |
+
configFilePath string
|
| 140 |
+
|
| 141 |
+
// currentPath is the absolute path to the current working directory.
|
| 142 |
+
currentPath string
|
| 143 |
+
|
| 144 |
+
// wsRoutes tracks registered websocket upgrade paths.
|
| 145 |
+
wsRouteMu sync.Mutex
|
| 146 |
+
wsRoutes map[string]struct{}
|
| 147 |
+
wsAuthChanged func(bool, bool)
|
| 148 |
+
wsAuthEnabled atomic.Bool
|
| 149 |
+
|
| 150 |
+
// management handler
|
| 151 |
+
mgmt *managementHandlers.Handler
|
| 152 |
+
|
| 153 |
+
// ampModule is the Amp routing module for model mapping hot-reload
|
| 154 |
+
ampModule *ampmodule.AmpModule
|
| 155 |
+
|
| 156 |
+
// managementRoutesRegistered tracks whether the management routes have been attached to the engine.
|
| 157 |
+
managementRoutesRegistered atomic.Bool
|
| 158 |
+
// managementRoutesEnabled controls whether management endpoints serve real handlers.
|
| 159 |
+
managementRoutesEnabled atomic.Bool
|
| 160 |
+
|
| 161 |
+
// envManagementSecret indicates whether MANAGEMENT_PASSWORD is configured.
|
| 162 |
+
envManagementSecret bool
|
| 163 |
+
|
| 164 |
+
localPassword string
|
| 165 |
+
|
| 166 |
+
keepAliveEnabled bool
|
| 167 |
+
keepAliveTimeout time.Duration
|
| 168 |
+
keepAliveOnTimeout func()
|
| 169 |
+
keepAliveHeartbeat chan struct{}
|
| 170 |
+
keepAliveStop chan struct{}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// NewServer creates and initializes a new API server instance.
|
| 174 |
+
// It sets up the Gin engine, middleware, routes, and handlers.
|
| 175 |
+
//
|
| 176 |
+
// Parameters:
|
| 177 |
+
// - cfg: The server configuration
|
| 178 |
+
// - authManager: core runtime auth manager
|
| 179 |
+
// - accessManager: request authentication manager
|
| 180 |
+
//
|
| 181 |
+
// Returns:
|
| 182 |
+
// - *Server: A new server instance
|
| 183 |
+
func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server {
|
| 184 |
+
optionState := &serverOptionConfig{
|
| 185 |
+
requestLoggerFactory: defaultRequestLoggerFactory,
|
| 186 |
+
}
|
| 187 |
+
for i := range opts {
|
| 188 |
+
opts[i](optionState)
|
| 189 |
+
}
|
| 190 |
+
// Set gin mode
|
| 191 |
+
if !cfg.Debug {
|
| 192 |
+
gin.SetMode(gin.ReleaseMode)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// Create gin engine
|
| 196 |
+
engine := gin.New()
|
| 197 |
+
if optionState.engineConfigurator != nil {
|
| 198 |
+
optionState.engineConfigurator(engine)
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Add middleware
|
| 202 |
+
engine.Use(logging.GinLogrusLogger())
|
| 203 |
+
engine.Use(logging.GinLogrusRecovery())
|
| 204 |
+
for _, mw := range optionState.extraMiddleware {
|
| 205 |
+
engine.Use(mw)
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Add request logging middleware (positioned after recovery, before auth)
|
| 209 |
+
// Resolve logs directory relative to the configuration file directory.
|
| 210 |
+
var requestLogger logging.RequestLogger
|
| 211 |
+
var toggle func(bool)
|
| 212 |
+
if !cfg.CommercialMode {
|
| 213 |
+
if optionState.requestLoggerFactory != nil {
|
| 214 |
+
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
| 215 |
+
}
|
| 216 |
+
if requestLogger != nil {
|
| 217 |
+
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
| 218 |
+
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
| 219 |
+
toggle = setter.SetEnabled
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
engine.Use(corsMiddleware())
|
| 225 |
+
wd, err := os.Getwd()
|
| 226 |
+
if err != nil {
|
| 227 |
+
wd = configFilePath
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
envAdminPassword, envAdminPasswordSet := os.LookupEnv("MANAGEMENT_PASSWORD")
|
| 231 |
+
envAdminPassword = strings.TrimSpace(envAdminPassword)
|
| 232 |
+
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
| 233 |
+
|
| 234 |
+
// Create server instance
|
| 235 |
+
s := &Server{
|
| 236 |
+
engine: engine,
|
| 237 |
+
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
| 238 |
+
cfg: cfg,
|
| 239 |
+
accessManager: accessManager,
|
| 240 |
+
requestLogger: requestLogger,
|
| 241 |
+
loggerToggle: toggle,
|
| 242 |
+
configFilePath: configFilePath,
|
| 243 |
+
currentPath: wd,
|
| 244 |
+
envManagementSecret: envManagementSecret,
|
| 245 |
+
wsRoutes: make(map[string]struct{}),
|
| 246 |
+
}
|
| 247 |
+
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
| 248 |
+
// Save initial YAML snapshot
|
| 249 |
+
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
| 250 |
+
s.applyAccessConfig(nil, cfg)
|
| 251 |
+
if authManager != nil {
|
| 252 |
+
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
| 253 |
+
}
|
| 254 |
+
managementasset.SetCurrentConfig(cfg)
|
| 255 |
+
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
| 256 |
+
// Initialize management handler
|
| 257 |
+
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
| 258 |
+
if optionState.localPassword != "" {
|
| 259 |
+
s.mgmt.SetLocalPassword(optionState.localPassword)
|
| 260 |
+
}
|
| 261 |
+
logDir := filepath.Join(s.currentPath, "logs")
|
| 262 |
+
if base := util.WritablePath(); base != "" {
|
| 263 |
+
logDir = filepath.Join(base, "logs")
|
| 264 |
+
}
|
| 265 |
+
s.mgmt.SetLogDirectory(logDir)
|
| 266 |
+
s.localPassword = optionState.localPassword
|
| 267 |
+
|
| 268 |
+
// Setup routes
|
| 269 |
+
s.setupRoutes()
|
| 270 |
+
|
| 271 |
+
// Register Amp module using V2 interface with Context
|
| 272 |
+
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
|
| 273 |
+
ctx := modules.Context{
|
| 274 |
+
Engine: engine,
|
| 275 |
+
BaseHandler: s.handlers,
|
| 276 |
+
Config: cfg,
|
| 277 |
+
AuthMiddleware: AuthMiddleware(accessManager),
|
| 278 |
+
}
|
| 279 |
+
if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
|
| 280 |
+
log.Errorf("Failed to register Amp module: %v", err)
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// Apply additional router configurators from options
|
| 284 |
+
if optionState.routerConfigurator != nil {
|
| 285 |
+
optionState.routerConfigurator(engine, s.handlers, cfg)
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// Register management routes when configuration or environment secrets are available.
|
| 289 |
+
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret
|
| 290 |
+
s.managementRoutesEnabled.Store(hasManagementSecret)
|
| 291 |
+
if hasManagementSecret {
|
| 292 |
+
s.registerManagementRoutes()
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
if optionState.keepAliveEnabled {
|
| 296 |
+
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// Create HTTP server
|
| 300 |
+
s.server = &http.Server{
|
| 301 |
+
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
| 302 |
+
Handler: engine,
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
return s
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
// setupRoutes configures the API routes for the server.
|
| 309 |
+
// It defines the endpoints and associates them with their respective handlers.
|
| 310 |
+
func (s *Server) setupRoutes() {
|
| 311 |
+
s.engine.GET("/management.html", s.serveManagementControlPanel)
|
| 312 |
+
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
|
| 313 |
+
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
|
| 314 |
+
geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers)
|
| 315 |
+
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers)
|
| 316 |
+
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers)
|
| 317 |
+
|
| 318 |
+
// OpenAI compatible API routes
|
| 319 |
+
v1 := s.engine.Group("/v1")
|
| 320 |
+
v1.Use(AuthMiddleware(s.accessManager))
|
| 321 |
+
{
|
| 322 |
+
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
|
| 323 |
+
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
| 324 |
+
v1.POST("/completions", openaiHandlers.Completions)
|
| 325 |
+
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
| 326 |
+
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
| 327 |
+
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// Gemini compatible API routes
|
| 331 |
+
v1beta := s.engine.Group("/v1beta")
|
| 332 |
+
v1beta.Use(AuthMiddleware(s.accessManager))
|
| 333 |
+
{
|
| 334 |
+
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
| 335 |
+
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
| 336 |
+
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// Root endpoint
|
| 340 |
+
s.engine.GET("/", func(c *gin.Context) {
|
| 341 |
+
c.JSON(http.StatusOK, gin.H{
|
| 342 |
+
"message": "CLI Proxy API Server",
|
| 343 |
+
"endpoints": []string{
|
| 344 |
+
"POST /v1/chat/completions",
|
| 345 |
+
"POST /v1/completions",
|
| 346 |
+
"GET /v1/models",
|
| 347 |
+
},
|
| 348 |
+
})
|
| 349 |
+
})
|
| 350 |
+
|
| 351 |
+
// Event logging endpoint - handles Claude Code telemetry requests
|
| 352 |
+
// Returns 200 OK to prevent 404 errors in logs
|
| 353 |
+
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
|
| 354 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 355 |
+
})
|
| 356 |
+
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
| 357 |
+
|
| 358 |
+
// OAuth callback endpoints (reuse main server port)
|
| 359 |
+
// These endpoints receive provider redirects and persist
|
| 360 |
+
// the short-lived code/state for the waiting goroutine.
|
| 361 |
+
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
|
| 362 |
+
code := c.Query("code")
|
| 363 |
+
state := c.Query("state")
|
| 364 |
+
errStr := c.Query("error")
|
| 365 |
+
if errStr == "" {
|
| 366 |
+
errStr = c.Query("error_description")
|
| 367 |
+
}
|
| 368 |
+
if state != "" {
|
| 369 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
| 370 |
+
}
|
| 371 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 372 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
s.engine.GET("/codex/callback", func(c *gin.Context) {
|
| 376 |
+
code := c.Query("code")
|
| 377 |
+
state := c.Query("state")
|
| 378 |
+
errStr := c.Query("error")
|
| 379 |
+
if errStr == "" {
|
| 380 |
+
errStr = c.Query("error_description")
|
| 381 |
+
}
|
| 382 |
+
if state != "" {
|
| 383 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
| 384 |
+
}
|
| 385 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 386 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
s.engine.GET("/google/callback", func(c *gin.Context) {
|
| 390 |
+
code := c.Query("code")
|
| 391 |
+
state := c.Query("state")
|
| 392 |
+
errStr := c.Query("error")
|
| 393 |
+
if errStr == "" {
|
| 394 |
+
errStr = c.Query("error_description")
|
| 395 |
+
}
|
| 396 |
+
if state != "" {
|
| 397 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
| 398 |
+
}
|
| 399 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 400 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 401 |
+
})
|
| 402 |
+
|
| 403 |
+
s.engine.GET("/iflow/callback", func(c *gin.Context) {
|
| 404 |
+
code := c.Query("code")
|
| 405 |
+
state := c.Query("state")
|
| 406 |
+
errStr := c.Query("error")
|
| 407 |
+
if errStr == "" {
|
| 408 |
+
errStr = c.Query("error_description")
|
| 409 |
+
}
|
| 410 |
+
if state != "" {
|
| 411 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
| 412 |
+
}
|
| 413 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 414 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 415 |
+
})
|
| 416 |
+
|
| 417 |
+
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
|
| 418 |
+
code := c.Query("code")
|
| 419 |
+
state := c.Query("state")
|
| 420 |
+
errStr := c.Query("error")
|
| 421 |
+
if errStr == "" {
|
| 422 |
+
errStr = c.Query("error_description")
|
| 423 |
+
}
|
| 424 |
+
if state != "" {
|
| 425 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
| 426 |
+
}
|
| 427 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 428 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 429 |
+
})
|
| 430 |
+
|
| 431 |
+
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
| 432 |
+
code := c.Query("code")
|
| 433 |
+
state := c.Query("state")
|
| 434 |
+
errStr := c.Query("error")
|
| 435 |
+
if errStr == "" {
|
| 436 |
+
errStr = c.Query("error_description")
|
| 437 |
+
}
|
| 438 |
+
if state != "" {
|
| 439 |
+
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
|
| 440 |
+
}
|
| 441 |
+
c.Header("Content-Type", "text/html; charset=utf-8")
|
| 442 |
+
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
| 443 |
+
})
|
| 444 |
+
|
| 445 |
+
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
|
| 449 |
+
// The handler is served as-is without additional middleware beyond the standard stack already configured.
|
| 450 |
+
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
|
| 451 |
+
if s == nil || s.engine == nil || handler == nil {
|
| 452 |
+
return
|
| 453 |
+
}
|
| 454 |
+
trimmed := strings.TrimSpace(path)
|
| 455 |
+
if trimmed == "" {
|
| 456 |
+
trimmed = "/v1/ws"
|
| 457 |
+
}
|
| 458 |
+
if !strings.HasPrefix(trimmed, "/") {
|
| 459 |
+
trimmed = "/" + trimmed
|
| 460 |
+
}
|
| 461 |
+
s.wsRouteMu.Lock()
|
| 462 |
+
if _, exists := s.wsRoutes[trimmed]; exists {
|
| 463 |
+
s.wsRouteMu.Unlock()
|
| 464 |
+
return
|
| 465 |
+
}
|
| 466 |
+
s.wsRoutes[trimmed] = struct{}{}
|
| 467 |
+
s.wsRouteMu.Unlock()
|
| 468 |
+
|
| 469 |
+
authMiddleware := AuthMiddleware(s.accessManager)
|
| 470 |
+
conditionalAuth := func(c *gin.Context) {
|
| 471 |
+
if !s.wsAuthEnabled.Load() {
|
| 472 |
+
c.Next()
|
| 473 |
+
return
|
| 474 |
+
}
|
| 475 |
+
authMiddleware(c)
|
| 476 |
+
}
|
| 477 |
+
finalHandler := func(c *gin.Context) {
|
| 478 |
+
handler.ServeHTTP(c.Writer, c.Request)
|
| 479 |
+
c.Abort()
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
s.engine.GET(trimmed, conditionalAuth, finalHandler)
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
func (s *Server) registerManagementRoutes() {
|
| 486 |
+
if s == nil || s.engine == nil || s.mgmt == nil {
|
| 487 |
+
return
|
| 488 |
+
}
|
| 489 |
+
if !s.managementRoutesRegistered.CompareAndSwap(false, true) {
|
| 490 |
+
return
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
log.Info("management routes registered after secret key configuration")
|
| 494 |
+
|
| 495 |
+
mgmt := s.engine.Group("/v0/management")
|
| 496 |
+
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
| 497 |
+
{
|
| 498 |
+
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
| 499 |
+
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
| 500 |
+
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
| 501 |
+
mgmt.GET("/config", s.mgmt.GetConfig)
|
| 502 |
+
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
| 503 |
+
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
| 504 |
+
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
| 505 |
+
|
| 506 |
+
mgmt.GET("/debug", s.mgmt.GetDebug)
|
| 507 |
+
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
| 508 |
+
mgmt.PATCH("/debug", s.mgmt.PutDebug)
|
| 509 |
+
|
| 510 |
+
mgmt.GET("/logging-to-file", s.mgmt.GetLoggingToFile)
|
| 511 |
+
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
|
| 512 |
+
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
|
| 513 |
+
|
| 514 |
+
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
| 515 |
+
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
| 516 |
+
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
| 517 |
+
|
| 518 |
+
mgmt.GET("/proxy-url", s.mgmt.GetProxyURL)
|
| 519 |
+
mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL)
|
| 520 |
+
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
|
| 521 |
+
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
|
| 522 |
+
|
| 523 |
+
mgmt.POST("/api-call", s.mgmt.APICall)
|
| 524 |
+
|
| 525 |
+
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
|
| 526 |
+
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
| 527 |
+
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
|
| 528 |
+
|
| 529 |
+
mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel)
|
| 530 |
+
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
| 531 |
+
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
|
| 532 |
+
|
| 533 |
+
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
|
| 534 |
+
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
| 535 |
+
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
| 536 |
+
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
|
| 537 |
+
|
| 538 |
+
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
|
| 539 |
+
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
|
| 540 |
+
mgmt.PATCH("/gemini-api-key", s.mgmt.PatchGeminiKey)
|
| 541 |
+
mgmt.DELETE("/gemini-api-key", s.mgmt.DeleteGeminiKey)
|
| 542 |
+
|
| 543 |
+
mgmt.GET("/logs", s.mgmt.GetLogs)
|
| 544 |
+
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
| 545 |
+
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
| 546 |
+
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
| 547 |
+
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
| 548 |
+
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
| 549 |
+
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
| 550 |
+
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
| 551 |
+
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
|
| 552 |
+
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
| 553 |
+
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
| 554 |
+
|
| 555 |
+
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
|
| 556 |
+
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
|
| 557 |
+
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
| 558 |
+
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
|
| 559 |
+
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
|
| 560 |
+
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
|
| 561 |
+
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
| 562 |
+
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
|
| 563 |
+
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
|
| 564 |
+
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
|
| 565 |
+
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
| 566 |
+
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
|
| 567 |
+
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
|
| 568 |
+
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
|
| 569 |
+
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
|
| 570 |
+
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
|
| 571 |
+
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
| 572 |
+
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
| 573 |
+
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
| 574 |
+
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
| 575 |
+
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
| 576 |
+
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
| 577 |
+
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
| 578 |
+
|
| 579 |
+
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
| 580 |
+
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
| 581 |
+
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
| 582 |
+
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
|
| 583 |
+
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
| 584 |
+
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
| 585 |
+
|
| 586 |
+
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
| 587 |
+
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
| 588 |
+
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
|
| 589 |
+
mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey)
|
| 590 |
+
|
| 591 |
+
mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys)
|
| 592 |
+
mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys)
|
| 593 |
+
mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey)
|
| 594 |
+
mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey)
|
| 595 |
+
|
| 596 |
+
mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat)
|
| 597 |
+
mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat)
|
| 598 |
+
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
| 599 |
+
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
| 600 |
+
|
| 601 |
+
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
| 602 |
+
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
| 603 |
+
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
| 604 |
+
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
| 605 |
+
|
| 606 |
+
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
| 607 |
+
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
| 608 |
+
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
| 609 |
+
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
| 610 |
+
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
| 611 |
+
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
| 612 |
+
|
| 613 |
+
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
| 614 |
+
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
| 615 |
+
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
| 616 |
+
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
| 617 |
+
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
| 618 |
+
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
| 619 |
+
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
| 620 |
+
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
| 621 |
+
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
| 622 |
+
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
| 623 |
+
}
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
|
| 627 |
+
return func(c *gin.Context) {
|
| 628 |
+
if !s.managementRoutesEnabled.Load() {
|
| 629 |
+
c.AbortWithStatus(http.StatusNotFound)
|
| 630 |
+
return
|
| 631 |
+
}
|
| 632 |
+
c.Next()
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
| 637 |
+
cfg := s.cfg
|
| 638 |
+
if cfg == nil || cfg.RemoteManagement.DisableControlPanel {
|
| 639 |
+
c.AbortWithStatus(http.StatusNotFound)
|
| 640 |
+
return
|
| 641 |
+
}
|
| 642 |
+
filePath := managementasset.FilePath(s.configFilePath)
|
| 643 |
+
if strings.TrimSpace(filePath) == "" {
|
| 644 |
+
c.AbortWithStatus(http.StatusNotFound)
|
| 645 |
+
return
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
if _, err := os.Stat(filePath); err != nil {
|
| 649 |
+
if os.IsNotExist(err) {
|
| 650 |
+
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
| 651 |
+
c.AbortWithStatus(http.StatusNotFound)
|
| 652 |
+
return
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
log.WithError(err).Error("failed to stat management control panel asset")
|
| 656 |
+
c.AbortWithStatus(http.StatusInternalServerError)
|
| 657 |
+
return
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
c.File(filePath)
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
|
| 664 |
+
if timeout <= 0 || onTimeout == nil {
|
| 665 |
+
return
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
s.keepAliveEnabled = true
|
| 669 |
+
s.keepAliveTimeout = timeout
|
| 670 |
+
s.keepAliveOnTimeout = onTimeout
|
| 671 |
+
s.keepAliveHeartbeat = make(chan struct{}, 1)
|
| 672 |
+
s.keepAliveStop = make(chan struct{}, 1)
|
| 673 |
+
|
| 674 |
+
s.engine.GET("/keep-alive", s.handleKeepAlive)
|
| 675 |
+
|
| 676 |
+
go s.watchKeepAlive()
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
func (s *Server) handleKeepAlive(c *gin.Context) {
|
| 680 |
+
if s.localPassword != "" {
|
| 681 |
+
provided := strings.TrimSpace(c.GetHeader("Authorization"))
|
| 682 |
+
if provided != "" {
|
| 683 |
+
parts := strings.SplitN(provided, " ", 2)
|
| 684 |
+
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
|
| 685 |
+
provided = parts[1]
|
| 686 |
+
}
|
| 687 |
+
}
|
| 688 |
+
if provided == "" {
|
| 689 |
+
provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
|
| 690 |
+
}
|
| 691 |
+
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
|
| 692 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
|
| 693 |
+
return
|
| 694 |
+
}
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
s.signalKeepAlive()
|
| 698 |
+
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
func (s *Server) signalKeepAlive() {
|
| 702 |
+
if !s.keepAliveEnabled {
|
| 703 |
+
return
|
| 704 |
+
}
|
| 705 |
+
select {
|
| 706 |
+
case s.keepAliveHeartbeat <- struct{}{}:
|
| 707 |
+
default:
|
| 708 |
+
}
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
func (s *Server) watchKeepAlive() {
|
| 712 |
+
if !s.keepAliveEnabled {
|
| 713 |
+
return
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
timer := time.NewTimer(s.keepAliveTimeout)
|
| 717 |
+
defer timer.Stop()
|
| 718 |
+
|
| 719 |
+
for {
|
| 720 |
+
select {
|
| 721 |
+
case <-timer.C:
|
| 722 |
+
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
|
| 723 |
+
if s.keepAliveOnTimeout != nil {
|
| 724 |
+
s.keepAliveOnTimeout()
|
| 725 |
+
}
|
| 726 |
+
return
|
| 727 |
+
case <-s.keepAliveHeartbeat:
|
| 728 |
+
if !timer.Stop() {
|
| 729 |
+
select {
|
| 730 |
+
case <-timer.C:
|
| 731 |
+
default:
|
| 732 |
+
}
|
| 733 |
+
}
|
| 734 |
+
timer.Reset(s.keepAliveTimeout)
|
| 735 |
+
case <-s.keepAliveStop:
|
| 736 |
+
return
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
|
| 742 |
+
// that routes to different handlers based on the User-Agent header.
|
| 743 |
+
// If User-Agent starts with "claude-cli", it routes to Claude handler,
|
| 744 |
+
// otherwise it routes to OpenAI handler.
|
| 745 |
+
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
|
| 746 |
+
return func(c *gin.Context) {
|
| 747 |
+
userAgent := c.GetHeader("User-Agent")
|
| 748 |
+
|
| 749 |
+
// Route to Claude handler if User-Agent starts with "claude-cli"
|
| 750 |
+
if strings.HasPrefix(userAgent, "claude-cli") {
|
| 751 |
+
// log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
|
| 752 |
+
claudeHandler.ClaudeModels(c)
|
| 753 |
+
} else {
|
| 754 |
+
// log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
|
| 755 |
+
openaiHandler.OpenAIModels(c)
|
| 756 |
+
}
|
| 757 |
+
}
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
// Start begins listening for and serving HTTP or HTTPS requests.
|
| 761 |
+
// It's a blocking call and will only return on an unrecoverable error.
|
| 762 |
+
//
|
| 763 |
+
// Returns:
|
| 764 |
+
// - error: An error if the server fails to start
|
| 765 |
+
func (s *Server) Start() error {
|
| 766 |
+
if s == nil || s.server == nil {
|
| 767 |
+
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
| 771 |
+
if useTLS {
|
| 772 |
+
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
| 773 |
+
key := strings.TrimSpace(s.cfg.TLS.Key)
|
| 774 |
+
if cert == "" || key == "" {
|
| 775 |
+
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
| 776 |
+
}
|
| 777 |
+
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
| 778 |
+
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
| 779 |
+
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
| 780 |
+
}
|
| 781 |
+
return nil
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
log.Debugf("Starting API server on %s", s.server.Addr)
|
| 785 |
+
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
| 786 |
+
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
return nil
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
// Stop gracefully shuts down the API server without interrupting any
|
| 793 |
+
// active connections.
|
| 794 |
+
//
|
| 795 |
+
// Parameters:
|
| 796 |
+
// - ctx: The context for graceful shutdown
|
| 797 |
+
//
|
| 798 |
+
// Returns:
|
| 799 |
+
// - error: An error if the server fails to stop
|
| 800 |
+
func (s *Server) Stop(ctx context.Context) error {
|
| 801 |
+
log.Debug("Stopping API server...")
|
| 802 |
+
|
| 803 |
+
if s.keepAliveEnabled {
|
| 804 |
+
select {
|
| 805 |
+
case s.keepAliveStop <- struct{}{}:
|
| 806 |
+
default:
|
| 807 |
+
}
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
// Shutdown the HTTP server.
|
| 811 |
+
if err := s.server.Shutdown(ctx); err != nil {
|
| 812 |
+
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
log.Debug("API server stopped")
|
| 816 |
+
return nil
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
| 820 |
+
// to every response, allowing cross-origin requests.
|
| 821 |
+
//
|
| 822 |
+
// Returns:
|
| 823 |
+
// - gin.HandlerFunc: The CORS middleware handler
|
| 824 |
+
func corsMiddleware() gin.HandlerFunc {
|
| 825 |
+
return func(c *gin.Context) {
|
| 826 |
+
c.Header("Access-Control-Allow-Origin", "*")
|
| 827 |
+
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
| 828 |
+
c.Header("Access-Control-Allow-Headers", "*")
|
| 829 |
+
|
| 830 |
+
if c.Request.Method == "OPTIONS" {
|
| 831 |
+
c.AbortWithStatus(http.StatusNoContent)
|
| 832 |
+
return
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
c.Next()
|
| 836 |
+
}
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
|
| 840 |
+
if s == nil || s.accessManager == nil || newCfg == nil {
|
| 841 |
+
return
|
| 842 |
+
}
|
| 843 |
+
if _, err := access.ApplyAccessProviders(s.accessManager, oldCfg, newCfg); err != nil {
|
| 844 |
+
return
|
| 845 |
+
}
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
// UpdateClients updates the server's client list and configuration.
|
| 849 |
+
// This method is called when the configuration or authentication tokens change.
|
| 850 |
+
//
|
| 851 |
+
// Parameters:
|
| 852 |
+
// - clients: The new slice of AI service clients
|
| 853 |
+
// - cfg: The new application configuration
|
| 854 |
+
func (s *Server) UpdateClients(cfg *config.Config) {
|
| 855 |
+
// Reconstruct old config from YAML snapshot to avoid reference sharing issues
|
| 856 |
+
var oldCfg *config.Config
|
| 857 |
+
if len(s.oldConfigYaml) > 0 {
|
| 858 |
+
_ = yaml.Unmarshal(s.oldConfigYaml, &oldCfg)
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
// Update request logger enabled state if it has changed
|
| 862 |
+
previousRequestLog := false
|
| 863 |
+
if oldCfg != nil {
|
| 864 |
+
previousRequestLog = oldCfg.RequestLog
|
| 865 |
+
}
|
| 866 |
+
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
|
| 867 |
+
if s.loggerToggle != nil {
|
| 868 |
+
s.loggerToggle(cfg.RequestLog)
|
| 869 |
+
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
| 870 |
+
toggler.SetEnabled(cfg.RequestLog)
|
| 871 |
+
}
|
| 872 |
+
if oldCfg != nil {
|
| 873 |
+
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
|
| 874 |
+
} else {
|
| 875 |
+
log.Debugf("request logging toggled to %t", cfg.RequestLog)
|
| 876 |
+
}
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
| 880 |
+
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
| 881 |
+
log.Errorf("failed to reconfigure log output: %v", err)
|
| 882 |
+
} else {
|
| 883 |
+
if oldCfg == nil {
|
| 884 |
+
log.Debug("log output configuration refreshed")
|
| 885 |
+
} else {
|
| 886 |
+
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
| 887 |
+
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
| 888 |
+
}
|
| 889 |
+
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
| 890 |
+
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
| 891 |
+
}
|
| 892 |
+
}
|
| 893 |
+
}
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
| 897 |
+
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
| 898 |
+
if oldCfg != nil {
|
| 899 |
+
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
| 900 |
+
} else {
|
| 901 |
+
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
|
| 902 |
+
}
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
| 906 |
+
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
| 907 |
+
if oldCfg != nil {
|
| 908 |
+
log.Debugf("disable_cooling updated from %t to %t", oldCfg.DisableCooling, cfg.DisableCooling)
|
| 909 |
+
} else {
|
| 910 |
+
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
| 911 |
+
}
|
| 912 |
+
}
|
| 913 |
+
if s.handlers != nil && s.handlers.AuthManager != nil {
|
| 914 |
+
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
// Update log level dynamically when debug flag changes
|
| 918 |
+
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
| 919 |
+
util.SetLogLevel(cfg)
|
| 920 |
+
if oldCfg != nil {
|
| 921 |
+
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
|
| 922 |
+
} else {
|
| 923 |
+
log.Debugf("debug mode toggled to %t", cfg.Debug)
|
| 924 |
+
}
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
prevSecretEmpty := true
|
| 928 |
+
if oldCfg != nil {
|
| 929 |
+
prevSecretEmpty = oldCfg.RemoteManagement.SecretKey == ""
|
| 930 |
+
}
|
| 931 |
+
newSecretEmpty := cfg.RemoteManagement.SecretKey == ""
|
| 932 |
+
if s.envManagementSecret {
|
| 933 |
+
s.registerManagementRoutes()
|
| 934 |
+
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
|
| 935 |
+
log.Info("management routes enabled via MANAGEMENT_PASSWORD")
|
| 936 |
+
} else {
|
| 937 |
+
s.managementRoutesEnabled.Store(true)
|
| 938 |
+
}
|
| 939 |
+
} else {
|
| 940 |
+
switch {
|
| 941 |
+
case prevSecretEmpty && !newSecretEmpty:
|
| 942 |
+
s.registerManagementRoutes()
|
| 943 |
+
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
|
| 944 |
+
log.Info("management routes enabled after secret key update")
|
| 945 |
+
} else {
|
| 946 |
+
s.managementRoutesEnabled.Store(true)
|
| 947 |
+
}
|
| 948 |
+
case !prevSecretEmpty && newSecretEmpty:
|
| 949 |
+
if s.managementRoutesEnabled.CompareAndSwap(true, false) {
|
| 950 |
+
log.Info("management routes disabled after secret key removal")
|
| 951 |
+
} else {
|
| 952 |
+
s.managementRoutesEnabled.Store(false)
|
| 953 |
+
}
|
| 954 |
+
default:
|
| 955 |
+
s.managementRoutesEnabled.Store(!newSecretEmpty)
|
| 956 |
+
}
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
s.applyAccessConfig(oldCfg, cfg)
|
| 960 |
+
s.cfg = cfg
|
| 961 |
+
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
| 962 |
+
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
|
| 963 |
+
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
|
| 964 |
+
}
|
| 965 |
+
managementasset.SetCurrentConfig(cfg)
|
| 966 |
+
// Save YAML snapshot for next comparison
|
| 967 |
+
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
| 968 |
+
|
| 969 |
+
s.handlers.UpdateClients(&cfg.SDKConfig)
|
| 970 |
+
|
| 971 |
+
if !cfg.RemoteManagement.DisableControlPanel {
|
| 972 |
+
staticDir := managementasset.StaticDir(s.configFilePath)
|
| 973 |
+
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
| 974 |
+
}
|
| 975 |
+
if s.mgmt != nil {
|
| 976 |
+
s.mgmt.SetConfig(cfg)
|
| 977 |
+
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
// Notify Amp module of config changes (for model mapping hot-reload)
|
| 981 |
+
if s.ampModule != nil {
|
| 982 |
+
log.Debugf("triggering amp module config update")
|
| 983 |
+
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
| 984 |
+
log.Errorf("failed to update Amp module config: %v", err)
|
| 985 |
+
}
|
| 986 |
+
} else {
|
| 987 |
+
log.Warnf("amp module is nil, skipping config update")
|
| 988 |
+
}
|
| 989 |
+
|
| 990 |
+
// Count client sources from configuration and auth directory
|
| 991 |
+
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
| 992 |
+
geminiAPIKeyCount := len(cfg.GeminiKey)
|
| 993 |
+
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
| 994 |
+
codexAPIKeyCount := len(cfg.CodexKey)
|
| 995 |
+
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
|
| 996 |
+
openAICompatCount := 0
|
| 997 |
+
for i := range cfg.OpenAICompatibility {
|
| 998 |
+
entry := cfg.OpenAICompatibility[i]
|
| 999 |
+
openAICompatCount += len(entry.APIKeyEntries)
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
| 1003 |
+
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
| 1004 |
+
total,
|
| 1005 |
+
authFiles,
|
| 1006 |
+
geminiAPIKeyCount,
|
| 1007 |
+
claudeAPIKeyCount,
|
| 1008 |
+
codexAPIKeyCount,
|
| 1009 |
+
vertexAICompatCount,
|
| 1010 |
+
openAICompatCount,
|
| 1011 |
+
)
|
| 1012 |
+
}
|
| 1013 |
+
|
| 1014 |
+
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
|
| 1015 |
+
if s == nil {
|
| 1016 |
+
return
|
| 1017 |
+
}
|
| 1018 |
+
s.wsAuthChanged = fn
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
// (management handlers moved to internal/api/handlers/management)
|
| 1022 |
+
|
| 1023 |
+
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
| 1024 |
+
// using the configured authentication providers. When no providers are available,
|
| 1025 |
+
// it allows all requests (legacy behaviour).
|
| 1026 |
+
func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
| 1027 |
+
return func(c *gin.Context) {
|
| 1028 |
+
if manager == nil {
|
| 1029 |
+
c.Next()
|
| 1030 |
+
return
|
| 1031 |
+
}
|
| 1032 |
+
|
| 1033 |
+
result, err := manager.Authenticate(c.Request.Context(), c.Request)
|
| 1034 |
+
if err == nil {
|
| 1035 |
+
if result != nil {
|
| 1036 |
+
c.Set("apiKey", result.Principal)
|
| 1037 |
+
c.Set("accessProvider", result.Provider)
|
| 1038 |
+
if len(result.Metadata) > 0 {
|
| 1039 |
+
c.Set("accessMetadata", result.Metadata)
|
| 1040 |
+
}
|
| 1041 |
+
}
|
| 1042 |
+
c.Next()
|
| 1043 |
+
return
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
switch {
|
| 1047 |
+
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
| 1048 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
| 1049 |
+
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
| 1050 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
| 1051 |
+
default:
|
| 1052 |
+
log.Errorf("authentication middleware error: %v", err)
|
| 1053 |
+
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
| 1054 |
+
}
|
| 1055 |
+
}
|
| 1056 |
+
}
|
internal/api/server_test.go
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package api
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"net/http"
|
| 5 |
+
"net/http/httptest"
|
| 6 |
+
"os"
|
| 7 |
+
"path/filepath"
|
| 8 |
+
"strings"
|
| 9 |
+
"testing"
|
| 10 |
+
|
| 11 |
+
gin "github.com/gin-gonic/gin"
|
| 12 |
+
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 13 |
+
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
| 14 |
+
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
| 15 |
+
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
func newTestServer(t *testing.T) *Server {
|
| 19 |
+
t.Helper()
|
| 20 |
+
|
| 21 |
+
gin.SetMode(gin.TestMode)
|
| 22 |
+
|
| 23 |
+
tmpDir := t.TempDir()
|
| 24 |
+
authDir := filepath.Join(tmpDir, "auth")
|
| 25 |
+
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
| 26 |
+
t.Fatalf("failed to create auth dir: %v", err)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
cfg := &proxyconfig.Config{
|
| 30 |
+
SDKConfig: sdkconfig.SDKConfig{
|
| 31 |
+
APIKeys: []string{"test-key"},
|
| 32 |
+
},
|
| 33 |
+
Port: 0,
|
| 34 |
+
AuthDir: authDir,
|
| 35 |
+
Debug: true,
|
| 36 |
+
LoggingToFile: false,
|
| 37 |
+
UsageStatisticsEnabled: false,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
authManager := auth.NewManager(nil, nil, nil)
|
| 41 |
+
accessManager := sdkaccess.NewManager()
|
| 42 |
+
|
| 43 |
+
configPath := filepath.Join(tmpDir, "config.yaml")
|
| 44 |
+
return NewServer(cfg, authManager, accessManager, configPath)
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
func TestAmpProviderModelRoutes(t *testing.T) {
|
| 48 |
+
testCases := []struct {
|
| 49 |
+
name string
|
| 50 |
+
path string
|
| 51 |
+
wantStatus int
|
| 52 |
+
wantContains string
|
| 53 |
+
}{
|
| 54 |
+
{
|
| 55 |
+
name: "openai root models",
|
| 56 |
+
path: "/api/provider/openai/models",
|
| 57 |
+
wantStatus: http.StatusOK,
|
| 58 |
+
wantContains: `"object":"list"`,
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
name: "groq root models",
|
| 62 |
+
path: "/api/provider/groq/models",
|
| 63 |
+
wantStatus: http.StatusOK,
|
| 64 |
+
wantContains: `"object":"list"`,
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
name: "openai models",
|
| 68 |
+
path: "/api/provider/openai/v1/models",
|
| 69 |
+
wantStatus: http.StatusOK,
|
| 70 |
+
wantContains: `"object":"list"`,
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
name: "anthropic models",
|
| 74 |
+
path: "/api/provider/anthropic/v1/models",
|
| 75 |
+
wantStatus: http.StatusOK,
|
| 76 |
+
wantContains: `"data"`,
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
name: "google models v1",
|
| 80 |
+
path: "/api/provider/google/v1/models",
|
| 81 |
+
wantStatus: http.StatusOK,
|
| 82 |
+
wantContains: `"models"`,
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
name: "google models v1beta",
|
| 86 |
+
path: "/api/provider/google/v1beta/models",
|
| 87 |
+
wantStatus: http.StatusOK,
|
| 88 |
+
wantContains: `"models"`,
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
for _, tc := range testCases {
|
| 93 |
+
tc := tc
|
| 94 |
+
t.Run(tc.name, func(t *testing.T) {
|
| 95 |
+
server := newTestServer(t)
|
| 96 |
+
|
| 97 |
+
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
| 98 |
+
req.Header.Set("Authorization", "Bearer test-key")
|
| 99 |
+
|
| 100 |
+
rr := httptest.NewRecorder()
|
| 101 |
+
server.engine.ServeHTTP(rr, req)
|
| 102 |
+
|
| 103 |
+
if rr.Code != tc.wantStatus {
|
| 104 |
+
t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String())
|
| 105 |
+
}
|
| 106 |
+
if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) {
|
| 107 |
+
t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body)
|
| 108 |
+
}
|
| 109 |
+
})
|
| 110 |
+
}
|
| 111 |
+
}
|
internal/auth/claude/anthropic.go
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package claude
|
| 2 |
+
|
| 3 |
+
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
| 4 |
+
type PKCECodes struct {
|
| 5 |
+
// CodeVerifier is the cryptographically random string used to correlate
|
| 6 |
+
// the authorization request to the token request
|
| 7 |
+
CodeVerifier string `json:"code_verifier"`
|
| 8 |
+
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
| 9 |
+
CodeChallenge string `json:"code_challenge"`
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
// ClaudeTokenData holds OAuth token information from Anthropic
|
| 13 |
+
type ClaudeTokenData struct {
|
| 14 |
+
// AccessToken is the OAuth2 access token for API access
|
| 15 |
+
AccessToken string `json:"access_token"`
|
| 16 |
+
// RefreshToken is used to obtain new access tokens
|
| 17 |
+
RefreshToken string `json:"refresh_token"`
|
| 18 |
+
// Email is the Anthropic account email
|
| 19 |
+
Email string `json:"email"`
|
| 20 |
+
// Expire is the timestamp of the token expire
|
| 21 |
+
Expire string `json:"expired"`
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// ClaudeAuthBundle aggregates authentication data after OAuth flow completion
|
| 25 |
+
type ClaudeAuthBundle struct {
|
| 26 |
+
// APIKey is the Anthropic API key obtained from token exchange
|
| 27 |
+
APIKey string `json:"api_key"`
|
| 28 |
+
// TokenData contains the OAuth tokens from the authentication flow
|
| 29 |
+
TokenData ClaudeTokenData `json:"token_data"`
|
| 30 |
+
// LastRefresh is the timestamp of the last token refresh
|
| 31 |
+
LastRefresh string `json:"last_refresh"`
|
| 32 |
+
}
|
internal/auth/claude/anthropic_auth.go
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides OAuth2 authentication functionality for Anthropic's Claude API.
|
| 2 |
+
// This package implements the complete OAuth2 flow with PKCE (Proof Key for Code Exchange)
|
| 3 |
+
// for secure authentication with Claude API, including token exchange, refresh, and storage.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"context"
|
| 8 |
+
"encoding/json"
|
| 9 |
+
"fmt"
|
| 10 |
+
"io"
|
| 11 |
+
"net/http"
|
| 12 |
+
"net/url"
|
| 13 |
+
"strings"
|
| 14 |
+
"time"
|
| 15 |
+
|
| 16 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
| 17 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
| 18 |
+
log "github.com/sirupsen/logrus"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
const (
|
| 22 |
+
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
| 23 |
+
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
| 24 |
+
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
| 25 |
+
redirectURI = "http://localhost:54545/callback"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
| 29 |
+
// It contains access token, refresh token, and associated user/organization information.
|
| 30 |
+
type tokenResponse struct {
|
| 31 |
+
AccessToken string `json:"access_token"`
|
| 32 |
+
RefreshToken string `json:"refresh_token"`
|
| 33 |
+
TokenType string `json:"token_type"`
|
| 34 |
+
ExpiresIn int `json:"expires_in"`
|
| 35 |
+
Organization struct {
|
| 36 |
+
UUID string `json:"uuid"`
|
| 37 |
+
Name string `json:"name"`
|
| 38 |
+
} `json:"organization"`
|
| 39 |
+
Account struct {
|
| 40 |
+
UUID string `json:"uuid"`
|
| 41 |
+
EmailAddress string `json:"email_address"`
|
| 42 |
+
} `json:"account"`
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// ClaudeAuth handles Anthropic OAuth2 authentication flow.
|
| 46 |
+
// It provides methods for generating authorization URLs, exchanging codes for tokens,
|
| 47 |
+
// and refreshing expired tokens using PKCE for enhanced security.
|
| 48 |
+
type ClaudeAuth struct {
|
| 49 |
+
httpClient *http.Client
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// NewClaudeAuth creates a new Anthropic authentication service.
|
| 53 |
+
// It initializes the HTTP client with proxy settings from the configuration.
|
| 54 |
+
//
|
| 55 |
+
// Parameters:
|
| 56 |
+
// - cfg: The application configuration containing proxy settings
|
| 57 |
+
//
|
| 58 |
+
// Returns:
|
| 59 |
+
// - *ClaudeAuth: A new Claude authentication service instance
|
| 60 |
+
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
|
| 61 |
+
return &ClaudeAuth{
|
| 62 |
+
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// GenerateAuthURL creates the OAuth authorization URL with PKCE.
|
| 67 |
+
// This method generates a secure authorization URL including PKCE challenge codes
|
| 68 |
+
// for the OAuth2 flow with Anthropic's API.
|
| 69 |
+
//
|
| 70 |
+
// Parameters:
|
| 71 |
+
// - state: A random state parameter for CSRF protection
|
| 72 |
+
// - pkceCodes: The PKCE codes for secure code exchange
|
| 73 |
+
//
|
| 74 |
+
// Returns:
|
| 75 |
+
// - string: The complete authorization URL
|
| 76 |
+
// - string: The state parameter for verification
|
| 77 |
+
// - error: An error if PKCE codes are missing or URL generation fails
|
| 78 |
+
func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, string, error) {
|
| 79 |
+
if pkceCodes == nil {
|
| 80 |
+
return "", "", fmt.Errorf("PKCE codes are required")
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
params := url.Values{
|
| 84 |
+
"code": {"true"},
|
| 85 |
+
"client_id": {anthropicClientID},
|
| 86 |
+
"response_type": {"code"},
|
| 87 |
+
"redirect_uri": {redirectURI},
|
| 88 |
+
"scope": {"org:create_api_key user:profile user:inference"},
|
| 89 |
+
"code_challenge": {pkceCodes.CodeChallenge},
|
| 90 |
+
"code_challenge_method": {"S256"},
|
| 91 |
+
"state": {state},
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
| 95 |
+
return authURL, state, nil
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// parseCodeAndState extracts the authorization code and state from the callback response.
|
| 99 |
+
// It handles the parsing of the code parameter which may contain additional fragments.
|
| 100 |
+
//
|
| 101 |
+
// Parameters:
|
| 102 |
+
// - code: The raw code parameter from the OAuth callback
|
| 103 |
+
//
|
| 104 |
+
// Returns:
|
| 105 |
+
// - parsedCode: The extracted authorization code
|
| 106 |
+
// - parsedState: The extracted state parameter if present
|
| 107 |
+
func (c *ClaudeAuth) parseCodeAndState(code string) (parsedCode, parsedState string) {
|
| 108 |
+
splits := strings.Split(code, "#")
|
| 109 |
+
parsedCode = splits[0]
|
| 110 |
+
if len(splits) > 1 {
|
| 111 |
+
parsedState = splits[1]
|
| 112 |
+
}
|
| 113 |
+
return
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// ExchangeCodeForTokens exchanges authorization code for access tokens.
|
| 117 |
+
// This method implements the OAuth2 token exchange flow using PKCE for security.
|
| 118 |
+
// It sends the authorization code along with PKCE verifier to get access and refresh tokens.
|
| 119 |
+
//
|
| 120 |
+
// Parameters:
|
| 121 |
+
// - ctx: The context for the request
|
| 122 |
+
// - code: The authorization code received from OAuth callback
|
| 123 |
+
// - state: The state parameter for verification
|
| 124 |
+
// - pkceCodes: The PKCE codes for secure verification
|
| 125 |
+
//
|
| 126 |
+
// Returns:
|
| 127 |
+
// - *ClaudeAuthBundle: The complete authentication bundle with tokens
|
| 128 |
+
// - error: An error if token exchange fails
|
| 129 |
+
func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state string, pkceCodes *PKCECodes) (*ClaudeAuthBundle, error) {
|
| 130 |
+
if pkceCodes == nil {
|
| 131 |
+
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
| 132 |
+
}
|
| 133 |
+
newCode, newState := o.parseCodeAndState(code)
|
| 134 |
+
|
| 135 |
+
// Prepare token exchange request
|
| 136 |
+
reqBody := map[string]interface{}{
|
| 137 |
+
"code": newCode,
|
| 138 |
+
"state": state,
|
| 139 |
+
"grant_type": "authorization_code",
|
| 140 |
+
"client_id": anthropicClientID,
|
| 141 |
+
"redirect_uri": redirectURI,
|
| 142 |
+
"code_verifier": pkceCodes.CodeVerifier,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// Include state if present
|
| 146 |
+
if newState != "" {
|
| 147 |
+
reqBody["state"] = newState
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
jsonBody, err := json.Marshal(reqBody)
|
| 151 |
+
if err != nil {
|
| 152 |
+
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
| 156 |
+
|
| 157 |
+
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
| 158 |
+
if err != nil {
|
| 159 |
+
return nil, fmt.Errorf("failed to create token request: %w", err)
|
| 160 |
+
}
|
| 161 |
+
req.Header.Set("Content-Type", "application/json")
|
| 162 |
+
req.Header.Set("Accept", "application/json")
|
| 163 |
+
|
| 164 |
+
resp, err := o.httpClient.Do(req)
|
| 165 |
+
if err != nil {
|
| 166 |
+
return nil, fmt.Errorf("token exchange request failed: %w", err)
|
| 167 |
+
}
|
| 168 |
+
defer func() {
|
| 169 |
+
if errClose := resp.Body.Close(); errClose != nil {
|
| 170 |
+
log.Errorf("failed to close response body: %v", errClose)
|
| 171 |
+
}
|
| 172 |
+
}()
|
| 173 |
+
|
| 174 |
+
body, err := io.ReadAll(resp.Body)
|
| 175 |
+
if err != nil {
|
| 176 |
+
return nil, fmt.Errorf("failed to read token response: %w", err)
|
| 177 |
+
}
|
| 178 |
+
// log.Debugf("Token response: %s", string(body))
|
| 179 |
+
|
| 180 |
+
if resp.StatusCode != http.StatusOK {
|
| 181 |
+
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
| 182 |
+
}
|
| 183 |
+
// log.Debugf("Token response: %s", string(body))
|
| 184 |
+
|
| 185 |
+
var tokenResp tokenResponse
|
| 186 |
+
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
| 187 |
+
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// Create token data
|
| 191 |
+
tokenData := ClaudeTokenData{
|
| 192 |
+
AccessToken: tokenResp.AccessToken,
|
| 193 |
+
RefreshToken: tokenResp.RefreshToken,
|
| 194 |
+
Email: tokenResp.Account.EmailAddress,
|
| 195 |
+
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Create auth bundle
|
| 199 |
+
bundle := &ClaudeAuthBundle{
|
| 200 |
+
TokenData: tokenData,
|
| 201 |
+
LastRefresh: time.Now().Format(time.RFC3339),
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
return bundle, nil
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
// RefreshTokens refreshes the access token using the refresh token.
|
| 208 |
+
// This method exchanges a valid refresh token for a new access token,
|
| 209 |
+
// extending the user's authenticated session.
|
| 210 |
+
//
|
| 211 |
+
// Parameters:
|
| 212 |
+
// - ctx: The context for the request
|
| 213 |
+
// - refreshToken: The refresh token to use for getting new access token
|
| 214 |
+
//
|
| 215 |
+
// Returns:
|
| 216 |
+
// - *ClaudeTokenData: The new token data with updated access token
|
| 217 |
+
// - error: An error if token refresh fails
|
| 218 |
+
func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
|
| 219 |
+
if refreshToken == "" {
|
| 220 |
+
return nil, fmt.Errorf("refresh token is required")
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
reqBody := map[string]interface{}{
|
| 224 |
+
"client_id": anthropicClientID,
|
| 225 |
+
"grant_type": "refresh_token",
|
| 226 |
+
"refresh_token": refreshToken,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
jsonBody, err := json.Marshal(reqBody)
|
| 230 |
+
if err != nil {
|
| 231 |
+
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
| 235 |
+
if err != nil {
|
| 236 |
+
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
req.Header.Set("Content-Type", "application/json")
|
| 240 |
+
req.Header.Set("Accept", "application/json")
|
| 241 |
+
|
| 242 |
+
resp, err := o.httpClient.Do(req)
|
| 243 |
+
if err != nil {
|
| 244 |
+
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
| 245 |
+
}
|
| 246 |
+
defer func() {
|
| 247 |
+
_ = resp.Body.Close()
|
| 248 |
+
}()
|
| 249 |
+
|
| 250 |
+
body, err := io.ReadAll(resp.Body)
|
| 251 |
+
if err != nil {
|
| 252 |
+
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
if resp.StatusCode != http.StatusOK {
|
| 256 |
+
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// log.Debugf("Token response: %s", string(body))
|
| 260 |
+
|
| 261 |
+
var tokenResp tokenResponse
|
| 262 |
+
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
| 263 |
+
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// Create token data
|
| 267 |
+
return &ClaudeTokenData{
|
| 268 |
+
AccessToken: tokenResp.AccessToken,
|
| 269 |
+
RefreshToken: tokenResp.RefreshToken,
|
| 270 |
+
Email: tokenResp.Account.EmailAddress,
|
| 271 |
+
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
| 272 |
+
}, nil
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// CreateTokenStorage creates a new ClaudeTokenStorage from auth bundle and user info.
|
| 276 |
+
// This method converts the authentication bundle into a token storage structure
|
| 277 |
+
// suitable for persistence and later use.
|
| 278 |
+
//
|
| 279 |
+
// Parameters:
|
| 280 |
+
// - bundle: The authentication bundle containing token data
|
| 281 |
+
//
|
| 282 |
+
// Returns:
|
| 283 |
+
// - *ClaudeTokenStorage: A new token storage instance
|
| 284 |
+
func (o *ClaudeAuth) CreateTokenStorage(bundle *ClaudeAuthBundle) *ClaudeTokenStorage {
|
| 285 |
+
storage := &ClaudeTokenStorage{
|
| 286 |
+
AccessToken: bundle.TokenData.AccessToken,
|
| 287 |
+
RefreshToken: bundle.TokenData.RefreshToken,
|
| 288 |
+
LastRefresh: bundle.LastRefresh,
|
| 289 |
+
Email: bundle.TokenData.Email,
|
| 290 |
+
Expire: bundle.TokenData.Expire,
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return storage
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
// RefreshTokensWithRetry refreshes tokens with automatic retry logic.
|
| 297 |
+
// This method implements exponential backoff retry logic for token refresh operations,
|
| 298 |
+
// providing resilience against temporary network or service issues.
|
| 299 |
+
//
|
| 300 |
+
// Parameters:
|
| 301 |
+
// - ctx: The context for the request
|
| 302 |
+
// - refreshToken: The refresh token to use
|
| 303 |
+
// - maxRetries: The maximum number of retry attempts
|
| 304 |
+
//
|
| 305 |
+
// Returns:
|
| 306 |
+
// - *ClaudeTokenData: The refreshed token data
|
| 307 |
+
// - error: An error if all retry attempts fail
|
| 308 |
+
func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*ClaudeTokenData, error) {
|
| 309 |
+
var lastErr error
|
| 310 |
+
|
| 311 |
+
for attempt := 0; attempt < maxRetries; attempt++ {
|
| 312 |
+
if attempt > 0 {
|
| 313 |
+
// Wait before retry
|
| 314 |
+
select {
|
| 315 |
+
case <-ctx.Done():
|
| 316 |
+
return nil, ctx.Err()
|
| 317 |
+
case <-time.After(time.Duration(attempt) * time.Second):
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
| 322 |
+
if err == nil {
|
| 323 |
+
return tokenData, nil
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
lastErr = err
|
| 327 |
+
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// UpdateTokenStorage updates an existing token storage with new token data.
|
| 334 |
+
// This method refreshes the token storage with newly obtained access and refresh tokens,
|
| 335 |
+
// updating timestamps and expiration information.
|
| 336 |
+
//
|
| 337 |
+
// Parameters:
|
| 338 |
+
// - storage: The existing token storage to update
|
| 339 |
+
// - tokenData: The new token data to apply
|
| 340 |
+
func (o *ClaudeAuth) UpdateTokenStorage(storage *ClaudeTokenStorage, tokenData *ClaudeTokenData) {
|
| 341 |
+
storage.AccessToken = tokenData.AccessToken
|
| 342 |
+
storage.RefreshToken = tokenData.RefreshToken
|
| 343 |
+
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
| 344 |
+
storage.Email = tokenData.Email
|
| 345 |
+
storage.Expire = tokenData.Expire
|
| 346 |
+
}
|
internal/auth/claude/errors.go
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides authentication and token management functionality
|
| 2 |
+
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
| 3 |
+
// and retrieval for maintaining authenticated sessions with the Claude API.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"errors"
|
| 8 |
+
"fmt"
|
| 9 |
+
"net/http"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
// OAuthError represents an OAuth-specific error.
|
| 13 |
+
type OAuthError struct {
|
| 14 |
+
// Code is the OAuth error code.
|
| 15 |
+
Code string `json:"error"`
|
| 16 |
+
// Description is a human-readable description of the error.
|
| 17 |
+
Description string `json:"error_description,omitempty"`
|
| 18 |
+
// URI is a URI identifying a human-readable web page with information about the error.
|
| 19 |
+
URI string `json:"error_uri,omitempty"`
|
| 20 |
+
// StatusCode is the HTTP status code associated with the error.
|
| 21 |
+
StatusCode int `json:"-"`
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// Error returns a string representation of the OAuth error.
|
| 25 |
+
func (e *OAuthError) Error() string {
|
| 26 |
+
if e.Description != "" {
|
| 27 |
+
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
| 28 |
+
}
|
| 29 |
+
return fmt.Sprintf("OAuth error: %s", e.Code)
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
|
| 33 |
+
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
| 34 |
+
return &OAuthError{
|
| 35 |
+
Code: code,
|
| 36 |
+
Description: description,
|
| 37 |
+
StatusCode: statusCode,
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// AuthenticationError represents authentication-related errors.
|
| 42 |
+
type AuthenticationError struct {
|
| 43 |
+
// Type is the type of authentication error.
|
| 44 |
+
Type string `json:"type"`
|
| 45 |
+
// Message is a human-readable message describing the error.
|
| 46 |
+
Message string `json:"message"`
|
| 47 |
+
// Code is the HTTP status code associated with the error.
|
| 48 |
+
Code int `json:"code"`
|
| 49 |
+
// Cause is the underlying error that caused this authentication error.
|
| 50 |
+
Cause error `json:"-"`
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
// Error returns a string representation of the authentication error.
|
| 54 |
+
func (e *AuthenticationError) Error() string {
|
| 55 |
+
if e.Cause != nil {
|
| 56 |
+
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
| 57 |
+
}
|
| 58 |
+
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// Common authentication error types.
|
| 62 |
+
var (
|
| 63 |
+
// ErrTokenExpired = &AuthenticationError{
|
| 64 |
+
// Type: "token_expired",
|
| 65 |
+
// Message: "Access token has expired",
|
| 66 |
+
// Code: http.StatusUnauthorized,
|
| 67 |
+
// }
|
| 68 |
+
|
| 69 |
+
// ErrInvalidState represents an error for invalid OAuth state parameter.
|
| 70 |
+
ErrInvalidState = &AuthenticationError{
|
| 71 |
+
Type: "invalid_state",
|
| 72 |
+
Message: "OAuth state parameter is invalid",
|
| 73 |
+
Code: http.StatusBadRequest,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
|
| 77 |
+
ErrCodeExchangeFailed = &AuthenticationError{
|
| 78 |
+
Type: "code_exchange_failed",
|
| 79 |
+
Message: "Failed to exchange authorization code for tokens",
|
| 80 |
+
Code: http.StatusBadRequest,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// ErrServerStartFailed represents an error when starting the OAuth callback server fails.
|
| 84 |
+
ErrServerStartFailed = &AuthenticationError{
|
| 85 |
+
Type: "server_start_failed",
|
| 86 |
+
Message: "Failed to start OAuth callback server",
|
| 87 |
+
Code: http.StatusInternalServerError,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// ErrPortInUse represents an error when the OAuth callback port is already in use.
|
| 91 |
+
ErrPortInUse = &AuthenticationError{
|
| 92 |
+
Type: "port_in_use",
|
| 93 |
+
Message: "OAuth callback port is already in use",
|
| 94 |
+
Code: 13, // Special exit code for port-in-use
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
|
| 98 |
+
ErrCallbackTimeout = &AuthenticationError{
|
| 99 |
+
Type: "callback_timeout",
|
| 100 |
+
Message: "Timeout waiting for OAuth callback",
|
| 101 |
+
Code: http.StatusRequestTimeout,
|
| 102 |
+
}
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
|
| 106 |
+
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
| 107 |
+
return &AuthenticationError{
|
| 108 |
+
Type: baseErr.Type,
|
| 109 |
+
Message: baseErr.Message,
|
| 110 |
+
Code: baseErr.Code,
|
| 111 |
+
Cause: cause,
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// IsAuthenticationError checks if an error is an authentication error.
|
| 116 |
+
func IsAuthenticationError(err error) bool {
|
| 117 |
+
var authenticationError *AuthenticationError
|
| 118 |
+
ok := errors.As(err, &authenticationError)
|
| 119 |
+
return ok
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// IsOAuthError checks if an error is an OAuth error.
|
| 123 |
+
func IsOAuthError(err error) bool {
|
| 124 |
+
var oAuthError *OAuthError
|
| 125 |
+
ok := errors.As(err, &oAuthError)
|
| 126 |
+
return ok
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
|
| 130 |
+
func GetUserFriendlyMessage(err error) string {
|
| 131 |
+
switch {
|
| 132 |
+
case IsAuthenticationError(err):
|
| 133 |
+
var authErr *AuthenticationError
|
| 134 |
+
errors.As(err, &authErr)
|
| 135 |
+
switch authErr.Type {
|
| 136 |
+
case "token_expired":
|
| 137 |
+
return "Your authentication has expired. Please log in again."
|
| 138 |
+
case "token_invalid":
|
| 139 |
+
return "Your authentication is invalid. Please log in again."
|
| 140 |
+
case "authentication_required":
|
| 141 |
+
return "Please log in to continue."
|
| 142 |
+
case "port_in_use":
|
| 143 |
+
return "The required port is already in use. Please close any applications using port 3000 and try again."
|
| 144 |
+
case "callback_timeout":
|
| 145 |
+
return "Authentication timed out. Please try again."
|
| 146 |
+
case "browser_open_failed":
|
| 147 |
+
return "Could not open your browser automatically. Please copy and paste the URL manually."
|
| 148 |
+
default:
|
| 149 |
+
return "Authentication failed. Please try again."
|
| 150 |
+
}
|
| 151 |
+
case IsOAuthError(err):
|
| 152 |
+
var oauthErr *OAuthError
|
| 153 |
+
errors.As(err, &oauthErr)
|
| 154 |
+
switch oauthErr.Code {
|
| 155 |
+
case "access_denied":
|
| 156 |
+
return "Authentication was cancelled or denied."
|
| 157 |
+
case "invalid_request":
|
| 158 |
+
return "Invalid authentication request. Please try again."
|
| 159 |
+
case "server_error":
|
| 160 |
+
return "Authentication server error. Please try again later."
|
| 161 |
+
default:
|
| 162 |
+
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
| 163 |
+
}
|
| 164 |
+
default:
|
| 165 |
+
return "An unexpected error occurred. Please try again."
|
| 166 |
+
}
|
| 167 |
+
}
|
internal/auth/claude/html_templates.go
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides authentication and token management functionality
|
| 2 |
+
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
| 3 |
+
// and retrieval for maintaining authenticated sessions with the Claude API.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
// LoginSuccessHtml is the HTML template displayed to users after successful OAuth authentication.
|
| 7 |
+
// This template provides a user-friendly success page with options to close the window
|
| 8 |
+
// or navigate to the Claude platform. It includes automatic window closing functionality
|
| 9 |
+
// and keyboard accessibility features.
|
| 10 |
+
const LoginSuccessHtml = `<!DOCTYPE html>
|
| 11 |
+
<html lang="en">
|
| 12 |
+
<head>
|
| 13 |
+
<meta charset="UTF-8">
|
| 14 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 15 |
+
<title>Authentication Successful - Claude</title>
|
| 16 |
+
<link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%2310b981'%3E%3Cpath d='M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E">
|
| 17 |
+
<style>
|
| 18 |
+
* {
|
| 19 |
+
box-sizing: border-box;
|
| 20 |
+
}
|
| 21 |
+
body {
|
| 22 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
| 23 |
+
display: flex;
|
| 24 |
+
justify-content: center;
|
| 25 |
+
align-items: center;
|
| 26 |
+
min-height: 100vh;
|
| 27 |
+
margin: 0;
|
| 28 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 29 |
+
padding: 1rem;
|
| 30 |
+
}
|
| 31 |
+
.container {
|
| 32 |
+
text-align: center;
|
| 33 |
+
background: white;
|
| 34 |
+
padding: 2.5rem;
|
| 35 |
+
border-radius: 12px;
|
| 36 |
+
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
|
| 37 |
+
max-width: 480px;
|
| 38 |
+
width: 100%;
|
| 39 |
+
animation: slideIn 0.3s ease-out;
|
| 40 |
+
}
|
| 41 |
+
@keyframes slideIn {
|
| 42 |
+
from {
|
| 43 |
+
opacity: 0;
|
| 44 |
+
transform: translateY(-20px);
|
| 45 |
+
}
|
| 46 |
+
to {
|
| 47 |
+
opacity: 1;
|
| 48 |
+
transform: translateY(0);
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
.success-icon {
|
| 52 |
+
width: 64px;
|
| 53 |
+
height: 64px;
|
| 54 |
+
margin: 0 auto 1.5rem;
|
| 55 |
+
background: #10b981;
|
| 56 |
+
border-radius: 50%;
|
| 57 |
+
display: flex;
|
| 58 |
+
align-items: center;
|
| 59 |
+
justify-content: center;
|
| 60 |
+
color: white;
|
| 61 |
+
font-size: 2rem;
|
| 62 |
+
font-weight: bold;
|
| 63 |
+
}
|
| 64 |
+
h1 {
|
| 65 |
+
color: #1f2937;
|
| 66 |
+
margin-bottom: 1rem;
|
| 67 |
+
font-size: 1.75rem;
|
| 68 |
+
font-weight: 600;
|
| 69 |
+
}
|
| 70 |
+
.subtitle {
|
| 71 |
+
color: #6b7280;
|
| 72 |
+
margin-bottom: 1.5rem;
|
| 73 |
+
font-size: 1rem;
|
| 74 |
+
line-height: 1.5;
|
| 75 |
+
}
|
| 76 |
+
.setup-notice {
|
| 77 |
+
background: #fef3c7;
|
| 78 |
+
border: 1px solid #f59e0b;
|
| 79 |
+
border-radius: 6px;
|
| 80 |
+
padding: 1rem;
|
| 81 |
+
margin: 1rem 0;
|
| 82 |
+
}
|
| 83 |
+
.setup-notice h3 {
|
| 84 |
+
color: #92400e;
|
| 85 |
+
margin: 0 0 0.5rem 0;
|
| 86 |
+
font-size: 1rem;
|
| 87 |
+
}
|
| 88 |
+
.setup-notice p {
|
| 89 |
+
color: #92400e;
|
| 90 |
+
margin: 0;
|
| 91 |
+
font-size: 0.875rem;
|
| 92 |
+
}
|
| 93 |
+
.setup-notice a {
|
| 94 |
+
color: #1d4ed8;
|
| 95 |
+
text-decoration: none;
|
| 96 |
+
}
|
| 97 |
+
.setup-notice a:hover {
|
| 98 |
+
text-decoration: underline;
|
| 99 |
+
}
|
| 100 |
+
.actions {
|
| 101 |
+
display: flex;
|
| 102 |
+
gap: 1rem;
|
| 103 |
+
justify-content: center;
|
| 104 |
+
flex-wrap: wrap;
|
| 105 |
+
margin-top: 2rem;
|
| 106 |
+
}
|
| 107 |
+
.button {
|
| 108 |
+
padding: 0.75rem 1.5rem;
|
| 109 |
+
border-radius: 8px;
|
| 110 |
+
font-size: 0.875rem;
|
| 111 |
+
font-weight: 500;
|
| 112 |
+
text-decoration: none;
|
| 113 |
+
transition: all 0.2s;
|
| 114 |
+
cursor: pointer;
|
| 115 |
+
border: none;
|
| 116 |
+
display: inline-flex;
|
| 117 |
+
align-items: center;
|
| 118 |
+
gap: 0.5rem;
|
| 119 |
+
}
|
| 120 |
+
.button-primary {
|
| 121 |
+
background: #3b82f6;
|
| 122 |
+
color: white;
|
| 123 |
+
}
|
| 124 |
+
.button-primary:hover {
|
| 125 |
+
background: #2563eb;
|
| 126 |
+
transform: translateY(-1px);
|
| 127 |
+
}
|
| 128 |
+
.button-secondary {
|
| 129 |
+
background: #f3f4f6;
|
| 130 |
+
color: #374151;
|
| 131 |
+
border: 1px solid #d1d5db;
|
| 132 |
+
}
|
| 133 |
+
.button-secondary:hover {
|
| 134 |
+
background: #e5e7eb;
|
| 135 |
+
}
|
| 136 |
+
.countdown {
|
| 137 |
+
color: #9ca3af;
|
| 138 |
+
font-size: 0.75rem;
|
| 139 |
+
margin-top: 1rem;
|
| 140 |
+
}
|
| 141 |
+
.footer {
|
| 142 |
+
margin-top: 2rem;
|
| 143 |
+
padding-top: 1.5rem;
|
| 144 |
+
border-top: 1px solid #e5e7eb;
|
| 145 |
+
color: #9ca3af;
|
| 146 |
+
font-size: 0.75rem;
|
| 147 |
+
}
|
| 148 |
+
.footer a {
|
| 149 |
+
color: #3b82f6;
|
| 150 |
+
text-decoration: none;
|
| 151 |
+
}
|
| 152 |
+
.footer a:hover {
|
| 153 |
+
text-decoration: underline;
|
| 154 |
+
}
|
| 155 |
+
</style>
|
| 156 |
+
</head>
|
| 157 |
+
<body>
|
| 158 |
+
<div class="container">
|
| 159 |
+
<div class="success-icon">✓</div>
|
| 160 |
+
<h1>Authentication Successful!</h1>
|
| 161 |
+
<p class="subtitle">You have successfully authenticated with Claude. You can now close this window and return to your terminal to continue.</p>
|
| 162 |
+
|
| 163 |
+
{{SETUP_NOTICE}}
|
| 164 |
+
|
| 165 |
+
<div class="actions">
|
| 166 |
+
<button class="button button-primary" onclick="window.close()">
|
| 167 |
+
<span>Close Window</span>
|
| 168 |
+
</button>
|
| 169 |
+
<a href="{{PLATFORM_URL}}" target="_blank" class="button button-secondary">
|
| 170 |
+
<span>Open Platform</span>
|
| 171 |
+
<span>↗</span>
|
| 172 |
+
</a>
|
| 173 |
+
</div>
|
| 174 |
+
|
| 175 |
+
<div class="countdown">
|
| 176 |
+
This window will close automatically in <span id="countdown">10</span> seconds
|
| 177 |
+
</div>
|
| 178 |
+
|
| 179 |
+
<div class="footer">
|
| 180 |
+
<p>Powered by <a href="https://chatgpt.com" target="_blank">ChatGPT</a></p>
|
| 181 |
+
</div>
|
| 182 |
+
</div>
|
| 183 |
+
|
| 184 |
+
<script>
|
| 185 |
+
let countdown = 10;
|
| 186 |
+
const countdownElement = document.getElementById('countdown');
|
| 187 |
+
|
| 188 |
+
const timer = setInterval(() => {
|
| 189 |
+
countdown--;
|
| 190 |
+
countdownElement.textContent = countdown;
|
| 191 |
+
|
| 192 |
+
if (countdown <= 0) {
|
| 193 |
+
clearInterval(timer);
|
| 194 |
+
window.close();
|
| 195 |
+
}
|
| 196 |
+
}, 1000);
|
| 197 |
+
|
| 198 |
+
// Close window when user presses Escape
|
| 199 |
+
document.addEventListener('keydown', (e) => {
|
| 200 |
+
if (e.key === 'Escape') {
|
| 201 |
+
window.close();
|
| 202 |
+
}
|
| 203 |
+
});
|
| 204 |
+
|
| 205 |
+
// Focus the close button for keyboard accessibility
|
| 206 |
+
document.querySelector('.button-primary').focus();
|
| 207 |
+
</script>
|
| 208 |
+
</body>
|
| 209 |
+
</html>`
|
| 210 |
+
|
| 211 |
+
// SetupNoticeHtml is the HTML template for the setup notice section.
|
| 212 |
+
// This template is embedded within the success page to inform users about
|
| 213 |
+
// additional setup steps required to complete their Claude account configuration.
|
| 214 |
+
const SetupNoticeHtml = `
|
| 215 |
+
<div class="setup-notice">
|
| 216 |
+
<h3>Additional Setup Required</h3>
|
| 217 |
+
<p>To complete your setup, please visit the <a href="{{PLATFORM_URL}}" target="_blank">Claude</a> to configure your account.</p>
|
| 218 |
+
</div>`
|
internal/auth/claude/oauth_server.go
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides authentication and token management functionality
|
| 2 |
+
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
| 3 |
+
// and retrieval for maintaining authenticated sessions with the Claude API.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"context"
|
| 8 |
+
"errors"
|
| 9 |
+
"fmt"
|
| 10 |
+
"net"
|
| 11 |
+
"net/http"
|
| 12 |
+
"strings"
|
| 13 |
+
"sync"
|
| 14 |
+
"time"
|
| 15 |
+
|
| 16 |
+
log "github.com/sirupsen/logrus"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
// OAuthServer handles the local HTTP server for OAuth callbacks.
|
| 20 |
+
// It listens for the authorization code response from the OAuth provider
|
| 21 |
+
// and captures the necessary parameters to complete the authentication flow.
|
| 22 |
+
type OAuthServer struct {
|
| 23 |
+
// server is the underlying HTTP server instance
|
| 24 |
+
server *http.Server
|
| 25 |
+
// port is the port number on which the server listens
|
| 26 |
+
port int
|
| 27 |
+
// resultChan is a channel for sending OAuth results
|
| 28 |
+
resultChan chan *OAuthResult
|
| 29 |
+
// errorChan is a channel for sending OAuth errors
|
| 30 |
+
errorChan chan error
|
| 31 |
+
// mu is a mutex for protecting server state
|
| 32 |
+
mu sync.Mutex
|
| 33 |
+
// running indicates whether the server is currently running
|
| 34 |
+
running bool
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// OAuthResult contains the result of the OAuth callback.
|
| 38 |
+
// It holds either the authorization code and state for successful authentication
|
| 39 |
+
// or an error message if the authentication failed.
|
| 40 |
+
type OAuthResult struct {
|
| 41 |
+
// Code is the authorization code received from the OAuth provider
|
| 42 |
+
Code string
|
| 43 |
+
// State is the state parameter used to prevent CSRF attacks
|
| 44 |
+
State string
|
| 45 |
+
// Error contains any error message if the OAuth flow failed
|
| 46 |
+
Error string
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// NewOAuthServer creates a new OAuth callback server.
|
| 50 |
+
// It initializes the server with the specified port and creates channels
|
| 51 |
+
// for handling OAuth results and errors.
|
| 52 |
+
//
|
| 53 |
+
// Parameters:
|
| 54 |
+
// - port: The port number on which the server should listen
|
| 55 |
+
//
|
| 56 |
+
// Returns:
|
| 57 |
+
// - *OAuthServer: A new OAuthServer instance
|
| 58 |
+
func NewOAuthServer(port int) *OAuthServer {
|
| 59 |
+
return &OAuthServer{
|
| 60 |
+
port: port,
|
| 61 |
+
resultChan: make(chan *OAuthResult, 1),
|
| 62 |
+
errorChan: make(chan error, 1),
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// Start starts the OAuth callback server.
|
| 67 |
+
// It sets up the HTTP handlers for the callback and success endpoints,
|
| 68 |
+
// and begins listening on the specified port.
|
| 69 |
+
//
|
| 70 |
+
// Returns:
|
| 71 |
+
// - error: An error if the server fails to start
|
| 72 |
+
func (s *OAuthServer) Start() error {
|
| 73 |
+
s.mu.Lock()
|
| 74 |
+
defer s.mu.Unlock()
|
| 75 |
+
|
| 76 |
+
if s.running {
|
| 77 |
+
return fmt.Errorf("server is already running")
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// Check if port is available
|
| 81 |
+
if !s.isPortAvailable() {
|
| 82 |
+
return fmt.Errorf("port %d is already in use", s.port)
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
mux := http.NewServeMux()
|
| 86 |
+
mux.HandleFunc("/callback", s.handleCallback)
|
| 87 |
+
mux.HandleFunc("/success", s.handleSuccess)
|
| 88 |
+
|
| 89 |
+
s.server = &http.Server{
|
| 90 |
+
Addr: fmt.Sprintf(":%d", s.port),
|
| 91 |
+
Handler: mux,
|
| 92 |
+
ReadTimeout: 10 * time.Second,
|
| 93 |
+
WriteTimeout: 10 * time.Second,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
s.running = true
|
| 97 |
+
|
| 98 |
+
// Start server in goroutine
|
| 99 |
+
go func() {
|
| 100 |
+
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
| 101 |
+
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
|
| 102 |
+
}
|
| 103 |
+
}()
|
| 104 |
+
|
| 105 |
+
// Give server a moment to start
|
| 106 |
+
time.Sleep(100 * time.Millisecond)
|
| 107 |
+
|
| 108 |
+
return nil
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Stop gracefully stops the OAuth callback server.
|
| 112 |
+
// It performs a graceful shutdown of the HTTP server with a timeout.
|
| 113 |
+
//
|
| 114 |
+
// Parameters:
|
| 115 |
+
// - ctx: The context for controlling the shutdown process
|
| 116 |
+
//
|
| 117 |
+
// Returns:
|
| 118 |
+
// - error: An error if the server fails to stop gracefully
|
| 119 |
+
func (s *OAuthServer) Stop(ctx context.Context) error {
|
| 120 |
+
s.mu.Lock()
|
| 121 |
+
defer s.mu.Unlock()
|
| 122 |
+
|
| 123 |
+
if !s.running || s.server == nil {
|
| 124 |
+
return nil
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
log.Debug("Stopping OAuth callback server")
|
| 128 |
+
|
| 129 |
+
// Create a context with timeout for shutdown
|
| 130 |
+
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
| 131 |
+
defer cancel()
|
| 132 |
+
|
| 133 |
+
err := s.server.Shutdown(shutdownCtx)
|
| 134 |
+
s.running = false
|
| 135 |
+
s.server = nil
|
| 136 |
+
|
| 137 |
+
return err
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// WaitForCallback waits for the OAuth callback with a timeout.
|
| 141 |
+
// It blocks until either an OAuth result is received, an error occurs,
|
| 142 |
+
// or the specified timeout is reached.
|
| 143 |
+
//
|
| 144 |
+
// Parameters:
|
| 145 |
+
// - timeout: The maximum time to wait for the callback
|
| 146 |
+
//
|
| 147 |
+
// Returns:
|
| 148 |
+
// - *OAuthResult: The OAuth result if successful
|
| 149 |
+
// - error: An error if the callback times out or an error occurs
|
| 150 |
+
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
| 151 |
+
select {
|
| 152 |
+
case result := <-s.resultChan:
|
| 153 |
+
return result, nil
|
| 154 |
+
case err := <-s.errorChan:
|
| 155 |
+
return nil, err
|
| 156 |
+
case <-time.After(timeout):
|
| 157 |
+
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// handleCallback handles the OAuth callback endpoint.
|
| 162 |
+
// It extracts the authorization code and state from the callback URL,
|
| 163 |
+
// validates the parameters, and sends the result to the waiting channel.
|
| 164 |
+
//
|
| 165 |
+
// Parameters:
|
| 166 |
+
// - w: The HTTP response writer
|
| 167 |
+
// - r: The HTTP request
|
| 168 |
+
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
| 169 |
+
log.Debug("Received OAuth callback")
|
| 170 |
+
|
| 171 |
+
// Validate request method
|
| 172 |
+
if r.Method != http.MethodGet {
|
| 173 |
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
| 174 |
+
return
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Extract parameters
|
| 178 |
+
query := r.URL.Query()
|
| 179 |
+
code := query.Get("code")
|
| 180 |
+
state := query.Get("state")
|
| 181 |
+
errorParam := query.Get("error")
|
| 182 |
+
|
| 183 |
+
// Validate required parameters
|
| 184 |
+
if errorParam != "" {
|
| 185 |
+
log.Errorf("OAuth error received: %s", errorParam)
|
| 186 |
+
result := &OAuthResult{
|
| 187 |
+
Error: errorParam,
|
| 188 |
+
}
|
| 189 |
+
s.sendResult(result)
|
| 190 |
+
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
|
| 191 |
+
return
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
if code == "" {
|
| 195 |
+
log.Error("No authorization code received")
|
| 196 |
+
result := &OAuthResult{
|
| 197 |
+
Error: "no_code",
|
| 198 |
+
}
|
| 199 |
+
s.sendResult(result)
|
| 200 |
+
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
| 201 |
+
return
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
if state == "" {
|
| 205 |
+
log.Error("No state parameter received")
|
| 206 |
+
result := &OAuthResult{
|
| 207 |
+
Error: "no_state",
|
| 208 |
+
}
|
| 209 |
+
s.sendResult(result)
|
| 210 |
+
http.Error(w, "No state parameter received", http.StatusBadRequest)
|
| 211 |
+
return
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Send successful result
|
| 215 |
+
result := &OAuthResult{
|
| 216 |
+
Code: code,
|
| 217 |
+
State: state,
|
| 218 |
+
}
|
| 219 |
+
s.sendResult(result)
|
| 220 |
+
|
| 221 |
+
// Redirect to success page
|
| 222 |
+
http.Redirect(w, r, "/success", http.StatusFound)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
// handleSuccess handles the success page endpoint.
|
| 226 |
+
// It serves a user-friendly HTML page indicating that authentication was successful.
|
| 227 |
+
//
|
| 228 |
+
// Parameters:
|
| 229 |
+
// - w: The HTTP response writer
|
| 230 |
+
// - r: The HTTP request
|
| 231 |
+
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
| 232 |
+
log.Debug("Serving success page")
|
| 233 |
+
|
| 234 |
+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
| 235 |
+
w.WriteHeader(http.StatusOK)
|
| 236 |
+
|
| 237 |
+
// Parse query parameters for customization
|
| 238 |
+
query := r.URL.Query()
|
| 239 |
+
setupRequired := query.Get("setup_required") == "true"
|
| 240 |
+
platformURL := query.Get("platform_url")
|
| 241 |
+
if platformURL == "" {
|
| 242 |
+
platformURL = "https://console.anthropic.com/"
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// Validate platformURL to prevent XSS - only allow http/https URLs
|
| 246 |
+
if !isValidURL(platformURL) {
|
| 247 |
+
platformURL = "https://console.anthropic.com/"
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// Generate success page HTML with dynamic content
|
| 251 |
+
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
| 252 |
+
|
| 253 |
+
_, err := w.Write([]byte(successHTML))
|
| 254 |
+
if err != nil {
|
| 255 |
+
log.Errorf("Failed to write success page: %v", err)
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
| 260 |
+
func isValidURL(urlStr string) bool {
|
| 261 |
+
urlStr = strings.TrimSpace(urlStr)
|
| 262 |
+
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// generateSuccessHTML creates the HTML content for the success page.
|
| 266 |
+
// It customizes the page based on whether additional setup is required
|
| 267 |
+
// and includes a link to the platform.
|
| 268 |
+
//
|
| 269 |
+
// Parameters:
|
| 270 |
+
// - setupRequired: Whether additional setup is required after authentication
|
| 271 |
+
// - platformURL: The URL to the platform for additional setup
|
| 272 |
+
//
|
| 273 |
+
// Returns:
|
| 274 |
+
// - string: The HTML content for the success page
|
| 275 |
+
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
|
| 276 |
+
html := LoginSuccessHtml
|
| 277 |
+
|
| 278 |
+
// Replace platform URL placeholder
|
| 279 |
+
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
|
| 280 |
+
|
| 281 |
+
// Add setup notice if required
|
| 282 |
+
if setupRequired {
|
| 283 |
+
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
|
| 284 |
+
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
|
| 285 |
+
} else {
|
| 286 |
+
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
return html
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
// sendResult sends the OAuth result to the waiting channel.
|
| 293 |
+
// It ensures that the result is sent without blocking the handler.
|
| 294 |
+
//
|
| 295 |
+
// Parameters:
|
| 296 |
+
// - result: The OAuth result to send
|
| 297 |
+
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
| 298 |
+
select {
|
| 299 |
+
case s.resultChan <- result:
|
| 300 |
+
log.Debug("OAuth result sent to channel")
|
| 301 |
+
default:
|
| 302 |
+
log.Warn("OAuth result channel is full, result dropped")
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
// isPortAvailable checks if the specified port is available.
|
| 307 |
+
// It attempts to listen on the port to determine availability.
|
| 308 |
+
//
|
| 309 |
+
// Returns:
|
| 310 |
+
// - bool: True if the port is available, false otherwise
|
| 311 |
+
func (s *OAuthServer) isPortAvailable() bool {
|
| 312 |
+
addr := fmt.Sprintf(":%d", s.port)
|
| 313 |
+
listener, err := net.Listen("tcp", addr)
|
| 314 |
+
if err != nil {
|
| 315 |
+
return false
|
| 316 |
+
}
|
| 317 |
+
defer func() {
|
| 318 |
+
_ = listener.Close()
|
| 319 |
+
}()
|
| 320 |
+
return true
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
// IsRunning returns whether the server is currently running.
|
| 324 |
+
//
|
| 325 |
+
// Returns:
|
| 326 |
+
// - bool: True if the server is running, false otherwise
|
| 327 |
+
func (s *OAuthServer) IsRunning() bool {
|
| 328 |
+
s.mu.Lock()
|
| 329 |
+
defer s.mu.Unlock()
|
| 330 |
+
return s.running
|
| 331 |
+
}
|
internal/auth/claude/pkce.go
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides authentication and token management functionality
|
| 2 |
+
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
| 3 |
+
// and retrieval for maintaining authenticated sessions with the Claude API.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"crypto/rand"
|
| 8 |
+
"crypto/sha256"
|
| 9 |
+
"encoding/base64"
|
| 10 |
+
"fmt"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
// GeneratePKCECodes generates a PKCE code verifier and challenge pair
|
| 14 |
+
// following RFC 7636 specifications for OAuth 2.0 PKCE extension.
|
| 15 |
+
// This provides additional security for the OAuth flow by ensuring that
|
| 16 |
+
// only the client that initiated the request can exchange the authorization code.
|
| 17 |
+
//
|
| 18 |
+
// Returns:
|
| 19 |
+
// - *PKCECodes: A struct containing the code verifier and challenge
|
| 20 |
+
// - error: An error if the generation fails, nil otherwise
|
| 21 |
+
func GeneratePKCECodes() (*PKCECodes, error) {
|
| 22 |
+
// Generate code verifier: 43-128 characters, URL-safe
|
| 23 |
+
codeVerifier, err := generateCodeVerifier()
|
| 24 |
+
if err != nil {
|
| 25 |
+
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// Generate code challenge using S256 method
|
| 29 |
+
codeChallenge := generateCodeChallenge(codeVerifier)
|
| 30 |
+
|
| 31 |
+
return &PKCECodes{
|
| 32 |
+
CodeVerifier: codeVerifier,
|
| 33 |
+
CodeChallenge: codeChallenge,
|
| 34 |
+
}, nil
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// generateCodeVerifier creates a cryptographically random string
|
| 38 |
+
// of 128 characters using URL-safe base64 encoding
|
| 39 |
+
func generateCodeVerifier() (string, error) {
|
| 40 |
+
// Generate 96 random bytes (will result in 128 base64 characters)
|
| 41 |
+
bytes := make([]byte, 96)
|
| 42 |
+
_, err := rand.Read(bytes)
|
| 43 |
+
if err != nil {
|
| 44 |
+
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// Encode to URL-safe base64 without padding
|
| 48 |
+
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// generateCodeChallenge creates a SHA256 hash of the code verifier
|
| 52 |
+
// and encodes it using URL-safe base64 encoding without padding
|
| 53 |
+
func generateCodeChallenge(codeVerifier string) string {
|
| 54 |
+
hash := sha256.Sum256([]byte(codeVerifier))
|
| 55 |
+
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
| 56 |
+
}
|
internal/auth/claude/token.go
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Package claude provides authentication and token management functionality
|
| 2 |
+
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
|
| 3 |
+
// and retrieval for maintaining authenticated sessions with the Claude API.
|
| 4 |
+
package claude
|
| 5 |
+
|
| 6 |
+
import (
|
| 7 |
+
"encoding/json"
|
| 8 |
+
"fmt"
|
| 9 |
+
"os"
|
| 10 |
+
"path/filepath"
|
| 11 |
+
|
| 12 |
+
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
// ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication.
|
| 16 |
+
// It maintains compatibility with the existing auth system while adding Claude-specific fields
|
| 17 |
+
// for managing access tokens, refresh tokens, and user account information.
|
| 18 |
+
type ClaudeTokenStorage struct {
|
| 19 |
+
// IDToken is the JWT ID token containing user claims and identity information.
|
| 20 |
+
IDToken string `json:"id_token"`
|
| 21 |
+
|
| 22 |
+
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
| 23 |
+
AccessToken string `json:"access_token"`
|
| 24 |
+
|
| 25 |
+
// RefreshToken is used to obtain new access tokens when the current one expires.
|
| 26 |
+
RefreshToken string `json:"refresh_token"`
|
| 27 |
+
|
| 28 |
+
// LastRefresh is the timestamp of the last token refresh operation.
|
| 29 |
+
LastRefresh string `json:"last_refresh"`
|
| 30 |
+
|
| 31 |
+
// Email is the Anthropic account email address associated with this token.
|
| 32 |
+
Email string `json:"email"`
|
| 33 |
+
|
| 34 |
+
// Type indicates the authentication provider type, always "claude" for this storage.
|
| 35 |
+
Type string `json:"type"`
|
| 36 |
+
|
| 37 |
+
// Expire is the timestamp when the current access token expires.
|
| 38 |
+
Expire string `json:"expired"`
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
| 42 |
+
// This method creates the necessary directory structure and writes the token
|
| 43 |
+
// data in JSON format to the specified file path for persistent storage.
|
| 44 |
+
//
|
| 45 |
+
// Parameters:
|
| 46 |
+
// - authFilePath: The full path where the token file should be saved
|
| 47 |
+
//
|
| 48 |
+
// Returns:
|
| 49 |
+
// - error: An error if the operation fails, nil otherwise
|
| 50 |
+
func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
| 51 |
+
misc.LogSavingCredentials(authFilePath)
|
| 52 |
+
ts.Type = "claude"
|
| 53 |
+
|
| 54 |
+
// Create directory structure if it doesn't exist
|
| 55 |
+
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
| 56 |
+
return fmt.Errorf("failed to create directory: %v", err)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// Create the token file
|
| 60 |
+
f, err := os.Create(authFilePath)
|
| 61 |
+
if err != nil {
|
| 62 |
+
return fmt.Errorf("failed to create token file: %w", err)
|
| 63 |
+
}
|
| 64 |
+
defer func() {
|
| 65 |
+
_ = f.Close()
|
| 66 |
+
}()
|
| 67 |
+
|
| 68 |
+
// Encode and write the token data as JSON
|
| 69 |
+
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
| 70 |
+
return fmt.Errorf("failed to write token to file: %w", err)
|
| 71 |
+
}
|
| 72 |
+
return nil
|
| 73 |
+
}
|