augment2api / main.go
github-actions[bot]
Update from GitHub Actions
191a47b
package main
import (
"augment2api/api"
"augment2api/config"
"augment2api/middleware"
"augment2api/pkg/logger"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const clientID = "v"
// OAuthState 存储OAuth状态信息
type OAuthState struct {
CodeVerifier string `json:"code_verifier"`
CodeChallenge string `json:"code_challenge"`
State string `json:"state"`
CreationTime time.Time `json:"creation_time"`
}
// 全局变量存储OAuth状态
var (
globalOAuthState OAuthState
)
// base64URLEncode 编码Buffer为base64 URL安全格式
func base64URLEncode(data []byte) string {
encoded := base64.StdEncoding.EncodeToString(data)
encoded = strings.ReplaceAll(encoded, "+", "-")
encoded = strings.ReplaceAll(encoded, "/", "_")
encoded = strings.ReplaceAll(encoded, "=", "")
return encoded
}
// sha256Hash 计算SHA256哈希
func sha256Hash(input []byte) []byte {
hash := sha256.Sum256(input)
return hash[:]
}
// createOAuthState 创建OAuth状态
func createOAuthState() OAuthState {
codeVerifierBytes := make([]byte, 32)
_, err := rand.Read(codeVerifierBytes)
if err != nil {
log.Fatalf("生成随机字节失败: %v", err)
}
codeVerifier := base64URLEncode(codeVerifierBytes)
codeChallenge := base64URLEncode(sha256Hash([]byte(codeVerifier)))
stateBytes := make([]byte, 8)
_, err = rand.Read(stateBytes)
if err != nil {
log.Fatalf("生成随机状态失败: %v", err)
}
state := base64URLEncode(stateBytes)
return OAuthState{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
State: state,
CreationTime: time.Now(),
}
}
// generateAuthorizeURL 生成授权URL
func generateAuthorizeURL(oauthState OAuthState) string {
params := url.Values{}
params.Add("response_type", "code")
params.Add("code_challenge", oauthState.CodeChallenge)
params.Add("client_id", clientID)
params.Add("state", oauthState.State)
params.Add("prompt", "login")
authorizeURL := fmt.Sprintf("https://auth.augmentcode.com/authorize?%s", params.Encode())
return authorizeURL
}
// getAccessToken 获取访问令牌
func getAccessToken(tenantURL, codeVerifier, code string) (string, error) {
data := map[string]string{
"grant_type": "authorization_code",
"client_id": clientID,
"code_verifier": codeVerifier,
"redirect_uri": "",
"code": code,
}
jsonData, err := json.Marshal(data)
if err != nil {
return "", fmt.Errorf("序列化数据失败: %v", err)
}
resp, err := http.Post(tenantURL+"token", "application/json", strings.NewReader(string(jsonData)))
if err != nil {
return "", fmt.Errorf("请求令牌失败: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
token, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("响应中没有访问令牌")
}
return token, nil
}
// 初始化路由
func setupRouter() *gin.Engine {
r := gin.Default()
// 跨域
r.Use(middleware.CORS())
// 初始化OAuth状态
globalOAuthState = createOAuthState()
// 静态文件服务
r.Static("/static", "./static")
r.LoadHTMLGlob("templates/*")
// 登录页面
r.GET("/login", func(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", gin.H{})
})
// 登录
r.POST("/api/login", api.LoginHandler)
// 登出
r.POST("/api/logout", api.LogoutHandler)
// 管理页面 - 需要会话验证
r.GET("/", func(c *gin.Context) {
// 如果设置了访问密码,检查是否已登录
if config.AppConfig.AccessPwd != "" {
// 从查询参数或Cookie中获取会话令牌
token := c.Query("token")
if token == "" {
// 尝试从Cookie获取
token, _ = c.Cookie("auth_token")
}
// 从请求头获取
if token == "" {
token = c.GetHeader("X-Auth-Token")
}
// 验证会话令牌
if !api.ValidateToken(token) {
c.Redirect(http.StatusFound, "/login")
return
}
}
c.HTML(http.StatusOK, "admin.html", gin.H{})
})
// 管理页面 - 需要会话验证
r.GET("/admin", api.AuthTokenMiddleware(), func(c *gin.Context) {
c.HTML(http.StatusOK, "admin.html", gin.H{})
})
// 授权端点 - 需要会话验证
r.GET("/auth", api.AuthTokenMiddleware(), func(c *gin.Context) {
authorizeURL := generateAuthorizeURL(globalOAuthState)
api.AuthHandler(c, authorizeURL)
})
// 获取token - 需要会话验证
r.GET("/api/tokens", api.AuthTokenMiddleware(), api.GetRedisTokenHandler)
// 删除token - 需要会话验证
r.DELETE("/api/token/:token", api.AuthTokenMiddleware(), api.DeleteTokenHandler)
// 更新token备注 - 需要会话验证
r.PUT("/api/token/:token/remark", api.AuthTokenMiddleware(), api.UpdateTokenRemark)
// 批量检测token - 需要会话验证
r.GET("/api/check-tokens", api.AuthTokenMiddleware(), api.CheckAllTokensHandler)
// 回调端点,用于处理授权码 - 需要会话验证
r.POST("/callback", api.AuthTokenMiddleware(), func(c *gin.Context) {
api.CallbackHandler(c, func(tenantURL, _, code string) (string, error) {
return getAccessToken(tenantURL, globalOAuthState.CodeVerifier, code)
})
})
// 鉴权路由组
authGroup := r.Group(ProcessPath(config.AppConfig.RoutePrefix))
authGroup.Use(api.AuthMiddleware())
{
// OpenAI兼容的聊天端点
chatGroup := authGroup.Group("/")
// 并发控制
chatGroup.Use(middleware.TokenConcurrencyMiddleware())
{
chatGroup.POST("/v1/chat/completions", api.ChatCompletionsHandler)
chatGroup.POST("/v1", api.ChatCompletionsHandler)
chatGroup.POST("/v1/chat", api.ChatCompletionsHandler)
}
authGroup.GET("/v1/models", api.ModelsHandler)
authGroup.POST("/api/add/tokens", api.AddTokenHandler)
}
return r
}
func ProcessPath(path string) string {
// 判断字符串是否为空
if path == "" {
return ""
}
// 判断开头是否为/,不是则添加
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// 判断结尾是否为/,是则去掉
if strings.HasSuffix(path, "/") {
path = path[:len(path)-1]
}
return path
}
func main() {
// 设置全局时区为东八区(CST)
time.Local = time.FixedZone("CST", 8*3600)
// 设置 Gin 为发布模式
gin.SetMode(gin.ReleaseMode)
// 初始化日志
logger.Init()
// 初始化配置
err := config.InitConfig()
if err != nil {
logger.Log.Fatalln("failed to initialize config: " + err.Error())
return
}
// 初始化Redis
err = config.InitRedisClient()
if err != nil {
logger.Log.Fatalln("failed to initialize Redis: " + err.Error())
}
// token备注字段迁移
err = api.MigrateTokensRemark()
if err != nil {
logger.Log.Error("Token备注字段迁移失败: %v", err)
}
// 启动token使用次数重置调度器
go api.StartTokenUsageResetScheduler()
r := setupRouter()
// 启动服务器
if err := r.Run(":7860"); err != nil {
logger.Log.Fatalf("启动服务失败: %v", err)
}
logger.Log.WithFields(map[string]interface{}{
"port": 7860,
"mode": gin.Mode(),
}).Info("Augment2API 服务启动成功")
}