package main import ( "augment2api/api" "augment2api/config" "augment2api/middleware" "augment2api/pkg/logger" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "log" "net/http" "net/url" "strings" "time" "github.com/gin-gonic/gin" ) const clientID = "v" // OAuthState 存储OAuth状态信息 type OAuthState struct { CodeVerifier string `json:"code_verifier"` CodeChallenge string `json:"code_challenge"` State string `json:"state"` CreationTime time.Time `json:"creation_time"` } // 全局变量存储OAuth状态 var ( globalOAuthState OAuthState ) // base64URLEncode 编码Buffer为base64 URL安全格式 func base64URLEncode(data []byte) string { encoded := base64.StdEncoding.EncodeToString(data) encoded = strings.ReplaceAll(encoded, "+", "-") encoded = strings.ReplaceAll(encoded, "/", "_") encoded = strings.ReplaceAll(encoded, "=", "") return encoded } // sha256Hash 计算SHA256哈希 func sha256Hash(input []byte) []byte { hash := sha256.Sum256(input) return hash[:] } // createOAuthState 创建OAuth状态 func createOAuthState() OAuthState { codeVerifierBytes := make([]byte, 32) _, err := rand.Read(codeVerifierBytes) if err != nil { log.Fatalf("生成随机字节失败: %v", err) } codeVerifier := base64URLEncode(codeVerifierBytes) codeChallenge := base64URLEncode(sha256Hash([]byte(codeVerifier))) stateBytes := make([]byte, 8) _, err = rand.Read(stateBytes) if err != nil { log.Fatalf("生成随机状态失败: %v", err) } state := base64URLEncode(stateBytes) return OAuthState{ CodeVerifier: codeVerifier, CodeChallenge: codeChallenge, State: state, CreationTime: time.Now(), } } // generateAuthorizeURL 生成授权URL func generateAuthorizeURL(oauthState OAuthState) string { params := url.Values{} params.Add("response_type", "code") params.Add("code_challenge", oauthState.CodeChallenge) params.Add("client_id", clientID) params.Add("state", oauthState.State) params.Add("prompt", "login") authorizeURL := fmt.Sprintf("https://auth.augmentcode.com/authorize?%s", params.Encode()) return authorizeURL } // getAccessToken 获取访问令牌 func getAccessToken(tenantURL, codeVerifier, code string) (string, error) { data := map[string]string{ "grant_type": "authorization_code", "client_id": clientID, "code_verifier": codeVerifier, "redirect_uri": "", "code": code, } jsonData, err := json.Marshal(data) if err != nil { return "", fmt.Errorf("序列化数据失败: %v", err) } resp, err := http.Post(tenantURL+"token", "application/json", strings.NewReader(string(jsonData))) if err != nil { return "", fmt.Errorf("请求令牌失败: %v", err) } defer resp.Body.Close() var result map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return "", fmt.Errorf("解析响应失败: %v", err) } token, ok := result["access_token"].(string) if !ok { return "", fmt.Errorf("响应中没有访问令牌") } return token, nil } // 初始化路由 func setupRouter() *gin.Engine { r := gin.Default() // 跨域 r.Use(middleware.CORS()) // 初始化OAuth状态 globalOAuthState = createOAuthState() // 静态文件服务 r.Static("/static", "./static") r.LoadHTMLGlob("templates/*") // 登录页面 r.GET("/login", func(c *gin.Context) { c.HTML(http.StatusOK, "login.html", gin.H{}) }) // 登录 r.POST("/api/login", api.LoginHandler) // 登出 r.POST("/api/logout", api.LogoutHandler) // 管理页面 - 需要会话验证 r.GET("/", func(c *gin.Context) { // 如果设置了访问密码,检查是否已登录 if config.AppConfig.AccessPwd != "" { // 从查询参数或Cookie中获取会话令牌 token := c.Query("token") if token == "" { // 尝试从Cookie获取 token, _ = c.Cookie("auth_token") } // 从请求头获取 if token == "" { token = c.GetHeader("X-Auth-Token") } // 验证会话令牌 if !api.ValidateToken(token) { c.Redirect(http.StatusFound, "/login") return } } c.HTML(http.StatusOK, "admin.html", gin.H{}) }) // 管理页面 - 需要会话验证 r.GET("/admin", api.AuthTokenMiddleware(), func(c *gin.Context) { c.HTML(http.StatusOK, "admin.html", gin.H{}) }) // 授权端点 - 需要会话验证 r.GET("/auth", api.AuthTokenMiddleware(), func(c *gin.Context) { authorizeURL := generateAuthorizeURL(globalOAuthState) api.AuthHandler(c, authorizeURL) }) // 获取token - 需要会话验证 r.GET("/api/tokens", api.AuthTokenMiddleware(), api.GetRedisTokenHandler) // 删除token - 需要会话验证 r.DELETE("/api/token/:token", api.AuthTokenMiddleware(), api.DeleteTokenHandler) // 更新token备注 - 需要会话验证 r.PUT("/api/token/:token/remark", api.AuthTokenMiddleware(), api.UpdateTokenRemark) // 批量检测token - 需要会话验证 r.GET("/api/check-tokens", api.AuthTokenMiddleware(), api.CheckAllTokensHandler) // 回调端点,用于处理授权码 - 需要会话验证 r.POST("/callback", api.AuthTokenMiddleware(), func(c *gin.Context) { api.CallbackHandler(c, func(tenantURL, _, code string) (string, error) { return getAccessToken(tenantURL, globalOAuthState.CodeVerifier, code) }) }) // 鉴权路由组 authGroup := r.Group(ProcessPath(config.AppConfig.RoutePrefix)) authGroup.Use(api.AuthMiddleware()) { // OpenAI兼容的聊天端点 chatGroup := authGroup.Group("/") // 并发控制 chatGroup.Use(middleware.TokenConcurrencyMiddleware()) { chatGroup.POST("/v1/chat/completions", api.ChatCompletionsHandler) chatGroup.POST("/v1", api.ChatCompletionsHandler) chatGroup.POST("/v1/chat", api.ChatCompletionsHandler) } authGroup.GET("/v1/models", api.ModelsHandler) authGroup.POST("/api/add/tokens", api.AddTokenHandler) } return r } func ProcessPath(path string) string { // 判断字符串是否为空 if path == "" { return "" } // 判断开头是否为/,不是则添加 if !strings.HasPrefix(path, "/") { path = "/" + path } // 判断结尾是否为/,是则去掉 if strings.HasSuffix(path, "/") { path = path[:len(path)-1] } return path } func main() { // 设置全局时区为东八区(CST) time.Local = time.FixedZone("CST", 8*3600) // 设置 Gin 为发布模式 gin.SetMode(gin.ReleaseMode) // 初始化日志 logger.Init() // 初始化配置 err := config.InitConfig() if err != nil { logger.Log.Fatalln("failed to initialize config: " + err.Error()) return } // 初始化Redis err = config.InitRedisClient() if err != nil { logger.Log.Fatalln("failed to initialize Redis: " + err.Error()) } // token备注字段迁移 err = api.MigrateTokensRemark() if err != nil { logger.Log.Error("Token备注字段迁移失败: %v", err) } // 启动token使用次数重置调度器 go api.StartTokenUsageResetScheduler() r := setupRouter() // 启动服务器 if err := r.Run(":7860"); err != nil { logger.Log.Fatalf("启动服务失败: %v", err) } logger.Log.WithFields(map[string]interface{}{ "port": 7860, "mode": gin.Mode(), }).Info("Augment2API 服务启动成功") }