aaxaxax's picture
force update files
1de7911
package handler
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// PKCESession 存储PKCE会话信息
type PKCESession struct {
CodeVerifier string
CreatedAt time.Time
}
// PKCESessionStore 内存中存储PKCE会话
type PKCESessionStore struct {
sync.RWMutex
sessions map[string]*PKCESession
}
// 全局PKCE会话存储
var pkceStore = &PKCESessionStore{
sessions: make(map[string]*PKCESession),
}
// OAuthHandler OAuth相关处理器
type OAuthHandler struct {
}
// NewOAuthHandler 创建OAuth处理器
func NewOAuthHandler() *OAuthHandler {
// 启动清理过期会话的定时器
go cleanupExpiredSessions()
return &OAuthHandler{}
}
// StartOAuthForRT 开始OAuth流程获取RT
func (h *OAuthHandler) StartOAuthForRT(c *gin.Context) {
// 生成PKCE参数
codeVerifier, err := generateCodeVerifier(32)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "生成PKCE参数失败",
})
return
}
// 生成code_challenge
codeChallenge := generateCodeChallenge(codeVerifier)
// 生成会话ID
sessionID, err := generateSessionID()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "生成会话ID失败",
})
return
}
// 存储会话
pkceStore.Lock()
pkceStore.sessions[sessionID] = &PKCESession{
CodeVerifier: codeVerifier,
CreatedAt: time.Now(),
}
pkceStore.Unlock()
// 获取回调URL
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
host := c.Request.Host
callbackURL := fmt.Sprintf("%s://%s/api/oauth/callback-rt?session=%s",
scheme, host, sessionID)
// 构建state参数
state := map[string]string{
"redirectUri": callbackURL,
"codeChallenge": codeChallenge,
"sessionId": sessionID,
}
stateJSON, _ := json.Marshal(state)
// 构建授权URL
params := url.Values{
"state": {string(stateJSON)},
"response_type": {"code"},
"client_id": {"5948a5c5-4b30-4465-a3f2-2136ea53ea0a"},
"scope": {"openid profile email"},
"redirect_uri": {"https://auth.zencoder.ai/extension/auth-success"},
"code_challenge": {codeChallenge},
"code_challenge_method": {"S256"},
}
authURL := fmt.Sprintf("https://fe.zencoder.ai/oauth/authorize?%s", params.Encode())
// 重定向到授权页面
c.Redirect(http.StatusFound, authURL)
}
// CallbackOAuthForRT 处理OAuth回调
func (h *OAuthHandler) CallbackOAuthForRT(c *gin.Context) {
code := c.Query("code")
sessionID := c.Query("session")
// 验证参数
if code == "" || sessionID == "" {
h.renderCallbackPage(c, false, "", "", "缺少必要参数")
return
}
// 获取会话
pkceStore.RLock()
session, exists := pkceStore.sessions[sessionID]
pkceStore.RUnlock()
if !exists {
h.renderCallbackPage(c, false, "", "", "会话已过期,请重新获取")
return
}
// 交换token
tokenResp, err := h.exchangeCodeForToken(code, session.CodeVerifier)
if err != nil {
h.renderCallbackPage(c, false, "", "", fmt.Sprintf("获取Token失败: %v", err))
return
}
// 清理会话
pkceStore.Lock()
delete(pkceStore.sessions, sessionID)
pkceStore.Unlock()
// 渲染成功页面,传递access token和refresh token
h.renderCallbackPage(c, true, tokenResp.AccessToken, tokenResp.RefreshToken, "")
}
// exchangeCodeForToken 用授权码换取token
func (h *OAuthHandler) exchangeCodeForToken(code, codeVerifier string) (*OAuthTokenResponse, error) {
tokenURL := "https://auth.zencoder.ai/api/frontegg/oauth/token"
payload := map[string]string{
"code": code,
"redirect_uri": "https://auth.zencoder.ai/extension/auth-success",
"code_verifier": codeVerifier,
"grant_type": "authorization_code",
}
body, _ := json.Marshal(payload)
req, err := http.NewRequest("POST", tokenURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-frontegg-sdk", "@frontegg/nextjs@9.2.10")
req.Header.Set("x-frontegg-framework", "next@15.3.8")
req.Header.Set("Origin", "https://auth.zencoder.ai")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed with status %d", resp.StatusCode)
}
var tokenResp OAuthTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, err
}
return &tokenResp, nil
}
// renderCallbackPage 渲染回调页面
func (h *OAuthHandler) renderCallbackPage(c *gin.Context, success bool, accessToken, refreshToken, errorMsg string) {
html := `
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>OAuth认证</title>
<script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-gray-50 dark:bg-gray-900 min-h-screen flex items-center justify-center">
<div class="max-w-md w-full mx-4">
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-8">
`
if success {
html += fmt.Sprintf(`
<div class="text-center">
<div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-green-100">
<svg class="h-6 w-6 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
</svg>
</div>
<h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证成功!</h2>
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">正在返回并填充Token...</p>
</div>
<script>
// 发送消息给父窗口
if (window.opener) {
window.opener.postMessage({
type: 'oauth-rt-complete',
success: true,
accessToken: '%s',
refreshToken: '%s'
}, window.location.origin);
// 2秒后关闭窗口
setTimeout(() => {
window.close();
}, 2000);
}
</script>
`, accessToken, refreshToken)
} else {
html += fmt.Sprintf(`
<div class="text-center">
<div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-red-100">
<svg class="h-6 w-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path>
</svg>
</div>
<h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证失败</h2>
<p class="mt-2 text-sm text-gray-600 dark:text-gray-400">%s</p>
<button onclick="window.close()" class="mt-4 px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition-colors">
关闭窗口
</button>
</div>
<script>
// 发送错误消息给父窗口
if (window.opener) {
window.opener.postMessage({
type: 'oauth-rt-complete',
success: false,
error: '%s'
}, window.location.origin);
}
</script>
`, errorMsg, errorMsg)
}
html += `
</div>
</div>
</body>
</html>
`
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, html)
}
// OAuthTokenResponse OAuth token响应
type OAuthTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// generateCodeVerifier 生成PKCE code_verifier
func generateCodeVerifier(length int) (string, error) {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
result := make([]byte, length)
randomBytes := make([]byte, length)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
}
for i := 0; i < length; i++ {
result[i] = chars[int(randomBytes[i])%len(chars)]
}
return string(result), nil
}
// generateCodeChallenge 生成PKCE code_challenge
func generateCodeChallenge(codeVerifier string) string {
hash := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hash[:])
}
// generateSessionID 生成会话ID
func generateSessionID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// cleanupExpiredSessions 清理过期的PKCE会话
func cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
pkceStore.Lock()
now := time.Now()
for id, session := range pkceStore.sessions {
if now.Sub(session.CreatedAt) > 10*time.Minute {
delete(pkceStore.sessions, id)
}
}
pkceStore.Unlock()
}
}