Spaces:
Paused
Paused
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import jwt | |
| import time | |
| import uuid | |
| import requests | |
| import os | |
| import base64 | |
| from functools import wraps | |
| import logging | |
| app = Flask(__name__) | |
| CORS(app, origins=os.getenv('ALLOWED_ORIGINS', 'https://cybercity.top').split(',')) | |
| # 环境变量配置 | |
| CLIENT_ID = os.getenv('COZE_CLIENT_ID', '1243934778935') | |
| KID = os.getenv('COZE_KID', 'tlrohMMZyKMrrpP3GtxF_3_cerDhVIMINs0LOW91m7w') | |
| PRIVATE_KEY = os.getenv('COZE_PRIVATE_KEY').replace('\\n', '\n') # 从环境变量获取并格式化 | |
| CLIENT_SECRET = os.getenv('COZE_CLIENT_SECRET', 'your_client_secret') | |
| # 日志配置 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # JWT缓存机制(简易内存缓存) | |
| jwt_cache = {'token': None, 'exp': 0} | |
| def validate_basic_auth(auth_header): | |
| """实现RFC6749标准的Basic认证验证[10](@ref)""" | |
| if not auth_header or not auth_header.startswith('Basic '): | |
| return False | |
| try: | |
| credentials = base64.b64decode(auth_header[6:]).decode('utf-8') | |
| client_id, client_secret = credentials.split(':', 1) | |
| return client_id == CLIENT_ID and client_secret == CLIENT_SECRET | |
| except Exception as e: | |
| logger.error(f"Basic auth validation failed: {str(e)}") | |
| return False | |
| def generate_jwt(): | |
| """生成符合RFC7519标准的JWT[1,3](@ref)""" | |
| current_time = int(time.time()) | |
| payload = { | |
| "iss": CLIENT_ID, | |
| "sub": CLIENT_ID, # 必须包含sub字段[6](@ref) | |
| "aud": "https://api.coze.cn", # 精确的URI格式 | |
| "iat": current_time, | |
| "exp": current_time + 3600, | |
| "jti": uuid.uuid4().hex, | |
| "connector_id": CLIENT_ID, # 统一使用client_id | |
| "user_id": CLIENT_ID | |
| } | |
| header = { | |
| "alg": "RS256", | |
| "typ": "JWT", | |
| "kid": KID | |
| } | |
| try: | |
| return jwt.encode(payload, PRIVATE_KEY, algorithm="RS256", headers=header) | |
| except jwt.PyJWTError as e: | |
| logger.error(f"JWT generation failed: {str(e)}") | |
| raise | |
| def get_access_token(jwt_token): | |
| """获取访问令牌(带重试机制)[3](@ref)""" | |
| url = "https://api.coze.cn/api/permission/oauth2/token" | |
| data = { | |
| "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", | |
| "duration_seconds": 86399 | |
| } | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {jwt_token}" | |
| } | |
| try: | |
| response = requests.post(url, json=data, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Access token request failed: {str(e)}") | |
| return {"error": "coze_api_error"} | |
| # 健康检查端点 | |
| def health_check(): | |
| return jsonify({"status": "healthy", "timestamp": int(time.time())}), 200 | |
| # 令牌获取端点 | |
| def get_coze_token(): | |
| # Basic认证验证 | |
| if not validate_basic_auth(request.headers.get('Authorization')): | |
| return jsonify({"error": "invalid_client"}), 401 | |
| # 检查缓存中的有效JWT | |
| current_time = time.time() | |
| if jwt_cache['exp'] > current_time + 300: # 有效期剩余超过5分钟时复用 | |
| cached_token = jwt_cache['token'] | |
| else: | |
| try: | |
| cached_token = generate_jwt() | |
| jwt_cache.update({ | |
| 'token': cached_token, | |
| 'exp': current_time + 3600 | |
| }) | |
| except Exception as e: | |
| return jsonify({"error": "jwt_generation_failed"}), 500 | |
| # 获取访问令牌 | |
| token_response = get_access_token(cached_token) | |
| if 'error' in token_response: | |
| return jsonify({ | |
| "error": "coze_oauth_error", | |
| "details": token_response.get('error_description') | |
| }), 502 | |
| return jsonify({ | |
| "access_token": token_response['access_token'], | |
| "expires_in": token_response['expires_in'], | |
| "token_type": "Bearer" | |
| }) | |
| # 错误处理 | |
| def not_found(error): | |
| return jsonify({"error": "endpoint_not_found"}), 404 | |
| def internal_error(error): | |
| return jsonify({"error": "internal_server_error"}), 500 | |
| if __name__ == '__main__': | |
| port = int(os.getenv('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=os.getenv('DEBUG', 'false').lower() == 'true') |