File size: 9,631 Bytes
1de7911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
package handler

import (
	"bytes"
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"net/http"
	"net/url"
	"sync"
	"time"

	"github.com/gin-gonic/gin"
)

// PKCESession 存储PKCE会话信息
type PKCESession struct {
	CodeVerifier string
	CreatedAt    time.Time
}

// PKCESessionStore 内存中存储PKCE会话
type PKCESessionStore struct {
	sync.RWMutex
	sessions map[string]*PKCESession
}

// 全局PKCE会话存储
var pkceStore = &PKCESessionStore{
	sessions: make(map[string]*PKCESession),
}

// OAuthHandler OAuth相关处理器
type OAuthHandler struct {
}

// NewOAuthHandler 创建OAuth处理器
func NewOAuthHandler() *OAuthHandler {
	// 启动清理过期会话的定时器
	go cleanupExpiredSessions()
	
	return &OAuthHandler{}
}

// StartOAuthForRT 开始OAuth流程获取RT
func (h *OAuthHandler) StartOAuthForRT(c *gin.Context) {
	// 生成PKCE参数
	codeVerifier, err := generateCodeVerifier(32)
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{
			"error": "生成PKCE参数失败",
		})
		return
	}
	
	// 生成code_challenge
	codeChallenge := generateCodeChallenge(codeVerifier)
	
	// 生成会话ID
	sessionID, err := generateSessionID()
	if err != nil {
		c.JSON(http.StatusInternalServerError, gin.H{
			"error": "生成会话ID失败",
		})
		return
	}
	
	// 存储会话
	pkceStore.Lock()
	pkceStore.sessions[sessionID] = &PKCESession{
		CodeVerifier: codeVerifier,
		CreatedAt:    time.Now(),
	}
	pkceStore.Unlock()
	
	// 获取回调URL
	scheme := "http"
	if c.Request.TLS != nil {
		scheme = "https"
	}
	host := c.Request.Host
	
	callbackURL := fmt.Sprintf("%s://%s/api/oauth/callback-rt?session=%s", 
		scheme, host, sessionID)
	
	// 构建state参数
	state := map[string]string{
		"redirectUri":   callbackURL,
		"codeChallenge": codeChallenge,
		"sessionId":     sessionID,
	}
	stateJSON, _ := json.Marshal(state)
	
	// 构建授权URL
	params := url.Values{
		"state":                       {string(stateJSON)},
		"response_type":               {"code"},
		"client_id":                   {"5948a5c5-4b30-4465-a3f2-2136ea53ea0a"},
		"scope":                       {"openid profile email"},
		"redirect_uri":                {"https://auth.zencoder.ai/extension/auth-success"},
		"code_challenge":              {codeChallenge},
		"code_challenge_method":       {"S256"},
	}
	
	authURL := fmt.Sprintf("https://fe.zencoder.ai/oauth/authorize?%s", params.Encode())
	
	// 重定向到授权页面
	c.Redirect(http.StatusFound, authURL)
}

// CallbackOAuthForRT 处理OAuth回调
func (h *OAuthHandler) CallbackOAuthForRT(c *gin.Context) {
	code := c.Query("code")
	sessionID := c.Query("session")
	
	// 验证参数
	if code == "" || sessionID == "" {
		h.renderCallbackPage(c, false, "", "", "缺少必要参数")
		return
	}
	
	// 获取会话
	pkceStore.RLock()
	session, exists := pkceStore.sessions[sessionID]
	pkceStore.RUnlock()
	
	if !exists {
		h.renderCallbackPage(c, false, "", "", "会话已过期,请重新获取")
		return
	}
	
	// 交换token
	tokenResp, err := h.exchangeCodeForToken(code, session.CodeVerifier)
	if err != nil {
		h.renderCallbackPage(c, false, "", "", fmt.Sprintf("获取Token失败: %v", err))
		return
	}
	
	// 清理会话
	pkceStore.Lock()
	delete(pkceStore.sessions, sessionID)
	pkceStore.Unlock()
	
	// 渲染成功页面,传递access token和refresh token
	h.renderCallbackPage(c, true, tokenResp.AccessToken, tokenResp.RefreshToken, "")
}

