gcli2api / src /credential_manager.py
lightspeed's picture
Upload 22 files
5868187 verified
"""
凭证管理器 - 完全基于统一存储中间层
"""
import asyncio
import time
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional, Tuple
from contextlib import asynccontextmanager
from config import get_calls_per_rotation, is_mongodb_mode
from log import log
from .storage_adapter import get_storage_adapter
from .google_oauth_api import fetch_user_email_from_file, Credentials
from .task_manager import task_manager
class CredentialManager:
"""
统一凭证管理器
所有存储操作通过storage_adapter进行
"""
def __init__(self):
# 核心状态
self._initialized = False
self._storage_adapter = None
# 凭证轮换相关
self._credential_files: List[str] = [] # 存储凭证文件名列表
self._current_credential_index = 0
self._call_count = 0
self._last_scan_time = 0
# 当前使用的凭证信息
self._current_credential_file: Optional[str] = None
self._current_credential_data: Optional[Dict[str, Any]] = None
self._current_credential_state: Dict[str, Any] = {}
# 并发控制
self._state_lock = asyncio.Lock()
self._operation_lock = asyncio.Lock()
# 工作线程控制
self._shutdown_event = asyncio.Event()
self._write_worker_running = False
self._write_worker_task = None
# 原子操作计数器
self._atomic_counter = 0
self._atomic_lock = asyncio.Lock()
# Onboarding state
self._onboarding_complete = False
self._onboarding_checked = False
async def initialize(self):
"""初始化凭证管理器"""
async with self._state_lock:
if self._initialized:
return
# 初始化统一存储适配器
self._storage_adapter = await get_storage_adapter()
# 启动后台工作线程
await self._start_background_workers()
# 发现并加载凭证
await self._discover_credentials()
self._initialized = True
storage_type = "MongoDB" if await is_mongodb_mode() else "File"
log.debug(f"Credential manager initialized with {storage_type} storage backend")
async def close(self):
"""清理资源"""
log.debug("Closing credential manager...")
# 设置关闭标志
self._shutdown_event.set()
# 等待后台任务结束
if self._write_worker_task:
try:
await asyncio.wait_for(self._write_worker_task, timeout=5.0)
except asyncio.TimeoutError:
log.warning("Write worker task did not finish within timeout")
if not self._write_worker_task.done():
self._write_worker_task.cancel()
self._initialized = False
log.debug("Credential manager closed")
async def _start_background_workers(self):
"""启动后台工作线程"""
if not self._write_worker_running:
self._write_worker_running = True
self._write_worker_task = task_manager.create_task(
self._background_worker(),
name="credential_background_worker"
)
async def _background_worker(self):
"""后台工作线程,处理定期任务"""
while not self._shutdown_event.is_set():
try:
# 每60秒检查一次凭证更新
await asyncio.wait_for(self._shutdown_event.wait(), timeout=60.0)
if self._shutdown_event.is_set():
break
# 重新发现凭证(热更新)
await self._discover_credentials()
except asyncio.TimeoutError:
# 超时是正常的,继续下一轮
continue
except Exception as e:
log.error(f"Background worker error: {e}")
await asyncio.sleep(5) # 错误后等待5秒再继续
async def _discover_credentials(self):
"""发现和加载所有可用凭证"""
try:
# 从存储适配器获取所有凭证
all_credentials = await self._storage_adapter.list_credentials()
# 过滤出可用的凭证(排除被禁用的)- 批量读取状态以提升性能
available_credentials = []
# 批量获取所有凭证状态,避免多次读取状态文件
if all_credentials:
try:
all_states = await self._storage_adapter.get_all_credential_states()
for credential_name in all_credentials:
normalized_name = credential_name
# 标准化文件名以匹配状态数据中的键
if hasattr(self._storage_adapter._backend, '_normalize_filename'):
normalized_name = self._storage_adapter._backend._normalize_filename(credential_name)
state = all_states.get(normalized_name, {})
if not state.get("disabled", False):
available_credentials.append(credential_name)
except Exception as e:
log.warning(f"Failed to batch load credential states, falling back to individual checks: {e}")
# 如果批量读取失败,回退到逐个检查
for credential_name in all_credentials:
try:
state = await self._storage_adapter.get_credential_state(credential_name)
if not state.get("disabled", False):
available_credentials.append(credential_name)
except Exception as e2:
log.warning(f"Failed to check state for credential {credential_name}: {e2}")
# 更新凭证列表
old_credentials = set(self._credential_files)
new_credentials = set(available_credentials)
if old_credentials != new_credentials:
# 记录变化(只在非初始状态时记录)
is_initial_load = len(old_credentials) == 0
added = new_credentials - old_credentials
removed = old_credentials - new_credentials
self._credential_files = available_credentials
# 初始加载时只记录调试信息,运行时变化才记录INFO
if not is_initial_load:
if added:
log.info(f"发现新的可用凭证: {list(added)}")
if removed:
log.info(f"移除不可用凭证: {list(removed)}")
else:
# 初始加载时只记录调试信息
if available_credentials:
log.debug(f"初始加载发现 {len(available_credentials)} 个可用凭证")
# 重置当前索引如果需要
if self._current_credential_index >= len(self._credential_files):
self._current_credential_index = 0
if not self._credential_files:
log.warning("No available credential files found")
else:
log.debug(f"Available credentials: {len(self._credential_files)} files")
except Exception as e:
log.error(f"Failed to discover credentials: {e}")
async def _load_current_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
"""加载当前选中的凭证数据,包含token过期检测和自动刷新"""
if not self._credential_files:
return None
try:
current_file = self._credential_files[self._current_credential_index]
# 从存储适配器加载凭证数据
credential_data = await self._storage_adapter.get_credential(current_file)
if not credential_data:
log.error(f"Failed to load credential data for: {current_file}")
return None
# 检查refresh_token
if "refresh_token" not in credential_data or not credential_data["refresh_token"]:
log.warning(f"No refresh token in {current_file}")
return None
# Auto-add 'type' field if missing but has required OAuth fields
if 'type' not in credential_data and all(key in credential_data for key in ['client_id', 'refresh_token']):
credential_data['type'] = 'authorized_user'
log.debug(f"Auto-added 'type' field to credential from file {current_file}")
# 兼容不同的token字段格式
if "access_token" in credential_data and "token" not in credential_data:
credential_data["token"] = credential_data["access_token"]
if "scope" in credential_data and "scopes" not in credential_data:
credential_data["scopes"] = credential_data["scope"].split()
# token过期检测和刷新
should_refresh = await self._should_refresh_token(credential_data)
if should_refresh:
log.debug(f"Token需要刷新 - 文件: {current_file}")
refreshed_data = await self._refresh_token(credential_data, current_file)
if refreshed_data:
credential_data = refreshed_data
log.debug(f"Token刷新成功: {current_file}")
else:
log.error(f"Token刷新失败: {current_file}")
return None
# 加载状态信息
state_data = await self._storage_adapter.get_credential_state(current_file)
# 缓存当前凭证信息
self._current_credential_file = current_file
self._current_credential_data = credential_data
self._current_credential_state = state_data
return current_file, credential_data
except Exception as e:
log.error(f"Error loading current credential: {e}")
return None
async def get_valid_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]:
"""获取有效的凭证,自动处理轮换和失效凭证切换"""
async with self._operation_lock:
if not self._credential_files:
await self._discover_credentials()
if not self._credential_files:
return None
# 检查是否需要轮换
if await self._should_rotate():
await self._rotate_credential()
# 尝试获取有效凭证,如果失败则自动切换
max_attempts = len(self._credential_files) # 最多尝试所有凭证
for attempt in range(max_attempts):
try:
# 加载当前凭证
result = await self._load_current_credential()
if result:
return result
# 当前凭证加载失败,标记为失效并切换到下一个
current_file = self._credential_files[self._current_credential_index] if self._credential_files else None
if current_file:
log.warning(f"凭证失效,自动禁用并切换: {current_file}")
await self.set_cred_disabled(current_file, True)
# 重新发现可用凭证(排除刚禁用的)
await self._discover_credentials()
if not self._credential_files:
log.error("没有可用的凭证")
return None
# 重置索引到第一个可用凭证
self._current_credential_index = 0
log.info(f"切换到下一个可用凭证 (索引: {self._current_credential_index})")
else:
log.error("无法获取当前凭证文件名")
break
except Exception as e:
log.error(f"获取凭证时发生异常 (尝试 {attempt + 1}/{max_attempts}): {e}")
if attempt < max_attempts - 1:
# 切换到下一个凭证继续尝试
await self._rotate_credential()
continue
log.error(f"所有 {max_attempts} 个凭证都尝试失败")
return None
async def _should_rotate(self) -> bool:
"""检查是否需要轮换凭证"""
if not self._credential_files or len(self._credential_files) <= 1:
return False
current_calls_per_rotation = await get_calls_per_rotation()
return self._call_count >= current_calls_per_rotation
async def _rotate_credential(self):
"""轮换到下一个凭证"""
if len(self._credential_files) <= 1:
return
self._current_credential_index = (self._current_credential_index + 1) % len(self._credential_files)
self._call_count = 0
log.info(f"Rotated to credential index {self._current_credential_index}")
async def force_rotate_credential(self):
"""强制轮换到下一个凭证(用于429错误处理)"""
async with self._operation_lock:
if len(self._credential_files) <= 1:
log.warning("Only one credential available, cannot rotate")
return
await self._rotate_credential()
log.info("Forced credential rotation due to rate limit")
def increment_call_count(self):
"""增加调用计数"""
self._call_count += 1
async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any]):
"""更新凭证状态"""
try:
# 直接通过存储适配器更新状态
success = await self._storage_adapter.update_credential_state(credential_name, state_updates)
# 如果是当前使用的凭证,更新缓存
if credential_name == self._current_credential_file:
self._current_credential_state.update(state_updates)
if success:
log.debug(f"Updated credential state: {credential_name}")
else:
log.warning(f"Failed to update credential state: {credential_name}")
return success
except Exception as e:
log.error(f"Error updating credential state {credential_name}: {e}")
return False
async def set_cred_disabled(self, credential_name: str, disabled: bool):
"""设置凭证的启用/禁用状态"""
try:
state_updates = {"disabled": disabled}
success = await self.update_credential_state(credential_name, state_updates)
if success:
# 如果禁用了当前正在使用的凭证,需要重新发现可用凭证
if disabled and credential_name == self._current_credential_file:
await self._discover_credentials()
if self._credential_files:
await self._rotate_credential()
action = "disabled" if disabled else "enabled"
log.info(f"Credential {action}: {credential_name}")
return success
except Exception as e:
log.error(f"Error setting credential disabled state {credential_name}: {e}")
return False
async def get_creds_status(self) -> Dict[str, Dict[str, Any]]:
"""获取所有凭证的状态"""
try:
# 从存储适配器获取所有状态
all_states = await self._storage_adapter.get_all_credential_states()
return all_states
except Exception as e:
log.error(f"Error getting credential statuses: {e}")
return {}
async def get_or_fetch_user_email(self, credential_name: str) -> Optional[str]:
"""获取或获取用户邮箱地址"""
try:
# 首先检查缓存的状态
state = await self._storage_adapter.get_credential_state(credential_name)
cached_email = state.get("user_email")
if cached_email:
return cached_email
# 如果没有缓存,从凭证数据获取
credential_data = await self._storage_adapter.get_credential(credential_name)
if not credential_data:
return None
# 尝试获取邮箱
email = await fetch_user_email_from_file(credential_data)
if email:
# 缓存邮箱地址
await self.update_credential_state(credential_name, {"user_email": email})
return email
return None
except Exception as e:
log.error(f"Error fetching user email for {credential_name}: {e}")
return None
async def record_api_call_result(self, credential_name: str, success: bool, error_code: Optional[int] = None):
"""记录API调用结果"""
try:
state_updates = {}
if success:
state_updates["last_success"] = time.time()
# 清除错误码(如果之前有的话)
state_updates["error_codes"] = []
elif error_code:
# 记录错误码
current_state = await self._storage_adapter.get_credential_state(credential_name)
error_codes = current_state.get("error_codes", [])
if error_code not in error_codes:
error_codes.append(error_code)
# 限制错误码列表长度
if len(error_codes) > 10:
error_codes = error_codes[-10:]
state_updates["error_codes"] = error_codes
if state_updates:
await self.update_credential_state(credential_name, state_updates)
except Exception as e:
log.error(f"Error recording API call result for {credential_name}: {e}")
# 原子操作支持
@asynccontextmanager
async def _atomic_operation(self, operation_name: str):
"""原子操作上下文管理器"""
async with self._atomic_lock:
self._atomic_counter += 1
operation_id = self._atomic_counter
log.debug(f"开始原子操作[{operation_id}]: {operation_name}")
try:
yield operation_id
log.debug(f"完成原子操作[{operation_id}]: {operation_name}")
except Exception as e:
log.error(f"原子操作[{operation_id}]失败: {operation_name} - {e}")
raise
async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool:
"""检查token是否需要刷新"""
try:
# 如果没有access_token或过期时间,需要刷新
if not credential_data.get("access_token") and not credential_data.get("token"):
log.debug("没有access_token,需要刷新")
return True
expiry_str = credential_data.get("expiry")
if not expiry_str:
log.debug("没有过期时间,需要刷新")
return True
# 解析过期时间
try:
if isinstance(expiry_str, str):
if "+" in expiry_str:
file_expiry = datetime.fromisoformat(expiry_str)
elif expiry_str.endswith("Z"):
file_expiry = datetime.fromisoformat(expiry_str.replace('Z', '+00:00'))
else:
file_expiry = datetime.fromisoformat(expiry_str)
else:
log.debug("过期时间格式无效,需要刷新")
return True
# 确保时区信息
if file_expiry.tzinfo is None:
file_expiry = file_expiry.replace(tzinfo=timezone.utc)
# 检查是否还有至少5分钟有效期
now = datetime.now(timezone.utc)
time_left = (file_expiry - now).total_seconds()
log.debug(f"Token剩余时间: {int(time_left/60)}分钟")
if time_left > 300: # 5分钟缓冲
return False
else:
log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新")
return True
except Exception as e:
log.warning(f"解析过期时间失败: {e},需要刷新")
return True
except Exception as e:
log.error(f"检查token过期时出错: {e}")
return True
async def _refresh_token(self, credential_data: Dict[str, Any], filename: str) -> Optional[Dict[str, Any]]:
"""刷新token并更新存储"""
try:
# 创建Credentials对象
creds = Credentials.from_dict(credential_data)
# 检查是否可以刷新
if not creds.refresh_token:
log.error(f"没有refresh_token,无法刷新: {filename}")
return None
# 刷新token
log.debug(f"正在刷新token: {filename}")
await creds.refresh()
# 更新凭证数据
if creds.access_token:
credential_data["access_token"] = creds.access_token
# 保持兼容性
credential_data["token"] = creds.access_token
if creds.expires_at:
credential_data["expiry"] = creds.expires_at.isoformat()
# 保存到存储
await self._storage_adapter.store_credential(filename, credential_data)
log.info(f"Token刷新成功并已保存: {filename}")
return credential_data
except Exception as e:
error_msg = str(e)
log.error(f"Token刷新失败 {filename}: {error_msg}")
# 检查是否是凭证永久失效的错误
is_permanent_failure = self._is_permanent_refresh_failure(error_msg)
if is_permanent_failure:
log.warning(f"检测到凭证永久失效: {filename}")
# 记录失效状态,但不在这里禁用凭证,让上层调用者处理
await self.record_api_call_result(filename, False, 400)
return None
def _is_permanent_refresh_failure(self, error_msg: str) -> bool:
"""判断是否是凭证永久失效的错误"""
# 常见的永久失效错误模式
permanent_error_patterns = [
"400 Bad Request",
"invalid_grant",
"refresh_token_expired",
"invalid_refresh_token",
"unauthorized_client",
"access_denied"
]
error_msg_lower = error_msg.lower()
for pattern in permanent_error_patterns:
if pattern.lower() in error_msg_lower:
return True
return False
# 兼容性方法 - 保持与现有代码的接口兼容
async def _update_token_in_file(self, file_path: str, new_token: str, expires_at=None):
"""更新凭证令牌(兼容性方法)"""
try:
credential_data = await self._storage_adapter.get_credential(file_path)
if not credential_data:
log.error(f"Credential not found for token update: {file_path}")
return False
# 更新令牌数据
credential_data["token"] = new_token
if expires_at:
credential_data["expiry"] = expires_at.isoformat() if hasattr(expires_at, 'isoformat') else expires_at
# 保存更新后的凭证
success = await self._storage_adapter.store_credential(file_path, credential_data)
if success:
log.debug(f"Token updated for credential: {file_path}")
else:
log.error(f"Failed to update token for credential: {file_path}")
return success
except Exception as e:
log.error(f"Error updating token for {file_path}: {e}")
return False
# 全局实例管理(保持兼容性)
_credential_manager: Optional[CredentialManager] = None
async def get_credential_manager() -> CredentialManager:
"""获取全局凭证管理器实例"""
global _credential_manager
if _credential_manager is None:
_credential_manager = CredentialManager()
await _credential_manager.initialize()
return _credential_manager