| | package handler |
| |
|
| | import ( |
| | "bytes" |
| | "crypto/rand" |
| | "crypto/sha256" |
| | "encoding/base64" |
| | "encoding/json" |
| | "fmt" |
| | "net/http" |
| | "net/url" |
| | "sync" |
| | "time" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | |
| | type PKCESession struct { |
| | CodeVerifier string |
| | CreatedAt time.Time |
| | } |
| |
|
| | |
| | type PKCESessionStore struct { |
| | sync.RWMutex |
| | sessions map[string]*PKCESession |
| | } |
| |
|
| | |
| | var pkceStore = &PKCESessionStore{ |
| | sessions: make(map[string]*PKCESession), |
| | } |
| |
|
| | |
| | type OAuthHandler struct { |
| | } |
| |
|
| | |
| | func NewOAuthHandler() *OAuthHandler { |
| | |
| | go cleanupExpiredSessions() |
| | |
| | return &OAuthHandler{} |
| | } |
| |
|
| | |
| | func (h *OAuthHandler) StartOAuthForRT(c *gin.Context) { |
| | |
| | codeVerifier, err := generateCodeVerifier(32) |
| | if err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "error": "生成PKCE参数失败", |
| | }) |
| | return |
| | } |
| | |
| | |
| | codeChallenge := generateCodeChallenge(codeVerifier) |
| | |
| | |
| | sessionID, err := generateSessionID() |
| | if err != nil { |
| | c.JSON(http.StatusInternalServerError, gin.H{ |
| | "error": "生成会话ID失败", |
| | }) |
| | return |
| | } |
| | |
| | |
| | pkceStore.Lock() |
| | pkceStore.sessions[sessionID] = &PKCESession{ |
| | CodeVerifier: codeVerifier, |
| | CreatedAt: time.Now(), |
| | } |
| | pkceStore.Unlock() |
| | |
| | |
| | scheme := "http" |
| | if c.Request.TLS != nil { |
| | scheme = "https" |
| | } |
| | host := c.Request.Host |
| | |
| | callbackURL := fmt.Sprintf("%s://%s/api/oauth/callback-rt?session=%s", |
| | scheme, host, sessionID) |
| | |
| | |
| | state := map[string]string{ |
| | "redirectUri": callbackURL, |
| | "codeChallenge": codeChallenge, |
| | "sessionId": sessionID, |
| | } |
| | stateJSON, _ := json.Marshal(state) |
| | |
| | |
| | params := url.Values{ |
| | "state": {string(stateJSON)}, |
| | "response_type": {"code"}, |
| | "client_id": {"5948a5c5-4b30-4465-a3f2-2136ea53ea0a"}, |
| | "scope": {"openid profile email"}, |
| | "redirect_uri": {"https://auth.zencoder.ai/extension/auth-success"}, |
| | "code_challenge": {codeChallenge}, |
| | "code_challenge_method": {"S256"}, |
| | } |
| | |
| | authURL := fmt.Sprintf("https://fe.zencoder.ai/oauth/authorize?%s", params.Encode()) |
| | |
| | |
| | c.Redirect(http.StatusFound, authURL) |
| | } |
| |
|
| | |
| | func (h *OAuthHandler) CallbackOAuthForRT(c *gin.Context) { |
| | code := c.Query("code") |
| | sessionID := c.Query("session") |
| | |
| | |
| | if code == "" || sessionID == "" { |
| | h.renderCallbackPage(c, false, "", "", "缺少必要参数") |
| | return |
| | } |
| | |
| | |
| | pkceStore.RLock() |
| | session, exists := pkceStore.sessions[sessionID] |
| | pkceStore.RUnlock() |
| | |
| | if !exists { |
| | h.renderCallbackPage(c, false, "", "", "会话已过期,请重新获取") |
| | return |
| | } |
| | |
| | |
| | tokenResp, err := h.exchangeCodeForToken(code, session.CodeVerifier) |
| | if err != nil { |
| | h.renderCallbackPage(c, false, "", "", fmt.Sprintf("获取Token失败: %v", err)) |
| | return |
| | } |
| | |
| | |
| | pkceStore.Lock() |
| | delete(pkceStore.sessions, sessionID) |
| | pkceStore.Unlock() |
| | |
| | |
| | h.renderCallbackPage(c, true, tokenResp.AccessToken, tokenResp.RefreshToken, "") |
| | } |
| |
|
| | |
| | func (h *OAuthHandler) exchangeCodeForToken(code, codeVerifier string) (*OAuthTokenResponse, error) { |
| | tokenURL := "https://auth.zencoder.ai/api/frontegg/oauth/token" |
| | |
| | payload := map[string]string{ |
| | "code": code, |
| | "redirect_uri": "https://auth.zencoder.ai/extension/auth-success", |
| | "code_verifier": codeVerifier, |
| | "grant_type": "authorization_code", |
| | } |
| | |
| | body, _ := json.Marshal(payload) |
| | |
| | req, err := http.NewRequest("POST", tokenURL, bytes.NewReader(body)) |
| | if err != nil { |
| | return nil, err |
| | } |
| | |
| | |
| | req.Header.Set("Content-Type", "application/json") |
| | req.Header.Set("x-frontegg-sdk", "@frontegg/nextjs@9.2.10") |
| | req.Header.Set("x-frontegg-framework", "next@15.3.8") |
| | req.Header.Set("Origin", "https://auth.zencoder.ai") |
| | req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") |
| | |
| | client := &http.Client{Timeout: 30 * time.Second} |
| | resp, err := client.Do(req) |
| | if err != nil { |
| | return nil, err |
| | } |
| | defer resp.Body.Close() |
| | |
| | if resp.StatusCode != http.StatusOK { |
| | return nil, fmt.Errorf("token exchange failed with status %d", resp.StatusCode) |
| | } |
| | |
| | var tokenResp OAuthTokenResponse |
| | if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { |
| | return nil, err |
| | } |
| | |
| | return &tokenResp, nil |
| | } |
| |
|
| | |
| | func (h *OAuthHandler) renderCallbackPage(c *gin.Context, success bool, accessToken, refreshToken, errorMsg string) { |
| | html := ` |
| | <!DOCTYPE html> |
| | <html lang="zh-CN"> |
| | <head> |
| | <meta charset="UTF-8"> |
| | <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| | <title>OAuth认证</title> |
| | <script src="https://cdn.tailwindcss.com"></script> |
| | </head> |
| | <body class="bg-gray-50 dark:bg-gray-900 min-h-screen flex items-center justify-center"> |
| | <div class="max-w-md w-full mx-4"> |
| | <div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-8"> |
| | ` |
| | |
| | if success { |
| | html += fmt.Sprintf(` |
| | <div class="text-center"> |
| | <div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-green-100"> |
| | <svg class="h-6 w-6 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24"> |
| | <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path> |
| | </svg> |
| | </div> |
| | <h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证成功!</h2> |
| | <p class="mt-2 text-sm text-gray-600 dark:text-gray-400">正在返回并填充Token...</p> |
| | </div> |
| | <script> |
| | // 发送消息给父窗口 |
| | if (window.opener) { |
| | window.opener.postMessage({ |
| | type: 'oauth-rt-complete', |
| | success: true, |
| | accessToken: '%s', |
| | refreshToken: '%s' |
| | }, window.location.origin); |
| | |
| | // 2秒后关闭窗口 |
| | setTimeout(() => { |
| | window.close(); |
| | }, 2000); |
| | } |
| | </script> |
| | `, accessToken, refreshToken) |
| | } else { |
| | html += fmt.Sprintf(` |
| | <div class="text-center"> |
| | <div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-red-100"> |
| | <svg class="h-6 w-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24"> |
| | <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path> |
| | </svg> |
| | </div> |
| | <h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证失败</h2> |
| | <p class="mt-2 text-sm text-gray-600 dark:text-gray-400">%s</p> |
| | <button onclick="window.close()" class="mt-4 px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition-colors"> |
| | 关闭窗口 |
| | </button> |
| | </div> |
| | <script> |
| | // 发送错误消息给父窗口 |
| | if (window.opener) { |
| | window.opener.postMessage({ |
| | type: 'oauth-rt-complete', |
| | success: false, |
| | error: '%s' |
| | }, window.location.origin); |
| | } |
| | </script> |
| | `, errorMsg, errorMsg) |
| | } |
| | |
| | html += ` |
| | </div> |
| | </div> |
| | </body> |
| | </html> |
| | ` |
| | |
| | c.Header("Content-Type", "text/html; charset=utf-8") |
| | c.String(http.StatusOK, html) |
| | } |
| |
|
| | |
| | type OAuthTokenResponse struct { |
| | AccessToken string `json:"access_token"` |
| | RefreshToken string `json:"refresh_token"` |
| | TokenType string `json:"token_type"` |
| | ExpiresIn int `json:"expires_in"` |
| | } |
| |
|
| | |
| | func generateCodeVerifier(length int) (string, error) { |
| | const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" |
| | result := make([]byte, length) |
| | randomBytes := make([]byte, length) |
| | |
| | if _, err := rand.Read(randomBytes); err != nil { |
| | return "", err |
| | } |
| | |
| | for i := 0; i < length; i++ { |
| | result[i] = chars[int(randomBytes[i])%len(chars)] |
| | } |
| | |
| | return string(result), nil |
| | } |
| |
|
| | |
| | func generateCodeChallenge(codeVerifier string) string { |
| | hash := sha256.Sum256([]byte(codeVerifier)) |
| | return base64.RawURLEncoding.EncodeToString(hash[:]) |
| | } |
| |
|
| | |
| | func generateSessionID() (string, error) { |
| | b := make([]byte, 16) |
| | if _, err := rand.Read(b); err != nil { |
| | return "", err |
| | } |
| | return base64.URLEncoding.EncodeToString(b), nil |
| | } |
| |
|
| | |
| | func cleanupExpiredSessions() { |
| | ticker := time.NewTicker(5 * time.Minute) |
| | defer ticker.Stop() |
| | |
| | for range ticker.C { |
| | pkceStore.Lock() |
| | now := time.Now() |
| | for id, session := range pkceStore.sessions { |
| | if now.Sub(session.CreatedAt) > 10*time.Minute { |
| | delete(pkceStore.sessions, id) |
| | } |
| | } |
| | pkceStore.Unlock() |
| | } |
| | } |