|
|
""" |
|
|
Google OAuth2 认证模块 |
|
|
""" |
|
|
|
|
|
import time |
|
|
import asyncio |
|
|
from datetime import datetime, timedelta, timezone |
|
|
from typing import Any, Dict, List, Optional |
|
|
from urllib.parse import urlencode |
|
|
|
|
|
import jwt |
|
|
|
|
|
from config import ( |
|
|
get_googleapis_proxy_url, |
|
|
get_oauth_proxy_url, |
|
|
get_resource_manager_api_url, |
|
|
get_service_usage_api_url, |
|
|
) |
|
|
from log import log |
|
|
|
|
|
from .httpx_client import get_async, post_async |
|
|
|
|
|
|
|
|
class TokenError(Exception): |
|
|
"""Token相关错误""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class Credentials: |
|
|
"""凭证类""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
access_token: str, |
|
|
refresh_token: str = None, |
|
|
client_id: str = None, |
|
|
client_secret: str = None, |
|
|
expires_at: datetime = None, |
|
|
project_id: str = None, |
|
|
): |
|
|
self.access_token = access_token |
|
|
self.refresh_token = refresh_token |
|
|
self.client_id = client_id |
|
|
self.client_secret = client_secret |
|
|
self.expires_at = expires_at |
|
|
self.project_id = project_id |
|
|
|
|
|
|
|
|
self.oauth_base_url = None |
|
|
self.token_endpoint = None |
|
|
|
|
|
def is_expired(self) -> bool: |
|
|
"""检查token是否过期""" |
|
|
if not self.expires_at: |
|
|
return True |
|
|
|
|
|
|
|
|
buffer = timedelta(minutes=3) |
|
|
return (self.expires_at - buffer) <= datetime.now(timezone.utc) |
|
|
|
|
|
async def refresh_if_needed(self) -> bool: |
|
|
"""如果需要则刷新token""" |
|
|
if not self.is_expired(): |
|
|
return False |
|
|
|
|
|
if not self.refresh_token: |
|
|
raise TokenError("需要刷新令牌但未提供") |
|
|
|
|
|
await self.refresh() |
|
|
return True |
|
|
|
|
|
async def refresh(self): |
|
|
"""刷新访问令牌""" |
|
|
if not self.refresh_token: |
|
|
raise TokenError("无刷新令牌") |
|
|
|
|
|
data = { |
|
|
"client_id": self.client_id, |
|
|
"client_secret": self.client_secret, |
|
|
"refresh_token": self.refresh_token, |
|
|
"grant_type": "refresh_token", |
|
|
} |
|
|
|
|
|
try: |
|
|
oauth_base_url = await get_oauth_proxy_url() |
|
|
token_url = f"{oauth_base_url.rstrip('/')}/token" |
|
|
response = await post_async( |
|
|
token_url, |
|
|
data=data, |
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
token_data = response.json() |
|
|
self.access_token = token_data["access_token"] |
|
|
|
|
|
if "expires_in" in token_data: |
|
|
expires_in = int(token_data["expires_in"]) |
|
|
current_utc = datetime.now(timezone.utc) |
|
|
self.expires_at = current_utc + timedelta(seconds=expires_in) |
|
|
log.debug( |
|
|
f"Token刷新: 当前UTC时间={current_utc.isoformat()}, " |
|
|
f"有效期={expires_in}秒, " |
|
|
f"过期时间={self.expires_at.isoformat()}" |
|
|
) |
|
|
|
|
|
if "refresh_token" in token_data: |
|
|
self.refresh_token = token_data["refresh_token"] |
|
|
|
|
|
log.debug(f"Token刷新成功,过期时间: {self.expires_at}") |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
status_code = None |
|
|
if hasattr(e, 'response') and hasattr(e.response, 'status_code'): |
|
|
status_code = e.response.status_code |
|
|
error_msg = f"Token刷新失败 (HTTP {status_code}): {error_msg}" |
|
|
else: |
|
|
error_msg = f"Token刷新失败: {error_msg}" |
|
|
|
|
|
log.error(error_msg) |
|
|
token_error = TokenError(error_msg) |
|
|
token_error.status_code = status_code |
|
|
raise token_error |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Credentials": |
|
|
"""从字典创建凭证""" |
|
|
|
|
|
expires_at = None |
|
|
if "expiry" in data and data["expiry"]: |
|
|
try: |
|
|
expiry_str = data["expiry"] |
|
|
if isinstance(expiry_str, str): |
|
|
if expiry_str.endswith("Z"): |
|
|
expires_at = datetime.fromisoformat(expiry_str.replace("Z", "+00:00")) |
|
|
elif "+" in expiry_str: |
|
|
expires_at = datetime.fromisoformat(expiry_str) |
|
|
else: |
|
|
expires_at = datetime.fromisoformat(expiry_str).replace(tzinfo=timezone.utc) |
|
|
except ValueError: |
|
|
log.warning(f"无法解析过期时间: {expiry_str}") |
|
|
|
|
|
return cls( |
|
|
access_token=data.get("token") or data.get("access_token", ""), |
|
|
refresh_token=data.get("refresh_token"), |
|
|
client_id=data.get("client_id"), |
|
|
client_secret=data.get("client_secret"), |
|
|
expires_at=expires_at, |
|
|
project_id=data.get("project_id"), |
|
|
) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""转为字典""" |
|
|
result = { |
|
|
"access_token": self.access_token, |
|
|
"refresh_token": self.refresh_token, |
|
|
"client_id": self.client_id, |
|
|
"client_secret": self.client_secret, |
|
|
"project_id": self.project_id, |
|
|
} |
|
|
|
|
|
if self.expires_at: |
|
|
result["expiry"] = self.expires_at.isoformat() |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class Flow: |
|
|
"""OAuth流程类""" |
|
|
|
|
|
def __init__( |
|
|
self, client_id: str, client_secret: str, scopes: List[str], redirect_uri: str = None |
|
|
): |
|
|
self.client_id = client_id |
|
|
self.client_secret = client_secret |
|
|
self.scopes = scopes |
|
|
self.redirect_uri = redirect_uri |
|
|
|
|
|
|
|
|
self.oauth_base_url = None |
|
|
self.token_endpoint = None |
|
|
self.auth_endpoint = "https://accounts.google.com/o/oauth2/auth" |
|
|
|
|
|
self.credentials: Optional[Credentials] = None |
|
|
|
|
|
def get_auth_url(self, state: str = None, **kwargs) -> str: |
|
|
"""生成授权URL""" |
|
|
params = { |
|
|
"client_id": self.client_id, |
|
|
"redirect_uri": self.redirect_uri, |
|
|
"scope": " ".join(self.scopes), |
|
|
"response_type": "code", |
|
|
"access_type": "offline", |
|
|
"prompt": "consent", |
|
|
"include_granted_scopes": "true", |
|
|
} |
|
|
|
|
|
if state: |
|
|
params["state"] = state |
|
|
|
|
|
params.update(kwargs) |
|
|
return f"{self.auth_endpoint}?{urlencode(params)}" |
|
|
|
|
|
async def exchange_code(self, code: str) -> Credentials: |
|
|
"""用授权码换取token""" |
|
|
data = { |
|
|
"client_id": self.client_id, |
|
|
"client_secret": self.client_secret, |
|
|
"redirect_uri": self.redirect_uri, |
|
|
"code": code, |
|
|
"grant_type": "authorization_code", |
|
|
} |
|
|
|
|
|
try: |
|
|
oauth_base_url = await get_oauth_proxy_url() |
|
|
token_url = f"{oauth_base_url.rstrip('/')}/token" |
|
|
response = await post_async( |
|
|
token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
token_data = response.json() |
|
|
|
|
|
|
|
|
expires_at = None |
|
|
if "expires_in" in token_data: |
|
|
expires_in = int(token_data["expires_in"]) |
|
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) |
|
|
|
|
|
|
|
|
self.credentials = Credentials( |
|
|
access_token=token_data["access_token"], |
|
|
refresh_token=token_data.get("refresh_token"), |
|
|
client_id=self.client_id, |
|
|
client_secret=self.client_secret, |
|
|
expires_at=expires_at, |
|
|
) |
|
|
|
|
|
return self.credentials |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"获取token失败: {str(e)}" |
|
|
log.error(error_msg) |
|
|
raise TokenError(error_msg) |
|
|
|
|
|
|
|
|
class ServiceAccount: |
|
|
"""Service Account类""" |
|
|
|
|
|
def __init__( |
|
|
self, email: str, private_key: str, project_id: str = None, scopes: List[str] = None |
|
|
): |
|
|
self.email = email |
|
|
self.private_key = private_key |
|
|
self.project_id = project_id |
|
|
self.scopes = scopes or [] |
|
|
|
|
|
|
|
|
self.oauth_base_url = None |
|
|
self.token_endpoint = None |
|
|
|
|
|
self.access_token: Optional[str] = None |
|
|
self.expires_at: Optional[datetime] = None |
|
|
|
|
|
def is_expired(self) -> bool: |
|
|
"""检查token是否过期""" |
|
|
if not self.expires_at: |
|
|
return True |
|
|
|
|
|
buffer = timedelta(minutes=3) |
|
|
return (self.expires_at - buffer) <= datetime.now(timezone.utc) |
|
|
|
|
|
def create_jwt(self) -> str: |
|
|
"""创建JWT令牌""" |
|
|
now = int(time.time()) |
|
|
|
|
|
payload = { |
|
|
"iss": self.email, |
|
|
"scope": " ".join(self.scopes) if self.scopes else "", |
|
|
"aud": self.token_endpoint, |
|
|
"exp": now + 3600, |
|
|
"iat": now, |
|
|
} |
|
|
|
|
|
return jwt.encode(payload, self.private_key, algorithm="RS256") |
|
|
|
|
|
async def get_access_token(self) -> str: |
|
|
"""获取访问令牌""" |
|
|
if not self.is_expired() and self.access_token: |
|
|
return self.access_token |
|
|
|
|
|
assertion = self.create_jwt() |
|
|
|
|
|
data = {"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": assertion} |
|
|
|
|
|
try: |
|
|
oauth_base_url = await get_oauth_proxy_url() |
|
|
token_url = f"{oauth_base_url.rstrip('/')}/token" |
|
|
response = await post_async( |
|
|
token_url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"} |
|
|
) |
|
|
response.raise_for_status() |
|
|
|
|
|
token_data = response.json() |
|
|
self.access_token = token_data["access_token"] |
|
|
|
|
|
if "expires_in" in token_data: |
|
|
expires_in = int(token_data["expires_in"]) |
|
|
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) |
|
|
|
|
|
return self.access_token |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Service Account获取token失败: {str(e)}" |
|
|
log.error(error_msg) |
|
|
raise TokenError(error_msg) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any], scopes: List[str] = None) -> "ServiceAccount": |
|
|
"""从字典创建Service Account凭证""" |
|
|
return cls( |
|
|
email=data["client_email"], |
|
|
private_key=data["private_key"], |
|
|
project_id=data.get("project_id"), |
|
|
scopes=scopes, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
async def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]: |
|
|
"""获取用户信息""" |
|
|
await credentials.refresh_if_needed() |
|
|
|
|
|
try: |
|
|
googleapis_base_url = await get_googleapis_proxy_url() |
|
|
userinfo_url = f"{googleapis_base_url.rstrip('/')}/oauth2/v2/userinfo" |
|
|
response = await get_async( |
|
|
userinfo_url, headers={"Authorization": f"Bearer {credentials.access_token}"} |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
log.error(f"获取用户信息失败: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def get_user_email(credentials: Credentials) -> Optional[str]: |
|
|
"""获取用户邮箱地址""" |
|
|
try: |
|
|
|
|
|
await credentials.refresh_if_needed() |
|
|
|
|
|
|
|
|
user_info = await get_user_info(credentials) |
|
|
if user_info: |
|
|
email = user_info.get("email") |
|
|
if email: |
|
|
log.info(f"成功获取邮箱地址: {email}") |
|
|
return email |
|
|
else: |
|
|
log.warning(f"userinfo响应中没有邮箱信息: {user_info}") |
|
|
return None |
|
|
else: |
|
|
log.warning("获取用户信息失败") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"获取用户邮箱失败: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def fetch_user_email_from_file(cred_data: Dict[str, Any]) -> Optional[str]: |
|
|
"""从凭证数据获取用户邮箱地址(支持统一存储)""" |
|
|
try: |
|
|
|
|
|
credentials = Credentials.from_dict(cred_data) |
|
|
if not credentials or not credentials.access_token: |
|
|
log.warning("无法从凭证数据创建凭证对象或获取访问令牌") |
|
|
return None |
|
|
|
|
|
|
|
|
return await get_user_email(credentials) |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"从凭证数据获取用户邮箱失败: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def validate_token(token: str) -> Optional[Dict[str, Any]]: |
|
|
"""验证访问令牌""" |
|
|
try: |
|
|
oauth_base_url = await get_oauth_proxy_url() |
|
|
tokeninfo_url = f"{oauth_base_url.rstrip('/')}/tokeninfo?access_token={token}" |
|
|
|
|
|
response = await get_async(tokeninfo_url) |
|
|
response.raise_for_status() |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
log.error(f"验证令牌失败: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def enable_required_apis(credentials: Credentials, project_id: str) -> bool: |
|
|
"""自动启用必需的API服务""" |
|
|
try: |
|
|
|
|
|
if credentials.is_expired() and credentials.refresh_token: |
|
|
await credentials.refresh() |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {credentials.access_token}", |
|
|
"Content-Type": "application/json", |
|
|
"User-Agent": "geminicli-oauth/1.0", |
|
|
} |
|
|
|
|
|
|
|
|
required_services = [ |
|
|
"geminicloudassist.googleapis.com", |
|
|
"cloudaicompanion.googleapis.com", |
|
|
] |
|
|
|
|
|
for service in required_services: |
|
|
log.info(f"正在检查并启用服务: {service}") |
|
|
|
|
|
|
|
|
service_usage_base_url = await get_service_usage_api_url() |
|
|
check_url = ( |
|
|
f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}" |
|
|
) |
|
|
try: |
|
|
check_response = await get_async(check_url, headers=headers) |
|
|
if check_response.status_code == 200: |
|
|
service_data = check_response.json() |
|
|
if service_data.get("state") == "ENABLED": |
|
|
log.info(f"服务 {service} 已启用") |
|
|
continue |
|
|
except Exception as e: |
|
|
log.debug(f"检查服务状态失败,将尝试启用: {e}") |
|
|
|
|
|
|
|
|
enable_url = f"{service_usage_base_url.rstrip('/')}/v1/projects/{project_id}/services/{service}:enable" |
|
|
try: |
|
|
enable_response = await post_async(enable_url, headers=headers, json={}) |
|
|
|
|
|
if enable_response.status_code in [200, 201]: |
|
|
log.info(f"✅ 成功启用服务: {service}") |
|
|
elif enable_response.status_code == 400: |
|
|
error_data = enable_response.json() |
|
|
if "already enabled" in error_data.get("error", {}).get("message", "").lower(): |
|
|
log.info(f"✅ 服务 {service} 已经启用") |
|
|
else: |
|
|
log.warning(f"⚠️ 启用服务 {service} 时出现警告: {error_data}") |
|
|
else: |
|
|
log.warning( |
|
|
f"⚠️ 启用服务 {service} 失败: {enable_response.status_code} - {enable_response.text}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
log.warning(f"⚠️ 启用服务 {service} 时发生异常: {e}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"启用API服务时发生错误: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
async def get_user_projects(credentials: Credentials) -> List[Dict[str, Any]]: |
|
|
"""获取用户可访问的Google Cloud项目列表""" |
|
|
try: |
|
|
|
|
|
if credentials.is_expired() and credentials.refresh_token: |
|
|
await credentials.refresh() |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {credentials.access_token}", |
|
|
"User-Agent": "geminicli-oauth/1.0", |
|
|
} |
|
|
|
|
|
|
|
|
resource_manager_base_url = await get_resource_manager_api_url() |
|
|
url = f"{resource_manager_base_url.rstrip('/')}/v1/projects" |
|
|
log.info(f"正在调用API: {url}") |
|
|
response = await get_async(url, headers=headers) |
|
|
|
|
|
log.info(f"API响应状态码: {response.status_code}") |
|
|
if response.status_code != 200: |
|
|
log.error(f"API响应内容: {response.text}") |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
projects = data.get("projects", []) |
|
|
|
|
|
active_projects = [ |
|
|
project for project in projects if project.get("lifecycleState") == "ACTIVE" |
|
|
] |
|
|
log.info(f"获取到 {len(active_projects)} 个活跃项目") |
|
|
return active_projects |
|
|
else: |
|
|
log.warning(f"获取项目列表失败: {response.status_code} - {response.text}") |
|
|
return [] |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"获取用户项目列表失败: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
async def select_default_project(projects: List[Dict[str, Any]]) -> Optional[str]: |
|
|
"""从项目列表中选择默认项目""" |
|
|
if not projects: |
|
|
return None |
|
|
|
|
|
|
|
|
for project in projects: |
|
|
display_name = project.get("displayName", "").lower() |
|
|
|
|
|
project_id = project.get("projectId", "") |
|
|
if "default" in display_name or "default" in project_id.lower(): |
|
|
log.info(f"选择默认项目: {project_id} ({project.get('displayName', project_id)})") |
|
|
return project_id |
|
|
|
|
|
|
|
|
first_project = projects[0] |
|
|
|
|
|
project_id = first_project.get("projectId", "") |
|
|
log.info( |
|
|
f"选择第一个项目作为默认: {project_id} ({first_project.get('displayName', project_id)})" |
|
|
) |
|
|
return project_id |
|
|
|
|
|
|
|
|
async def fetch_project_id( |
|
|
access_token: str, |
|
|
user_agent: str, |
|
|
api_base_url: str |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
从 API 获取 project_id,如果 loadCodeAssist 失败则回退到 onboardUser |
|
|
|
|
|
Args: |
|
|
access_token: Google OAuth access token |
|
|
user_agent: User-Agent header |
|
|
api_base_url: API base URL (e.g., antigravity or code assist endpoint) |
|
|
|
|
|
Returns: |
|
|
project_id 字符串,如果获取失败返回 None |
|
|
""" |
|
|
headers = { |
|
|
'User-Agent': user_agent, |
|
|
'Authorization': f'Bearer {access_token}', |
|
|
'Content-Type': 'application/json', |
|
|
'Accept-Encoding': 'gzip' |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
project_id = await _try_load_code_assist(api_base_url, headers) |
|
|
if project_id: |
|
|
return project_id |
|
|
|
|
|
log.warning("[fetch_project_id] loadCodeAssist did not return project_id, falling back to onboardUser") |
|
|
|
|
|
except Exception as e: |
|
|
log.warning(f"[fetch_project_id] loadCodeAssist failed: {type(e).__name__}: {e}") |
|
|
log.warning("[fetch_project_id] Falling back to onboardUser") |
|
|
|
|
|
|
|
|
try: |
|
|
project_id = await _try_onboard_user(api_base_url, headers) |
|
|
if project_id: |
|
|
return project_id |
|
|
|
|
|
log.error("[fetch_project_id] Failed to get project_id from both loadCodeAssist and onboardUser") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
log.error(f"[fetch_project_id] onboardUser failed: {type(e).__name__}: {e}") |
|
|
import traceback |
|
|
log.debug(f"[fetch_project_id] Traceback: {traceback.format_exc()}") |
|
|
return None |
|
|
|
|
|
|
|
|
async def _try_load_code_assist( |
|
|
api_base_url: str, |
|
|
headers: dict |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
尝试通过 loadCodeAssist 获取 project_id |
|
|
|
|
|
Returns: |
|
|
project_id 或 None |
|
|
""" |
|
|
request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" |
|
|
request_body = { |
|
|
"metadata": { |
|
|
"ideType": "ANTIGRAVITY", |
|
|
"platform": "PLATFORM_UNSPECIFIED", |
|
|
"pluginType": "GEMINI" |
|
|
} |
|
|
} |
|
|
|
|
|
log.debug(f"[loadCodeAssist] Fetching project_id from: {request_url}") |
|
|
log.debug(f"[loadCodeAssist] Request body: {request_body}") |
|
|
|
|
|
response = await post_async( |
|
|
request_url, |
|
|
json=request_body, |
|
|
headers=headers, |
|
|
timeout=30.0, |
|
|
) |
|
|
|
|
|
log.debug(f"[loadCodeAssist] Response status: {response.status_code}") |
|
|
|
|
|
if response.status_code == 200: |
|
|
response_text = response.text |
|
|
log.debug(f"[loadCodeAssist] Response body: {response_text}") |
|
|
|
|
|
data = response.json() |
|
|
log.debug(f"[loadCodeAssist] Response JSON keys: {list(data.keys())}") |
|
|
|
|
|
|
|
|
current_tier = data.get("currentTier") |
|
|
if current_tier: |
|
|
log.info("[loadCodeAssist] User is already activated") |
|
|
|
|
|
|
|
|
project_id = data.get("cloudaicompanionProject") |
|
|
if project_id: |
|
|
log.info(f"[loadCodeAssist] Successfully fetched project_id: {project_id}") |
|
|
return project_id |
|
|
|
|
|
log.warning("[loadCodeAssist] No project_id in response") |
|
|
return None |
|
|
else: |
|
|
log.info("[loadCodeAssist] User not activated yet (no currentTier)") |
|
|
return None |
|
|
else: |
|
|
log.warning(f"[loadCodeAssist] Failed: HTTP {response.status_code}") |
|
|
log.warning(f"[loadCodeAssist] Response body: {response.text[:500]}") |
|
|
raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") |
|
|
|
|
|
|
|
|
async def _try_onboard_user( |
|
|
api_base_url: str, |
|
|
headers: dict |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
尝试通过 onboardUser 获取 project_id(长时间运行操作,需要轮询) |
|
|
|
|
|
Returns: |
|
|
project_id 或 None |
|
|
""" |
|
|
request_url = f"{api_base_url.rstrip('/')}/v1internal:onboardUser" |
|
|
|
|
|
|
|
|
tier_id = await _get_onboard_tier(api_base_url, headers) |
|
|
if not tier_id: |
|
|
log.error("[onboardUser] Failed to determine user tier") |
|
|
return None |
|
|
|
|
|
log.info(f"[onboardUser] User tier: {tier_id}") |
|
|
|
|
|
|
|
|
|
|
|
request_body = { |
|
|
"tierId": tier_id, |
|
|
"metadata": { |
|
|
"ideType": "ANTIGRAVITY", |
|
|
"platform": "PLATFORM_UNSPECIFIED", |
|
|
"pluginType": "GEMINI" |
|
|
} |
|
|
} |
|
|
|
|
|
log.debug(f"[onboardUser] Request URL: {request_url}") |
|
|
log.debug(f"[onboardUser] Request body: {request_body}") |
|
|
|
|
|
|
|
|
|
|
|
max_attempts = 5 |
|
|
attempt = 0 |
|
|
|
|
|
while attempt < max_attempts: |
|
|
attempt += 1 |
|
|
log.debug(f"[onboardUser] Polling attempt {attempt}/{max_attempts}") |
|
|
|
|
|
response = await post_async( |
|
|
request_url, |
|
|
json=request_body, |
|
|
headers=headers, |
|
|
timeout=30.0, |
|
|
) |
|
|
|
|
|
log.debug(f"[onboardUser] Response status: {response.status_code}") |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
log.debug(f"[onboardUser] Response data: {data}") |
|
|
|
|
|
|
|
|
if data.get("done"): |
|
|
log.info("[onboardUser] Operation completed") |
|
|
|
|
|
|
|
|
response_data = data.get("response", {}) |
|
|
project_obj = response_data.get("cloudaicompanionProject", {}) |
|
|
|
|
|
if isinstance(project_obj, dict): |
|
|
project_id = project_obj.get("id") |
|
|
elif isinstance(project_obj, str): |
|
|
project_id = project_obj |
|
|
else: |
|
|
project_id = None |
|
|
|
|
|
if project_id: |
|
|
log.info(f"[onboardUser] Successfully fetched project_id: {project_id}") |
|
|
return project_id |
|
|
else: |
|
|
log.warning("[onboardUser] Operation completed but no project_id in response") |
|
|
return None |
|
|
else: |
|
|
log.debug("[onboardUser] Operation still in progress, waiting 2 seconds...") |
|
|
await asyncio.sleep(2) |
|
|
else: |
|
|
log.warning(f"[onboardUser] Failed: HTTP {response.status_code}") |
|
|
log.warning(f"[onboardUser] Response body: {response.text[:500]}") |
|
|
raise Exception(f"HTTP {response.status_code}: {response.text[:200]}") |
|
|
|
|
|
log.error("[onboardUser] Timeout: Operation did not complete within 10 seconds") |
|
|
return None |
|
|
|
|
|
|
|
|
async def _get_onboard_tier( |
|
|
api_base_url: str, |
|
|
headers: dict |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
从 loadCodeAssist 响应中获取用户应该注册的 tier |
|
|
|
|
|
Returns: |
|
|
tier_id (如 "FREE", "STANDARD", "LEGACY") 或 None |
|
|
""" |
|
|
request_url = f"{api_base_url.rstrip('/')}/v1internal:loadCodeAssist" |
|
|
request_body = { |
|
|
"metadata": { |
|
|
"ideType": "ANTIGRAVITY", |
|
|
"platform": "PLATFORM_UNSPECIFIED", |
|
|
"pluginType": "GEMINI" |
|
|
} |
|
|
} |
|
|
|
|
|
log.debug(f"[_get_onboard_tier] Fetching tier info from: {request_url}") |
|
|
|
|
|
response = await post_async( |
|
|
request_url, |
|
|
json=request_body, |
|
|
headers=headers, |
|
|
timeout=30.0, |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
log.debug(f"[_get_onboard_tier] Response data: {data}") |
|
|
|
|
|
|
|
|
allowed_tiers = data.get("allowedTiers", []) |
|
|
for tier in allowed_tiers: |
|
|
if tier.get("isDefault"): |
|
|
tier_id = tier.get("id") |
|
|
log.info(f"[_get_onboard_tier] Found default tier: {tier_id}") |
|
|
return tier_id |
|
|
|
|
|
|
|
|
log.warning("[_get_onboard_tier] No default tier found, using LEGACY") |
|
|
return "LEGACY" |
|
|
else: |
|
|
log.error(f"[_get_onboard_tier] Failed to fetch tier info: HTTP {response.status_code}") |
|
|
return None |
|
|
|
|
|
|
|
|
|