|
|
""" |
|
|
认证API模块 |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import secrets |
|
|
import socket |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
from datetime import timezone |
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
|
from typing import Any, Dict, List, Optional |
|
|
from urllib.parse import parse_qs, urlparse |
|
|
|
|
|
from config import get_config_value, get_antigravity_api_url, get_code_assist_endpoint |
|
|
from log import log |
|
|
|
|
|
from .google_oauth_api import ( |
|
|
Credentials, |
|
|
Flow, |
|
|
enable_required_apis, |
|
|
fetch_project_id, |
|
|
get_user_projects, |
|
|
select_default_project, |
|
|
) |
|
|
from .storage_adapter import get_storage_adapter |
|
|
from .utils import ( |
|
|
ANTIGRAVITY_CLIENT_ID, |
|
|
ANTIGRAVITY_CLIENT_SECRET, |
|
|
ANTIGRAVITY_SCOPES, |
|
|
ANTIGRAVITY_USER_AGENT, |
|
|
CALLBACK_HOST, |
|
|
CLIENT_ID, |
|
|
CLIENT_SECRET, |
|
|
SCOPES, |
|
|
GEMINICLI_USER_AGENT, |
|
|
TOKEN_URL, |
|
|
) |
|
|
|
|
|
|
|
|
async def get_callback_port(): |
|
|
"""获取OAuth回调端口""" |
|
|
return int(await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT")) |
|
|
|
|
|
|
|
|
def _prepare_credentials_data(credentials: Credentials, project_id: str, mode: str = "geminicli") -> Dict[str, Any]: |
|
|
"""准备凭证数据字典(统一函数)""" |
|
|
if mode == "antigravity": |
|
|
creds_data = { |
|
|
"client_id": ANTIGRAVITY_CLIENT_ID, |
|
|
"client_secret": ANTIGRAVITY_CLIENT_SECRET, |
|
|
"token": credentials.access_token, |
|
|
"refresh_token": credentials.refresh_token, |
|
|
"scopes": ANTIGRAVITY_SCOPES, |
|
|
"token_uri": TOKEN_URL, |
|
|
"project_id": project_id, |
|
|
} |
|
|
else: |
|
|
creds_data = { |
|
|
"client_id": CLIENT_ID, |
|
|
"client_secret": CLIENT_SECRET, |
|
|
"token": credentials.access_token, |
|
|
"refresh_token": credentials.refresh_token, |
|
|
"scopes": SCOPES, |
|
|
"token_uri": TOKEN_URL, |
|
|
"project_id": project_id, |
|
|
} |
|
|
|
|
|
if credentials.expires_at: |
|
|
if credentials.expires_at.tzinfo is None: |
|
|
expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc) |
|
|
else: |
|
|
expiry_utc = credentials.expires_at |
|
|
creds_data["expiry"] = expiry_utc.isoformat() |
|
|
|
|
|
return creds_data |
|
|
|
|
|
|
|
|
def _generate_random_project_id() -> str: |
|
|
"""生成随机project_id(antigravity模式使用)""" |
|
|
random_id = uuid.uuid4().hex[:8] |
|
|
return f"projects/random-{random_id}/locations/global" |
|
|
|
|
|
|
|
|
def _cleanup_auth_flow_server(state: str): |
|
|
"""清理认证流程的服务器资源""" |
|
|
if state in auth_flows: |
|
|
flow_data_to_clean = auth_flows[state] |
|
|
try: |
|
|
if flow_data_to_clean.get("server"): |
|
|
server = flow_data_to_clean["server"] |
|
|
port = flow_data_to_clean.get("callback_port") |
|
|
async_shutdown_server(server, port) |
|
|
except Exception as e: |
|
|
log.debug(f"关闭服务器时出错: {e}") |
|
|
del auth_flows[state] |
|
|
|
|
|
|
|
|
class _OAuthLibPatcher: |
|
|
"""oauthlib参数验证补丁的上下文管理器""" |
|
|
def __init__(self): |
|
|
import oauthlib.oauth2.rfc6749.parameters |
|
|
self.module = oauthlib.oauth2.rfc6749.parameters |
|
|
self.original_validate = None |
|
|
|
|
|
def __enter__(self): |
|
|
self.original_validate = self.module.validate_token_parameters |
|
|
|
|
|
def patched_validate(params): |
|
|
try: |
|
|
return self.original_validate(params) |
|
|
except Warning: |
|
|
pass |
|
|
|
|
|
self.module.validate_token_parameters = patched_validate |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
if self.original_validate: |
|
|
self.module.validate_token_parameters = self.original_validate |
|
|
|
|
|
|
|
|
|
|
|
auth_flows = {} |
|
|
MAX_AUTH_FLOWS = 20 |
|
|
|
|
|
|
|
|
def cleanup_auth_flows_for_memory(): |
|
|
"""清理认证流程以释放内存""" |
|
|
global auth_flows |
|
|
cleanup_expired_flows() |
|
|
|
|
|
if len(auth_flows) > 10: |
|
|
|
|
|
sorted_flows = sorted( |
|
|
auth_flows.items(), key=lambda x: x[1].get("created_at", 0), reverse=True |
|
|
) |
|
|
new_auth_flows = dict(sorted_flows[:10]) |
|
|
|
|
|
|
|
|
for state, flow_data in auth_flows.items(): |
|
|
if state not in new_auth_flows: |
|
|
try: |
|
|
if flow_data.get("server"): |
|
|
server = flow_data["server"] |
|
|
port = flow_data.get("callback_port") |
|
|
async_shutdown_server(server, port) |
|
|
except Exception: |
|
|
pass |
|
|
flow_data.clear() |
|
|
|
|
|
auth_flows = new_auth_flows |
|
|
log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程") |
|
|
|
|
|
return len(auth_flows) |
|
|
|
|
|
|
|
|
async def find_available_port(start_port: int = None) -> int: |
|
|
"""动态查找可用端口""" |
|
|
if start_port is None: |
|
|
start_port = await get_callback_port() |
|
|
|
|
|
|
|
|
for port in range(start_port, start_port + 100): |
|
|
try: |
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
|
s.bind(("0.0.0.0", port)) |
|
|
log.info(f"找到可用端口: {port}") |
|
|
return port |
|
|
except OSError: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
|
|
s.bind(("0.0.0.0", 0)) |
|
|
port = s.getsockname()[1] |
|
|
log.info(f"系统分配可用端口: {port}") |
|
|
return port |
|
|
except OSError as e: |
|
|
log.error(f"无法找到可用端口: {e}") |
|
|
raise RuntimeError("无法找到可用端口") |
|
|
|
|
|
|
|
|
def create_callback_server(port: int) -> HTTPServer: |
|
|
"""创建指定端口的回调服务器,优化快速关闭""" |
|
|
try: |
|
|
|
|
|
server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler) |
|
|
|
|
|
|
|
|
server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
|
|
|
|
server.timeout = 1.0 |
|
|
|
|
|
log.info(f"创建OAuth回调服务器,监听端口: {port}") |
|
|
return server |
|
|
except OSError as e: |
|
|
log.error(f"创建端口{port}的服务器失败: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class AuthCallbackHandler(BaseHTTPRequestHandler): |
|
|
"""OAuth回调处理器""" |
|
|
|
|
|
def do_GET(self): |
|
|
query_components = parse_qs(urlparse(self.path).query) |
|
|
code = query_components.get("code", [None])[0] |
|
|
state = query_components.get("state", [None])[0] |
|
|
|
|
|
log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}") |
|
|
|
|
|
if code and state and state in auth_flows: |
|
|
|
|
|
auth_flows[state]["code"] = code |
|
|
auth_flows[state]["completed"] = True |
|
|
|
|
|
log.info(f"OAuth回调成功处理: state={state}") |
|
|
|
|
|
self.send_response(200) |
|
|
self.send_header("Content-type", "text/html") |
|
|
self.end_headers() |
|
|
|
|
|
self.wfile.write( |
|
|
b"<h1>OAuth authentication successful!</h1><p>You can close this window. Please return to the original page and click 'Get Credentials' button.</p>" |
|
|
) |
|
|
else: |
|
|
self.send_response(400) |
|
|
self.send_header("Content-type", "text/html") |
|
|
self.end_headers() |
|
|
self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>") |
|
|
|
|
|
def log_message(self, format, *args): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
async def create_auth_url( |
|
|
project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" |
|
|
) -> Dict[str, Any]: |
|
|
"""创建认证URL,支持动态端口分配""" |
|
|
try: |
|
|
|
|
|
callback_port = await find_available_port() |
|
|
callback_url = f"http://{CALLBACK_HOST}:{callback_port}" |
|
|
|
|
|
|
|
|
try: |
|
|
callback_server = create_callback_server(callback_port) |
|
|
|
|
|
server_thread = threading.Thread( |
|
|
target=callback_server.serve_forever, |
|
|
daemon=True, |
|
|
name=f"OAuth-Server-{callback_port}", |
|
|
) |
|
|
server_thread.start() |
|
|
log.info(f"OAuth回调服务器已启动,端口: {callback_port}") |
|
|
except Exception as e: |
|
|
log.error(f"启动回调服务器失败: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"无法启动OAuth回调服务器,端口{callback_port}: {str(e)}", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if mode == "antigravity": |
|
|
client_id = ANTIGRAVITY_CLIENT_ID |
|
|
client_secret = ANTIGRAVITY_CLIENT_SECRET |
|
|
scopes = ANTIGRAVITY_SCOPES |
|
|
else: |
|
|
client_id = CLIENT_ID |
|
|
client_secret = CLIENT_SECRET |
|
|
scopes = SCOPES |
|
|
|
|
|
flow = Flow( |
|
|
client_id=client_id, |
|
|
client_secret=client_secret, |
|
|
scopes=scopes, |
|
|
redirect_uri=callback_url, |
|
|
) |
|
|
|
|
|
|
|
|
if user_session: |
|
|
state = f"{user_session}_{str(uuid.uuid4())}" |
|
|
else: |
|
|
state = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
auth_url = flow.get_auth_url(state=state) |
|
|
|
|
|
|
|
|
if len(auth_flows) >= MAX_AUTH_FLOWS: |
|
|
|
|
|
oldest_state = min(auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0)) |
|
|
try: |
|
|
|
|
|
old_flow = auth_flows[oldest_state] |
|
|
if old_flow.get("server"): |
|
|
server = old_flow["server"] |
|
|
port = old_flow.get("callback_port") |
|
|
async_shutdown_server(server, port) |
|
|
except Exception as e: |
|
|
log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}") |
|
|
|
|
|
del auth_flows[oldest_state] |
|
|
log.debug(f"Removed oldest auth flow: {oldest_state}") |
|
|
|
|
|
|
|
|
auth_flows[state] = { |
|
|
"flow": flow, |
|
|
"project_id": project_id, |
|
|
"user_session": user_session, |
|
|
"callback_port": callback_port, |
|
|
"callback_url": callback_url, |
|
|
"server": callback_server, |
|
|
"server_thread": server_thread, |
|
|
"code": None, |
|
|
"completed": False, |
|
|
"created_at": time.time(), |
|
|
"auto_project_detection": project_id is None, |
|
|
"mode": mode, |
|
|
} |
|
|
|
|
|
|
|
|
cleanup_expired_flows() |
|
|
|
|
|
log.info(f"OAuth流程已创建: state={state}, project_id={project_id}") |
|
|
log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}") |
|
|
log.info(f"为此认证流程分配的端口: {callback_port}") |
|
|
|
|
|
return { |
|
|
"auth_url": auth_url, |
|
|
"state": state, |
|
|
"callback_port": callback_port, |
|
|
"success": True, |
|
|
"auto_project_detection": project_id is None, |
|
|
"detected_project_id": project_id, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"创建认证URL失败: {e}") |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
|
|
|
def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]: |
|
|
"""同步等待OAuth回调完成,使用对应流程的专用服务器""" |
|
|
if state not in auth_flows: |
|
|
log.error(f"未找到状态为 {state} 的认证流程") |
|
|
return None |
|
|
|
|
|
flow_data = auth_flows[state] |
|
|
callback_port = flow_data["callback_port"] |
|
|
|
|
|
|
|
|
log.info(f"等待OAuth回调完成,端口: {callback_port}") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
while time.time() - start_time < timeout: |
|
|
if flow_data.get("code"): |
|
|
log.info("OAuth回调成功完成") |
|
|
return flow_data["code"] |
|
|
time.sleep(0.5) |
|
|
|
|
|
|
|
|
if state in auth_flows: |
|
|
flow_data = auth_flows[state] |
|
|
|
|
|
log.warning(f"等待OAuth回调超时 ({timeout}秒)") |
|
|
return None |
|
|
|
|
|
|
|
|
async def complete_auth_flow( |
|
|
project_id: Optional[str] = None, user_session: str = None |
|
|
) -> Dict[str, Any]: |
|
|
"""完成认证流程并保存凭证,支持自动检测项目ID""" |
|
|
try: |
|
|
|
|
|
state = None |
|
|
flow_data = None |
|
|
|
|
|
|
|
|
if project_id: |
|
|
for s, data in auth_flows.items(): |
|
|
if data["project_id"] == project_id: |
|
|
|
|
|
if user_session and data.get("user_session") == user_session: |
|
|
state = s |
|
|
flow_data = data |
|
|
break |
|
|
|
|
|
elif not state: |
|
|
state = s |
|
|
flow_data = data |
|
|
|
|
|
|
|
|
if not state: |
|
|
for s, data in auth_flows.items(): |
|
|
if data.get("auto_project_detection", False): |
|
|
|
|
|
if user_session and data.get("user_session") == user_session: |
|
|
state = s |
|
|
flow_data = data |
|
|
break |
|
|
|
|
|
elif not state: |
|
|
state = s |
|
|
flow_data = data |
|
|
|
|
|
if not state or not flow_data: |
|
|
return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} |
|
|
|
|
|
if not project_id: |
|
|
project_id = flow_data.get("project_id") |
|
|
if not project_id: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "缺少项目ID,请指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
|
|
|
flow = flow_data["flow"] |
|
|
|
|
|
|
|
|
if not flow_data.get("code"): |
|
|
log.info(f"等待用户完成OAuth授权 (state: {state})") |
|
|
auth_code = wait_for_callback_sync(state) |
|
|
|
|
|
if not auth_code: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "未接收到授权回调,请确保完成了浏览器中的OAuth认证", |
|
|
} |
|
|
|
|
|
|
|
|
auth_flows[state]["code"] = auth_code |
|
|
auth_flows[state]["completed"] = True |
|
|
else: |
|
|
auth_code = flow_data["code"] |
|
|
|
|
|
|
|
|
with _OAuthLibPatcher(): |
|
|
try: |
|
|
credentials = await flow.exchange_code(auth_code) |
|
|
|
|
|
|
|
|
|
|
|
if flow_data.get("auto_project_detection", False) and not project_id: |
|
|
log.info("尝试通过API获取用户项目列表...") |
|
|
log.info(f"使用的token: {credentials.access_token[:20]}...") |
|
|
log.info(f"Token过期时间: {credentials.expires_at}") |
|
|
user_projects = await get_user_projects(credentials) |
|
|
|
|
|
if user_projects: |
|
|
|
|
|
if len(user_projects) == 1: |
|
|
|
|
|
project_id = user_projects[0].get("projectId") |
|
|
if project_id: |
|
|
flow_data["project_id"] = project_id |
|
|
log.info(f"自动选择唯一项目: {project_id}") |
|
|
|
|
|
else: |
|
|
project_id = await select_default_project(user_projects) |
|
|
if project_id: |
|
|
flow_data["project_id"] = project_id |
|
|
log.info(f"自动选择默认项目: {project_id}") |
|
|
else: |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "请从以下项目中选择一个", |
|
|
"requires_project_selection": True, |
|
|
"available_projects": [ |
|
|
{ |
|
|
|
|
|
"project_id": p.get("projectId"), |
|
|
"name": p.get("displayName") or p.get("projectId"), |
|
|
"projectNumber": p.get("projectNumber"), |
|
|
} |
|
|
for p in user_projects |
|
|
], |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "无法获取您的项目列表,请手动指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
|
|
|
|
|
|
if not project_id: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "缺少项目ID,请指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
|
|
|
|
|
|
saved_filename = await save_credentials(credentials, project_id) |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") |
|
|
|
|
|
|
|
|
_cleanup_auth_flow_server(state) |
|
|
|
|
|
log.info("OAuth认证成功,凭证已保存") |
|
|
return { |
|
|
"success": True, |
|
|
"credentials": creds_data, |
|
|
"file_path": saved_filename, |
|
|
"auto_detected_project": flow_data.get("auto_project_detection", False), |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"获取凭证失败: {e}") |
|
|
return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"完成认证流程失败: {e}") |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
|
|
|
async def asyncio_complete_auth_flow( |
|
|
project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" |
|
|
) -> Dict[str, Any]: |
|
|
"""异步完成认证流程,支持自动检测项目ID""" |
|
|
try: |
|
|
log.info( |
|
|
f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}" |
|
|
) |
|
|
|
|
|
|
|
|
state = None |
|
|
flow_data = None |
|
|
|
|
|
log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}") |
|
|
|
|
|
|
|
|
if project_id: |
|
|
log.info(f"尝试匹配指定的项目ID: {project_id}") |
|
|
for s, data in auth_flows.items(): |
|
|
if data["project_id"] == project_id: |
|
|
|
|
|
if user_session and data.get("user_session") == user_session: |
|
|
state = s |
|
|
flow_data = data |
|
|
log.info(f"找到匹配的用户会话: {s}") |
|
|
break |
|
|
|
|
|
elif not state: |
|
|
state = s |
|
|
flow_data = data |
|
|
log.info(f"找到匹配的项目ID: {s}") |
|
|
|
|
|
|
|
|
if not state: |
|
|
log.info("没有找到指定项目的流程,查找自动检测流程") |
|
|
|
|
|
completed_flows = [] |
|
|
for s, data in auth_flows.items(): |
|
|
if data.get("auto_project_detection", False): |
|
|
if user_session and data.get("user_session") == user_session: |
|
|
if data.get("code"): |
|
|
completed_flows.append((s, data, data.get("created_at", 0))) |
|
|
|
|
|
|
|
|
if completed_flows: |
|
|
completed_flows.sort(key=lambda x: x[2], reverse=True) |
|
|
state, flow_data, _ = completed_flows[0] |
|
|
log.info(f"找到已完成的最新认证流程: {state}") |
|
|
else: |
|
|
|
|
|
pending_flows = [] |
|
|
for s, data in auth_flows.items(): |
|
|
if data.get("auto_project_detection", False): |
|
|
if user_session and data.get("user_session") == user_session: |
|
|
pending_flows.append((s, data, data.get("created_at", 0))) |
|
|
elif not user_session: |
|
|
pending_flows.append((s, data, data.get("created_at", 0))) |
|
|
|
|
|
if pending_flows: |
|
|
pending_flows.sort(key=lambda x: x[2], reverse=True) |
|
|
state, flow_data, _ = pending_flows[0] |
|
|
log.info(f"找到最新的待完成认证流程: {state}") |
|
|
|
|
|
if not state or not flow_data: |
|
|
log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}") |
|
|
log.debug(f"当前所有flow_data: {list(auth_flows.keys())}") |
|
|
return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} |
|
|
|
|
|
log.info(f"找到认证流程: state={state}") |
|
|
log.info( |
|
|
f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}" |
|
|
) |
|
|
log.info(f"传入的project_id参数: {project_id}") |
|
|
|
|
|
|
|
|
log.info( |
|
|
f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}" |
|
|
) |
|
|
if flow_data.get("auto_project_detection", False) and not project_id: |
|
|
log.info("跳过自动检测项目ID,进入等待阶段") |
|
|
elif not project_id: |
|
|
log.info("进入project_id检查分支") |
|
|
project_id = flow_data.get("project_id") |
|
|
if not project_id: |
|
|
log.error("缺少项目ID,返回错误") |
|
|
return { |
|
|
"success": False, |
|
|
"error": "缺少项目ID,请指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
else: |
|
|
log.info(f"使用提供的项目ID: {project_id}") |
|
|
|
|
|
|
|
|
log.info("开始检查OAuth授权码...") |
|
|
log.info(f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}") |
|
|
log.info(f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}") |
|
|
max_wait_time = 60 |
|
|
wait_interval = 1 |
|
|
waited = 0 |
|
|
|
|
|
while waited < max_wait_time: |
|
|
if flow_data.get("code"): |
|
|
log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") |
|
|
break |
|
|
|
|
|
|
|
|
if waited % 5 == 0 and waited > 0: |
|
|
log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") |
|
|
log.debug(f"当前state: {state}, flow_data keys: {list(flow_data.keys())}") |
|
|
|
|
|
|
|
|
await asyncio.sleep(wait_interval) |
|
|
waited += wait_interval |
|
|
|
|
|
|
|
|
if state in auth_flows: |
|
|
flow_data = auth_flows[state] |
|
|
|
|
|
if not flow_data.get("code"): |
|
|
log.error(f"等待OAuth回调超时,等待了{waited}秒") |
|
|
return { |
|
|
"success": False, |
|
|
"error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", |
|
|
} |
|
|
|
|
|
flow = flow_data["flow"] |
|
|
auth_code = flow_data["code"] |
|
|
|
|
|
log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}") |
|
|
|
|
|
|
|
|
with _OAuthLibPatcher(): |
|
|
try: |
|
|
log.info("调用flow.exchange_code...") |
|
|
credentials = await flow.exchange_code(auth_code) |
|
|
log.info( |
|
|
f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}..." |
|
|
) |
|
|
|
|
|
log.info( |
|
|
f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}" |
|
|
) |
|
|
|
|
|
|
|
|
cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode |
|
|
if cred_mode == "antigravity": |
|
|
log.info("Antigravity模式:从API获取project_id...") |
|
|
|
|
|
antigravity_url = await get_antigravity_api_url() |
|
|
project_id = await fetch_project_id( |
|
|
credentials.access_token, |
|
|
ANTIGRAVITY_USER_AGENT, |
|
|
antigravity_url |
|
|
) |
|
|
if project_id: |
|
|
log.info(f"成功从API获取project_id: {project_id}") |
|
|
else: |
|
|
log.warning("无法从API获取project_id,回退到随机生成") |
|
|
project_id = _generate_random_project_id() |
|
|
log.info(f"生成的随机project_id: {project_id}") |
|
|
|
|
|
|
|
|
saved_filename = await save_credentials(credentials, project_id, mode="antigravity") |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") |
|
|
|
|
|
|
|
|
_cleanup_auth_flow_server(state) |
|
|
|
|
|
log.info("Antigravity OAuth认证成功,凭证已保存") |
|
|
return { |
|
|
"success": True, |
|
|
"credentials": creds_data, |
|
|
"file_path": saved_filename, |
|
|
"auto_detected_project": False, |
|
|
"mode": "antigravity", |
|
|
} |
|
|
|
|
|
|
|
|
if flow_data.get("auto_project_detection", False) and not project_id: |
|
|
log.info("标准模式:从API获取project_id...") |
|
|
|
|
|
code_assist_url = await get_code_assist_endpoint() |
|
|
project_id = await fetch_project_id( |
|
|
credentials.access_token, |
|
|
GEMINICLI_USER_AGENT, |
|
|
code_assist_url |
|
|
) |
|
|
if project_id: |
|
|
flow_data["project_id"] = project_id |
|
|
log.info(f"成功从API获取project_id: {project_id}") |
|
|
|
|
|
log.info("正在自动启用必需的API服务...") |
|
|
await enable_required_apis(credentials, project_id) |
|
|
else: |
|
|
log.warning("无法从API获取project_id,回退到项目列表获取方式") |
|
|
|
|
|
user_projects = await get_user_projects(credentials) |
|
|
|
|
|
if user_projects: |
|
|
|
|
|
if len(user_projects) == 1: |
|
|
|
|
|
project_id = user_projects[0].get("projectId") |
|
|
if project_id: |
|
|
flow_data["project_id"] = project_id |
|
|
log.info(f"自动选择唯一项目: {project_id}") |
|
|
|
|
|
log.info("正在自动启用必需的API服务...") |
|
|
await enable_required_apis(credentials, project_id) |
|
|
|
|
|
else: |
|
|
project_id = await select_default_project(user_projects) |
|
|
if project_id: |
|
|
flow_data["project_id"] = project_id |
|
|
log.info(f"自动选择默认项目: {project_id}") |
|
|
|
|
|
log.info("正在自动启用必需的API服务...") |
|
|
await enable_required_apis(credentials, project_id) |
|
|
else: |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "请从以下项目中选择一个", |
|
|
"requires_project_selection": True, |
|
|
"available_projects": [ |
|
|
{ |
|
|
|
|
|
"project_id": p.get("projectId"), |
|
|
"name": p.get("displayName") or p.get("projectId"), |
|
|
"projectNumber": p.get("projectNumber"), |
|
|
} |
|
|
for p in user_projects |
|
|
], |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "无法获取您的项目列表,请手动指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
elif project_id: |
|
|
|
|
|
log.info("正在为已提供的项目ID自动启用必需的API服务...") |
|
|
await enable_required_apis(credentials, project_id) |
|
|
|
|
|
|
|
|
if not project_id: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "缺少项目ID,请指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
|
|
|
|
|
|
saved_filename = await save_credentials(credentials, project_id) |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") |
|
|
|
|
|
|
|
|
_cleanup_auth_flow_server(state) |
|
|
|
|
|
log.info("OAuth认证成功,凭证已保存") |
|
|
return { |
|
|
"success": True, |
|
|
"credentials": creds_data, |
|
|
"file_path": saved_filename, |
|
|
"auto_detected_project": flow_data.get("auto_project_detection", False), |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"获取凭证失败: {e}") |
|
|
return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"异步完成认证流程失败: {e}") |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
|
|
|
async def complete_auth_flow_from_callback_url( |
|
|
callback_url: str, project_id: Optional[str] = None, mode: str = "geminicli" |
|
|
) -> Dict[str, Any]: |
|
|
"""从回调URL直接完成认证流程,无需启动本地服务器""" |
|
|
try: |
|
|
log.info(f"开始从回调URL完成认证: {callback_url}") |
|
|
|
|
|
|
|
|
parsed_url = urlparse(callback_url) |
|
|
query_params = parse_qs(parsed_url.query) |
|
|
|
|
|
|
|
|
if "state" not in query_params or "code" not in query_params: |
|
|
return {"success": False, "error": "回调URL缺少必要参数 (state 或 code)"} |
|
|
|
|
|
state = query_params["state"][0] |
|
|
code = query_params["code"][0] |
|
|
|
|
|
log.info(f"从URL解析到: state={state}, code=xxx...") |
|
|
|
|
|
|
|
|
if state not in auth_flows: |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"未找到对应的认证流程,请先启动认证 (state: {state})", |
|
|
} |
|
|
|
|
|
flow_data = auth_flows[state] |
|
|
flow = flow_data["flow"] |
|
|
|
|
|
|
|
|
redirect_uri = flow.redirect_uri |
|
|
log.info(f"使用redirect_uri: {redirect_uri}") |
|
|
|
|
|
try: |
|
|
|
|
|
credentials = await flow.exchange_code(code) |
|
|
log.info("成功获取访问令牌") |
|
|
|
|
|
|
|
|
cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode |
|
|
if cred_mode == "antigravity": |
|
|
log.info("Antigravity模式(从回调URL):从API获取project_id...") |
|
|
|
|
|
antigravity_url = await get_antigravity_api_url() |
|
|
project_id = await fetch_project_id( |
|
|
credentials.access_token, |
|
|
ANTIGRAVITY_USER_AGENT, |
|
|
antigravity_url |
|
|
) |
|
|
if project_id: |
|
|
log.info(f"成功从API获取project_id: {project_id}") |
|
|
else: |
|
|
log.warning("无法从API获取project_id,回退到随机生成") |
|
|
project_id = _generate_random_project_id() |
|
|
log.info(f"生成的随机project_id: {project_id}") |
|
|
|
|
|
|
|
|
saved_filename = await save_credentials(credentials, project_id, mode="antigravity") |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") |
|
|
|
|
|
|
|
|
_cleanup_auth_flow_server(state) |
|
|
|
|
|
log.info("从回调URL完成Antigravity OAuth认证成功,凭证已保存") |
|
|
return { |
|
|
"success": True, |
|
|
"credentials": creds_data, |
|
|
"file_path": saved_filename, |
|
|
"auto_detected_project": False, |
|
|
"mode": "antigravity", |
|
|
} |
|
|
|
|
|
|
|
|
detected_project_id = None |
|
|
auto_detected = False |
|
|
|
|
|
if not project_id: |
|
|
|
|
|
try: |
|
|
log.info("标准模式:从API获取project_id...") |
|
|
code_assist_url = await get_code_assist_endpoint() |
|
|
detected_project_id = await fetch_project_id( |
|
|
credentials.access_token, |
|
|
GEMINICLI_USER_AGENT, |
|
|
code_assist_url |
|
|
) |
|
|
if detected_project_id: |
|
|
auto_detected = True |
|
|
log.info(f"成功从API获取project_id: {detected_project_id}") |
|
|
else: |
|
|
log.warning("无法从API获取project_id,回退到项目列表获取方式") |
|
|
|
|
|
projects = await get_user_projects(credentials) |
|
|
if projects: |
|
|
if len(projects) == 1: |
|
|
|
|
|
|
|
|
detected_project_id = projects[0]["projectId"] |
|
|
auto_detected = True |
|
|
log.info(f"自动检测到唯一项目ID: {detected_project_id}") |
|
|
else: |
|
|
|
|
|
|
|
|
detected_project_id = projects[0]["projectId"] |
|
|
auto_detected = True |
|
|
log.info( |
|
|
f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}" |
|
|
) |
|
|
log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}") |
|
|
else: |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "未检测到可访问的项目,请检查权限或手动指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
except Exception as e: |
|
|
log.warning(f"自动检测项目ID失败: {e}") |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"自动检测项目ID失败: {str(e)},请手动指定项目ID", |
|
|
"requires_manual_project_id": True, |
|
|
} |
|
|
else: |
|
|
detected_project_id = project_id |
|
|
|
|
|
|
|
|
if detected_project_id: |
|
|
try: |
|
|
log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...") |
|
|
await enable_required_apis(credentials, detected_project_id) |
|
|
except Exception as e: |
|
|
log.warning(f"启用API服务失败: {e}") |
|
|
|
|
|
|
|
|
saved_filename = await save_credentials(credentials, detected_project_id) |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(credentials, detected_project_id, mode="geminicli") |
|
|
|
|
|
|
|
|
_cleanup_auth_flow_server(state) |
|
|
|
|
|
log.info("从回调URL完成OAuth认证成功,凭证已保存") |
|
|
return { |
|
|
"success": True, |
|
|
"credentials": creds_data, |
|
|
"file_path": saved_filename, |
|
|
"auto_detected_project": auto_detected, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"从回调URL获取凭证失败: {e}") |
|
|
return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"从回调URL完成认证流程失败: {e}") |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
|
|
|
async def save_credentials(creds: Credentials, project_id: str, mode: str = "geminicli") -> str: |
|
|
"""通过统一存储系统保存凭证""" |
|
|
|
|
|
timestamp = int(time.time()) |
|
|
|
|
|
|
|
|
if mode == "antigravity": |
|
|
filename = f"ag_{project_id}-{timestamp}.json" |
|
|
else: |
|
|
filename = f"{project_id}-{timestamp}.json" |
|
|
|
|
|
|
|
|
creds_data = _prepare_credentials_data(creds, project_id, mode) |
|
|
|
|
|
|
|
|
storage_adapter = await get_storage_adapter() |
|
|
success = await storage_adapter.store_credential(filename, creds_data, mode=mode) |
|
|
|
|
|
if success: |
|
|
|
|
|
try: |
|
|
default_state = { |
|
|
"error_codes": [], |
|
|
"disabled": False, |
|
|
"last_success": time.time(), |
|
|
"user_email": None, |
|
|
} |
|
|
await storage_adapter.update_credential_state(filename, default_state, mode=mode) |
|
|
log.info(f"凭证和状态已保存到: {filename} (mode={mode})") |
|
|
except Exception as e: |
|
|
log.warning(f"创建默认状态记录失败 {filename}: {e}") |
|
|
|
|
|
return filename |
|
|
else: |
|
|
raise Exception(f"保存凭证失败: {filename}") |
|
|
|
|
|
|
|
|
def async_shutdown_server(server, port): |
|
|
"""异步关闭OAuth回调服务器,避免阻塞主流程""" |
|
|
|
|
|
def shutdown_server_async(): |
|
|
try: |
|
|
|
|
|
shutdown_completed = threading.Event() |
|
|
|
|
|
def do_shutdown(): |
|
|
try: |
|
|
server.shutdown() |
|
|
server.server_close() |
|
|
shutdown_completed.set() |
|
|
log.info(f"已关闭端口 {port} 的OAuth回调服务器") |
|
|
except Exception as e: |
|
|
shutdown_completed.set() |
|
|
log.debug(f"关闭服务器时出错: {e}") |
|
|
|
|
|
|
|
|
shutdown_worker = threading.Thread(target=do_shutdown, daemon=True) |
|
|
shutdown_worker.start() |
|
|
|
|
|
|
|
|
if shutdown_completed.wait(timeout=5): |
|
|
log.debug(f"端口 {port} 服务器关闭完成") |
|
|
else: |
|
|
log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程") |
|
|
|
|
|
except Exception as e: |
|
|
log.debug(f"异步关闭服务器时出错: {e}") |
|
|
|
|
|
|
|
|
shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True) |
|
|
shutdown_thread.start() |
|
|
log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器") |
|
|
|
|
|
|
|
|
def cleanup_expired_flows(): |
|
|
"""清理过期的认证流程""" |
|
|
current_time = time.time() |
|
|
EXPIRY_TIME = 600 |
|
|
|
|
|
|
|
|
states_to_remove = [ |
|
|
state |
|
|
for state, flow_data in auth_flows.items() |
|
|
if current_time - flow_data["created_at"] > EXPIRY_TIME |
|
|
] |
|
|
|
|
|
|
|
|
cleaned_count = 0 |
|
|
for state in states_to_remove: |
|
|
flow_data = auth_flows.get(state) |
|
|
if flow_data: |
|
|
|
|
|
try: |
|
|
if flow_data.get("server"): |
|
|
server = flow_data["server"] |
|
|
port = flow_data.get("callback_port") |
|
|
async_shutdown_server(server, port) |
|
|
except Exception as e: |
|
|
log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}") |
|
|
|
|
|
|
|
|
flow_data.clear() |
|
|
del auth_flows[state] |
|
|
cleaned_count += 1 |
|
|
|
|
|
if cleaned_count > 0: |
|
|
log.info(f"清理了 {cleaned_count} 个过期的认证流程") |
|
|
|
|
|
|
|
|
if len(auth_flows) > 20: |
|
|
import gc |
|
|
|
|
|
gc.collect() |
|
|
log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}") |
|
|
|
|
|
|
|
|
def get_auth_status(project_id: str) -> Dict[str, Any]: |
|
|
"""获取认证状态""" |
|
|
for state, flow_data in auth_flows.items(): |
|
|
if flow_data["project_id"] == project_id: |
|
|
return { |
|
|
"status": "completed" if flow_data["completed"] else "pending", |
|
|
"state": state, |
|
|
"created_at": flow_data["created_at"], |
|
|
} |
|
|
|
|
|
return {"status": "not_found"} |
|
|
|
|
|
|
|
|
|
|
|
auth_tokens = {} |
|
|
TOKEN_EXPIRY = 3600 |
|
|
|
|
|
|
|
|
async def verify_password(password: str) -> bool: |
|
|
"""验证密码(面板登录使用)""" |
|
|
from config import get_panel_password |
|
|
|
|
|
correct_password = await get_panel_password() |
|
|
return password == correct_password |
|
|
|
|
|
|
|
|
def generate_auth_token() -> str: |
|
|
"""生成认证令牌""" |
|
|
|
|
|
cleanup_expired_tokens() |
|
|
|
|
|
token = secrets.token_urlsafe(32) |
|
|
|
|
|
auth_tokens[token] = time.time() |
|
|
return token |
|
|
|
|
|
|
|
|
def verify_auth_token(token: str) -> bool: |
|
|
"""验证认证令牌""" |
|
|
if not token or token not in auth_tokens: |
|
|
return False |
|
|
|
|
|
created_at = auth_tokens[token] |
|
|
|
|
|
|
|
|
if time.time() - created_at > TOKEN_EXPIRY: |
|
|
del auth_tokens[token] |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def cleanup_expired_tokens(): |
|
|
"""清理过期的认证令牌""" |
|
|
current_time = time.time() |
|
|
expired_tokens = [ |
|
|
token |
|
|
for token, created_at in auth_tokens.items() |
|
|
if current_time - created_at > TOKEN_EXPIRY |
|
|
] |
|
|
|
|
|
for token in expired_tokens: |
|
|
del auth_tokens[token] |
|
|
|
|
|
if expired_tokens: |
|
|
log.debug(f"清理了 {len(expired_tokens)} 个过期的认证令牌") |
|
|
|
|
|
|
|
|
def invalidate_auth_token(token: str): |
|
|
"""使认证令牌失效""" |
|
|
if token in auth_tokens: |
|
|
del auth_tokens[token] |
|
|
|
|
|
|
|
|
|
|
|
def validate_credential_content(content: str) -> Dict[str, Any]: |
|
|
"""验证凭证内容格式""" |
|
|
try: |
|
|
creds_data = json.loads(content) |
|
|
|
|
|
|
|
|
required_fields = ["client_id", "client_secret", "refresh_token", "token_uri"] |
|
|
missing_fields = [field for field in required_fields if field not in creds_data] |
|
|
|
|
|
if missing_fields: |
|
|
return {"valid": False, "error": f'缺少必要字段: {", ".join(missing_fields)}'} |
|
|
|
|
|
|
|
|
if "project_id" not in creds_data: |
|
|
log.warning("认证文件缺少project_id字段") |
|
|
|
|
|
return {"valid": True, "data": creds_data} |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
return {"valid": False, "error": f"JSON格式错误: {str(e)}"} |
|
|
except Exception as e: |
|
|
return {"valid": False, "error": f"文件验证失败: {str(e)}"} |
|
|
|
|
|
|
|
|
async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]: |
|
|
"""通过统一存储系统保存上传的凭证""" |
|
|
try: |
|
|
|
|
|
validation = validate_credential_content(content) |
|
|
if not validation["valid"]: |
|
|
return {"success": False, "error": validation["error"]} |
|
|
|
|
|
creds_data = validation["data"] |
|
|
|
|
|
|
|
|
project_id = creds_data.get("project_id", "unknown") |
|
|
timestamp = int(time.time()) |
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
base_name = os.path.splitext(original_filename)[0] |
|
|
filename = f"{base_name}-{timestamp}.json" |
|
|
|
|
|
|
|
|
storage_adapter = await get_storage_adapter() |
|
|
success = await storage_adapter.store_credential(filename, creds_data) |
|
|
|
|
|
if success: |
|
|
log.info(f"凭证文件已上传保存: {filename}") |
|
|
return {"success": True, "file_path": filename, "project_id": project_id} |
|
|
else: |
|
|
return {"success": False, "error": "保存到存储系统失败"} |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"保存上传文件失败: {e}") |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
|
|
|
async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str, Any]: |
|
|
"""批量上传凭证文件到统一存储系统""" |
|
|
results = [] |
|
|
success_count = 0 |
|
|
|
|
|
for file_data in files_data: |
|
|
filename = file_data.get("filename", "unknown.json") |
|
|
content = file_data.get("content", "") |
|
|
|
|
|
result = await save_uploaded_credential(content, filename) |
|
|
result["filename"] = filename |
|
|
results.append(result) |
|
|
|
|
|
if result["success"]: |
|
|
success_count += 1 |
|
|
|
|
|
return {"uploaded_count": success_count, "total_count": len(files_data), "results": results} |
|
|
|