// exchangeCodeForToken 用授权码换取token
func (h *OAuthHandler) exchangeCodeForToken(code, codeVerifier string) (*OAuthTokenResponse, error) {
	tokenURL := "https://auth.zencoder.ai/api/frontegg/oauth/token"
	
	payload := map[string]string{
		"code":          code,
		"redirect_uri":  "https://auth.zencoder.ai/extension/auth-success",
		"code_verifier": codeVerifier,
		"grant_type":    "authorization_code",
	}
	
	body, _ := json.Marshal(payload)
	
	req, err := http.NewRequest("POST", tokenURL, bytes.NewReader(body))
	if err != nil {
		return nil, err
	}
	
	// 设置请求头
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("x-frontegg-sdk", "@frontegg/nextjs@9.2.10")
	req.Header.Set("x-frontegg-framework", "next@15.3.8")
	req.Header.Set("Origin", "https://auth.zencoder.ai")
	req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
	
	client := &http.Client{Timeout: 30 * time.Second}
	resp, err := client.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("token exchange failed with status %d", resp.StatusCode)
	}
	
	var tokenResp OAuthTokenResponse
	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
		return nil, err
	}
	
	return &tokenResp, nil
}

// renderCallbackPage 渲染回调页面
func (h *OAuthHandler) renderCallbackPage(c *gin.Context, success bool, accessToken, refreshToken, errorMsg string) {
	html := `
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>OAuth认证</title>
    <script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-gray-50 dark:bg-gray-900 min-h-screen flex items-center justify-center">
    <div class="max-w-md w-full mx-4">
        <div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-8">
`
	
	if success {
		html += fmt.Sprintf(`
	           <div class="text-center">
	               <div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-green-100">
	                   <svg class="h-6 w-6 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
	                       <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7"></path>
	                   </svg>
	               </div>
	               <h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证成功!</h2>
	               <p class="mt-2 text-sm text-gray-600 dark:text-gray-400">正在返回并填充Token...</p>
	           </div>
	           <script>
	               // 发送消息给父窗口
	               if (window.opener) {
	                   window.opener.postMessage({
	                       type: 'oauth-rt-complete',
	                       success: true,
	                       accessToken: '%s',
	                       refreshToken: '%s'
	                   }, window.location.origin);
	                   
	                   // 2秒后关闭窗口
	                   setTimeout(() => {
	                       window.close();
	                   }, 2000);
	               }
	           </script>
`, accessToken, refreshToken)
	} else {
		html += fmt.Sprintf(`
            <div class="text-center">
                <div class="mx-auto flex items-center justify-center h-12 w-12 rounded-full bg-red-100">
                    <svg class="h-6 w-6 text-red-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
                        <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path>
                    </svg>
                </div>
                <h2 class="mt-4 text-xl font-semibold text-gray-900 dark:text-white">认证失败</h2>
                <p class="mt-2 text-sm text-gray-600 dark:text-gray-400">%s</p>
                <button onclick="window.close()" class="mt-4 px-4 py-2 bg-gray-600 text-white rounded-lg hover:bg-gray-700 transition-colors">
                    关闭窗口
                </button>
            </div>
            <script>
                // 发送错误消息给父窗口
                if (window.opener) {
                    window.opener.postMessage({
                        type: 'oauth-rt-complete',
                        success: false,
                        error: '%s'
                    }, window.location.origin);
                }
            </script>
`, errorMsg, errorMsg)
	}
	
	html += `
        </div>
    </div>
</body>
</html>
`
	
	c.Header("Content-Type", "text/html; charset=utf-8")
	c.String(http.StatusOK, html)
}

// OAuthTokenResponse OAuth token响应
type OAuthTokenResponse struct {
	AccessToken  string `json:"access_token"`
	RefreshToken string `json:"refresh_token"`
	TokenType    string `json:"token_type"`
	ExpiresIn    int    `json:"expires_in"`
}

// generateCodeVerifier 生成PKCE code_verifier
func generateCodeVerifier(length int) (string, error) {
	const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
	result := make([]byte, length)
	randomBytes := make([]byte, length)
	
	if _, err := rand.Read(randomBytes); err != nil {
		return "", err
	}
	
	for i := 0; i < length; i++ {
		result[i] = chars[int(randomBytes[i])%len(chars)]
	}
	
	return string(result), nil
}

// generateCodeChallenge 生成PKCE code_challenge
func generateCodeChallenge(codeVerifier string) string {
	hash := sha256.Sum256([]byte(codeVerifier))
	return base64.RawURLEncoding.EncodeToString(hash[:])
}

// generateSessionID 生成会话ID
func generateSessionID() (string, error) {
	b := make([]byte, 16)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(b), nil
}

// cleanupExpiredSessions 清理过期的PKCE会话
func cleanupExpiredSessions() {
	ticker := time.NewTicker(5 * time.Minute)
	defer ticker.Stop()
	
	for range ticker.C {
		pkceStore.Lock()
		now := time.Now()
		for id, session := range pkceStore.sessions {
			if now.Sub(session.CreatedAt) > 10*time.Minute {
				delete(pkceStore.sessions, id)
			}
		}
		pkceStore.Unlock()
	}
}