Spaces:
Sleeping
Sleeping
| """ | |
| 凭证管理器 - 完全基于统一存储中间层 | |
| """ | |
| import asyncio | |
| import time | |
| from datetime import datetime, timezone | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from contextlib import asynccontextmanager | |
| from config import get_calls_per_rotation, is_mongodb_mode | |
| from log import log | |
| from .storage_adapter import get_storage_adapter | |
| from .google_oauth_api import fetch_user_email_from_file, Credentials | |
| from .task_manager import task_manager | |
| class CredentialManager: | |
| """ | |
| 统一凭证管理器 | |
| 所有存储操作通过storage_adapter进行 | |
| """ | |
| def __init__(self): | |
| # 核心状态 | |
| self._initialized = False | |
| self._storage_adapter = None | |
| # 凭证轮换相关 | |
| self._credential_files: List[str] = [] # 存储凭证文件名列表 | |
| self._current_credential_index = 0 | |
| self._call_count = 0 | |
| self._last_scan_time = 0 | |
| # 当前使用的凭证信息 | |
| self._current_credential_file: Optional[str] = None | |
| self._current_credential_data: Optional[Dict[str, Any]] = None | |
| self._current_credential_state: Dict[str, Any] = {} | |
| # 并发控制 | |
| self._state_lock = asyncio.Lock() | |
| self._operation_lock = asyncio.Lock() | |
| # 工作线程控制 | |
| self._shutdown_event = asyncio.Event() | |
| self._write_worker_running = False | |
| self._write_worker_task = None | |
| # 原子操作计数器 | |
| self._atomic_counter = 0 | |
| self._atomic_lock = asyncio.Lock() | |
| # Onboarding state | |
| self._onboarding_complete = False | |
| self._onboarding_checked = False | |
| async def initialize(self): | |
| """初始化凭证管理器""" | |
| async with self._state_lock: | |
| if self._initialized: | |
| return | |
| # 初始化统一存储适配器 | |
| self._storage_adapter = await get_storage_adapter() | |
| # 启动后台工作线程 | |
| await self._start_background_workers() | |
| # 发现并加载凭证 | |
| await self._discover_credentials() | |
| self._initialized = True | |
| storage_type = "MongoDB" if await is_mongodb_mode() else "File" | |
| log.debug(f"Credential manager initialized with {storage_type} storage backend") | |
| async def close(self): | |
| """清理资源""" | |
| log.debug("Closing credential manager...") | |
| # 设置关闭标志 | |
| self._shutdown_event.set() | |
| # 等待后台任务结束 | |
| if self._write_worker_task: | |
| try: | |
| await asyncio.wait_for(self._write_worker_task, timeout=5.0) | |
| except asyncio.TimeoutError: | |
| log.warning("Write worker task did not finish within timeout") | |
| if not self._write_worker_task.done(): | |
| self._write_worker_task.cancel() | |
| self._initialized = False | |
| log.debug("Credential manager closed") | |
| async def _start_background_workers(self): | |
| """启动后台工作线程""" | |
| if not self._write_worker_running: | |
| self._write_worker_running = True | |
| self._write_worker_task = task_manager.create_task( | |
| self._background_worker(), | |
| name="credential_background_worker" | |
| ) | |
| async def _background_worker(self): | |
| """后台工作线程,处理定期任务""" | |
| while not self._shutdown_event.is_set(): | |
| try: | |
| # 每60秒检查一次凭证更新 | |
| await asyncio.wait_for(self._shutdown_event.wait(), timeout=60.0) | |
| if self._shutdown_event.is_set(): | |
| break | |
| # 重新发现凭证(热更新) | |
| await self._discover_credentials() | |
| except asyncio.TimeoutError: | |
| # 超时是正常的,继续下一轮 | |
| continue | |
| except Exception as e: | |
| log.error(f"Background worker error: {e}") | |
| await asyncio.sleep(5) # 错误后等待5秒再继续 | |
| async def _discover_credentials(self): | |
| """发现和加载所有可用凭证""" | |
| try: | |
| # 从存储适配器获取所有凭证 | |
| all_credentials = await self._storage_adapter.list_credentials() | |
| # 过滤出可用的凭证(排除被禁用的)- 批量读取状态以提升性能 | |
| available_credentials = [] | |
| # 批量获取所有凭证状态,避免多次读取状态文件 | |
| if all_credentials: | |
| try: | |
| all_states = await self._storage_adapter.get_all_credential_states() | |
| for credential_name in all_credentials: | |
| normalized_name = credential_name | |
| # 标准化文件名以匹配状态数据中的键 | |
| if hasattr(self._storage_adapter._backend, '_normalize_filename'): | |
| normalized_name = self._storage_adapter._backend._normalize_filename(credential_name) | |
| state = all_states.get(normalized_name, {}) | |
| if not state.get("disabled", False): | |
| available_credentials.append(credential_name) | |
| except Exception as e: | |
| log.warning(f"Failed to batch load credential states, falling back to individual checks: {e}") | |
| # 如果批量读取失败,回退到逐个检查 | |
| for credential_name in all_credentials: | |
| try: | |
| state = await self._storage_adapter.get_credential_state(credential_name) | |
| if not state.get("disabled", False): | |
| available_credentials.append(credential_name) | |
| except Exception as e2: | |
| log.warning(f"Failed to check state for credential {credential_name}: {e2}") | |
| # 更新凭证列表 | |
| old_credentials = set(self._credential_files) | |
| new_credentials = set(available_credentials) | |
| if old_credentials != new_credentials: | |
| # 记录变化(只在非初始状态时记录) | |
| is_initial_load = len(old_credentials) == 0 | |
| added = new_credentials - old_credentials | |
| removed = old_credentials - new_credentials | |
| self._credential_files = available_credentials | |
| # 初始加载时只记录调试信息,运行时变化才记录INFO | |
| if not is_initial_load: | |
| if added: | |
| log.info(f"发现新的可用凭证: {list(added)}") | |
| if removed: | |
| log.info(f"移除不可用凭证: {list(removed)}") | |
| else: | |
| # 初始加载时只记录调试信息 | |
| if available_credentials: | |
| log.debug(f"初始加载发现 {len(available_credentials)} 个可用凭证") | |
| # 重置当前索引如果需要 | |
| if self._current_credential_index >= len(self._credential_files): | |
| self._current_credential_index = 0 | |
| if not self._credential_files: | |
| log.warning("No available credential files found") | |
| else: | |
| log.debug(f"Available credentials: {len(self._credential_files)} files") | |
| except Exception as e: | |
| log.error(f"Failed to discover credentials: {e}") | |
| async def _load_current_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]: | |
| """加载当前选中的凭证数据,包含token过期检测和自动刷新""" | |
| if not self._credential_files: | |
| return None | |
| try: | |
| current_file = self._credential_files[self._current_credential_index] | |
| # 从存储适配器加载凭证数据 | |
| credential_data = await self._storage_adapter.get_credential(current_file) | |
| if not credential_data: | |
| log.error(f"Failed to load credential data for: {current_file}") | |
| return None | |
| # 检查refresh_token | |
| if "refresh_token" not in credential_data or not credential_data["refresh_token"]: | |
| log.warning(f"No refresh token in {current_file}") | |
| return None | |
| # Auto-add 'type' field if missing but has required OAuth fields | |
| if 'type' not in credential_data and all(key in credential_data for key in ['client_id', 'refresh_token']): | |
| credential_data['type'] = 'authorized_user' | |
| log.debug(f"Auto-added 'type' field to credential from file {current_file}") | |
| # 兼容不同的token字段格式 | |
| if "access_token" in credential_data and "token" not in credential_data: | |
| credential_data["token"] = credential_data["access_token"] | |
| if "scope" in credential_data and "scopes" not in credential_data: | |
| credential_data["scopes"] = credential_data["scope"].split() | |
| # token过期检测和刷新 | |
| should_refresh = await self._should_refresh_token(credential_data) | |
| if should_refresh: | |
| log.debug(f"Token需要刷新 - 文件: {current_file}") | |
| refreshed_data = await self._refresh_token(credential_data, current_file) | |
| if refreshed_data: | |
| credential_data = refreshed_data | |
| log.debug(f"Token刷新成功: {current_file}") | |
| else: | |
| log.error(f"Token刷新失败: {current_file}") | |
| return None | |
| # 加载状态信息 | |
| state_data = await self._storage_adapter.get_credential_state(current_file) | |
| # 缓存当前凭证信息 | |
| self._current_credential_file = current_file | |
| self._current_credential_data = credential_data | |
| self._current_credential_state = state_data | |
| return current_file, credential_data | |
| except Exception as e: | |
| log.error(f"Error loading current credential: {e}") | |
| return None | |
| async def get_valid_credential(self) -> Optional[Tuple[str, Dict[str, Any]]]: | |
| """获取有效的凭证,自动处理轮换和失效凭证切换""" | |
| async with self._operation_lock: | |
| if not self._credential_files: | |
| await self._discover_credentials() | |
| if not self._credential_files: | |
| return None | |
| # 检查是否需要轮换 | |
| if await self._should_rotate(): | |
| await self._rotate_credential() | |
| # 尝试获取有效凭证,如果失败则自动切换 | |
| max_attempts = len(self._credential_files) # 最多尝试所有凭证 | |
| for attempt in range(max_attempts): | |
| try: | |
| # 加载当前凭证 | |
| result = await self._load_current_credential() | |
| if result: | |
| return result | |
| # 当前凭证加载失败,标记为失效并切换到下一个 | |
| current_file = self._credential_files[self._current_credential_index] if self._credential_files else None | |
| if current_file: | |
| log.warning(f"凭证失效,自动禁用并切换: {current_file}") | |
| await self.set_cred_disabled(current_file, True) | |
| # 重新发现可用凭证(排除刚禁用的) | |
| await self._discover_credentials() | |
| if not self._credential_files: | |
| log.error("没有可用的凭证") | |
| return None | |
| # 重置索引到第一个可用凭证 | |
| self._current_credential_index = 0 | |
| log.info(f"切换到下一个可用凭证 (索引: {self._current_credential_index})") | |
| else: | |
| log.error("无法获取当前凭证文件名") | |
| break | |
| except Exception as e: | |
| log.error(f"获取凭证时发生异常 (尝试 {attempt + 1}/{max_attempts}): {e}") | |
| if attempt < max_attempts - 1: | |
| # 切换到下一个凭证继续尝试 | |
| await self._rotate_credential() | |
| continue | |
| log.error(f"所有 {max_attempts} 个凭证都尝试失败") | |
| return None | |
| async def _should_rotate(self) -> bool: | |
| """检查是否需要轮换凭证""" | |
| if not self._credential_files or len(self._credential_files) <= 1: | |
| return False | |
| current_calls_per_rotation = await get_calls_per_rotation() | |
| return self._call_count >= current_calls_per_rotation | |
| async def _rotate_credential(self): | |
| """轮换到下一个凭证""" | |
| if len(self._credential_files) <= 1: | |
| return | |
| self._current_credential_index = (self._current_credential_index + 1) % len(self._credential_files) | |
| self._call_count = 0 | |
| log.info(f"Rotated to credential index {self._current_credential_index}") | |
| async def force_rotate_credential(self): | |
| """强制轮换到下一个凭证(用于429错误处理)""" | |
| async with self._operation_lock: | |
| if len(self._credential_files) <= 1: | |
| log.warning("Only one credential available, cannot rotate") | |
| return | |
| await self._rotate_credential() | |
| log.info("Forced credential rotation due to rate limit") | |
| def increment_call_count(self): | |
| """增加调用计数""" | |
| self._call_count += 1 | |
| async def update_credential_state(self, credential_name: str, state_updates: Dict[str, Any]): | |
| """更新凭证状态""" | |
| try: | |
| # 直接通过存储适配器更新状态 | |
| success = await self._storage_adapter.update_credential_state(credential_name, state_updates) | |
| # 如果是当前使用的凭证,更新缓存 | |
| if credential_name == self._current_credential_file: | |
| self._current_credential_state.update(state_updates) | |
| if success: | |
| log.debug(f"Updated credential state: {credential_name}") | |
| else: | |
| log.warning(f"Failed to update credential state: {credential_name}") | |
| return success | |
| except Exception as e: | |
| log.error(f"Error updating credential state {credential_name}: {e}") | |
| return False | |
| async def set_cred_disabled(self, credential_name: str, disabled: bool): | |
| """设置凭证的启用/禁用状态""" | |
| try: | |
| state_updates = {"disabled": disabled} | |
| success = await self.update_credential_state(credential_name, state_updates) | |
| if success: | |
| # 如果禁用了当前正在使用的凭证,需要重新发现可用凭证 | |
| if disabled and credential_name == self._current_credential_file: | |
| await self._discover_credentials() | |
| if self._credential_files: | |
| await self._rotate_credential() | |
| action = "disabled" if disabled else "enabled" | |
| log.info(f"Credential {action}: {credential_name}") | |
| return success | |
| except Exception as e: | |
| log.error(f"Error setting credential disabled state {credential_name}: {e}") | |
| return False | |
| async def get_creds_status(self) -> Dict[str, Dict[str, Any]]: | |
| """获取所有凭证的状态""" | |
| try: | |
| # 从存储适配器获取所有状态 | |
| all_states = await self._storage_adapter.get_all_credential_states() | |
| return all_states | |
| except Exception as e: | |
| log.error(f"Error getting credential statuses: {e}") | |
| return {} | |
| async def get_or_fetch_user_email(self, credential_name: str) -> Optional[str]: | |
| """获取或获取用户邮箱地址""" | |
| try: | |
| # 首先检查缓存的状态 | |
| state = await self._storage_adapter.get_credential_state(credential_name) | |
| cached_email = state.get("user_email") | |
| if cached_email: | |
| return cached_email | |
| # 如果没有缓存,从凭证数据获取 | |
| credential_data = await self._storage_adapter.get_credential(credential_name) | |
| if not credential_data: | |
| return None | |
| # 尝试获取邮箱 | |
| email = await fetch_user_email_from_file(credential_data) | |
| if email: | |
| # 缓存邮箱地址 | |
| await self.update_credential_state(credential_name, {"user_email": email}) | |
| return email | |
| return None | |
| except Exception as e: | |
| log.error(f"Error fetching user email for {credential_name}: {e}") | |
| return None | |
| async def record_api_call_result(self, credential_name: str, success: bool, error_code: Optional[int] = None): | |
| """记录API调用结果""" | |
| try: | |
| state_updates = {} | |
| if success: | |
| state_updates["last_success"] = time.time() | |
| # 清除错误码(如果之前有的话) | |
| state_updates["error_codes"] = [] | |
| elif error_code: | |
| # 记录错误码 | |
| current_state = await self._storage_adapter.get_credential_state(credential_name) | |
| error_codes = current_state.get("error_codes", []) | |
| if error_code not in error_codes: | |
| error_codes.append(error_code) | |
| # 限制错误码列表长度 | |
| if len(error_codes) > 10: | |
| error_codes = error_codes[-10:] | |
| state_updates["error_codes"] = error_codes | |
| if state_updates: | |
| await self.update_credential_state(credential_name, state_updates) | |
| except Exception as e: | |
| log.error(f"Error recording API call result for {credential_name}: {e}") | |
| # 原子操作支持 | |
| async def _atomic_operation(self, operation_name: str): | |
| """原子操作上下文管理器""" | |
| async with self._atomic_lock: | |
| self._atomic_counter += 1 | |
| operation_id = self._atomic_counter | |
| log.debug(f"开始原子操作[{operation_id}]: {operation_name}") | |
| try: | |
| yield operation_id | |
| log.debug(f"完成原子操作[{operation_id}]: {operation_name}") | |
| except Exception as e: | |
| log.error(f"原子操作[{operation_id}]失败: {operation_name} - {e}") | |
| raise | |
| async def _should_refresh_token(self, credential_data: Dict[str, Any]) -> bool: | |
| """检查token是否需要刷新""" | |
| try: | |
| # 如果没有access_token或过期时间,需要刷新 | |
| if not credential_data.get("access_token") and not credential_data.get("token"): | |
| log.debug("没有access_token,需要刷新") | |
| return True | |
| expiry_str = credential_data.get("expiry") | |
| if not expiry_str: | |
| log.debug("没有过期时间,需要刷新") | |
| return True | |
| # 解析过期时间 | |
| try: | |
| if isinstance(expiry_str, str): | |
| if "+" in expiry_str: | |
| file_expiry = datetime.fromisoformat(expiry_str) | |
| elif expiry_str.endswith("Z"): | |
| file_expiry = datetime.fromisoformat(expiry_str.replace('Z', '+00:00')) | |
| else: | |
| file_expiry = datetime.fromisoformat(expiry_str) | |
| else: | |
| log.debug("过期时间格式无效,需要刷新") | |
| return True | |
| # 确保时区信息 | |
| if file_expiry.tzinfo is None: | |
| file_expiry = file_expiry.replace(tzinfo=timezone.utc) | |
| # 检查是否还有至少5分钟有效期 | |
| now = datetime.now(timezone.utc) | |
| time_left = (file_expiry - now).total_seconds() | |
| log.debug(f"Token剩余时间: {int(time_left/60)}分钟") | |
| if time_left > 300: # 5分钟缓冲 | |
| return False | |
| else: | |
| log.debug(f"Token即将过期(剩余{int(time_left/60)}分钟),需要刷新") | |
| return True | |
| except Exception as e: | |
| log.warning(f"解析过期时间失败: {e},需要刷新") | |
| return True | |
| except Exception as e: | |
| log.error(f"检查token过期时出错: {e}") | |
| return True | |
| async def _refresh_token(self, credential_data: Dict[str, Any], filename: str) -> Optional[Dict[str, Any]]: | |
| """刷新token并更新存储""" | |
| try: | |
| # 创建Credentials对象 | |
| creds = Credentials.from_dict(credential_data) | |
| # 检查是否可以刷新 | |
| if not creds.refresh_token: | |
| log.error(f"没有refresh_token,无法刷新: {filename}") | |
| return None | |
| # 刷新token | |
| log.debug(f"正在刷新token: {filename}") | |
| await creds.refresh() | |
| # 更新凭证数据 | |
| if creds.access_token: | |
| credential_data["access_token"] = creds.access_token | |
| # 保持兼容性 | |
| credential_data["token"] = creds.access_token | |
| if creds.expires_at: | |
| credential_data["expiry"] = creds.expires_at.isoformat() | |
| # 保存到存储 | |
| await self._storage_adapter.store_credential(filename, credential_data) | |
| log.info(f"Token刷新成功并已保存: {filename}") | |
| return credential_data | |
| except Exception as e: | |
| error_msg = str(e) | |
| log.error(f"Token刷新失败 {filename}: {error_msg}") | |
| # 检查是否是凭证永久失效的错误 | |
| is_permanent_failure = self._is_permanent_refresh_failure(error_msg) | |
| if is_permanent_failure: | |
| log.warning(f"检测到凭证永久失效: {filename}") | |
| # 记录失效状态,但不在这里禁用凭证,让上层调用者处理 | |
| await self.record_api_call_result(filename, False, 400) | |
| return None | |
| def _is_permanent_refresh_failure(self, error_msg: str) -> bool: | |
| """判断是否是凭证永久失效的错误""" | |
| # 常见的永久失效错误模式 | |
| permanent_error_patterns = [ | |
| "400 Bad Request", | |
| "invalid_grant", | |
| "refresh_token_expired", | |
| "invalid_refresh_token", | |
| "unauthorized_client", | |
| "access_denied" | |
| ] | |
| error_msg_lower = error_msg.lower() | |
| for pattern in permanent_error_patterns: | |
| if pattern.lower() in error_msg_lower: | |
| return True | |
| return False | |
| # 兼容性方法 - 保持与现有代码的接口兼容 | |
| async def _update_token_in_file(self, file_path: str, new_token: str, expires_at=None): | |
| """更新凭证令牌(兼容性方法)""" | |
| try: | |
| credential_data = await self._storage_adapter.get_credential(file_path) | |
| if not credential_data: | |
| log.error(f"Credential not found for token update: {file_path}") | |
| return False | |
| # 更新令牌数据 | |
| credential_data["token"] = new_token | |
| if expires_at: | |
| credential_data["expiry"] = expires_at.isoformat() if hasattr(expires_at, 'isoformat') else expires_at | |
| # 保存更新后的凭证 | |
| success = await self._storage_adapter.store_credential(file_path, credential_data) | |
| if success: | |
| log.debug(f"Token updated for credential: {file_path}") | |
| else: | |
| log.error(f"Failed to update token for credential: {file_path}") | |
| return success | |
| except Exception as e: | |
| log.error(f"Error updating token for {file_path}: {e}") | |
| return False | |
| # 全局实例管理(保持兼容性) | |
| _credential_manager: Optional[CredentialManager] = None | |
| async def get_credential_manager() -> CredentialManager: | |
| """获取全局凭证管理器实例""" | |
| global _credential_manager | |
| if _credential_manager is None: | |
| _credential_manager = CredentialManager() | |
| await _credential_manager.initialize() | |
| return _credential_manager |