kiroproxy / kiro_proxy /core /refresh_manager.py
KiroProxy User
chore: repo cleanup and maintenance
0edbd7b
"""Token 刷新管理模块
提供 Token 批量刷新的管理功能,包括:
- 刷新进度跟踪
- 并发控制
- 重试机制配置
- 全局锁防止重复刷新
- Token 过期检测和自动刷新
- 指数退避重试策略
"""
import time
import asyncio
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, Any, List, Tuple, Callable, TYPE_CHECKING
from threading import Lock
if TYPE_CHECKING:
from .account import Account
@dataclass
class RefreshProgress:
"""刷新进度信息
用于跟踪批量 Token 刷新操作的进度状态。
Attributes:
total: 需要刷新的账号总数
completed: 已完成处理的账号数(包括成功和失败)
success: 刷新成功的账号数
failed: 刷新失败的账号数
current_account: 当前正在处理的账号ID
status: 刷新状态 - running(进行中), completed(已完成), error(出错)
started_at: 刷新开始时间戳
message: 状态消息,用于显示当前操作或错误信息
"""
total: int = 0
completed: int = 0
success: int = 0
failed: int = 0
current_account: Optional[str] = None
status: str = "running" # running, completed, error
started_at: float = field(default_factory=time.time)
message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式
Returns:
包含所有进度信息的字典
"""
return asdict(self)
@property
def progress_percent(self) -> float:
"""计算完成百分比
Returns:
完成百分比(0-100)
"""
if self.total == 0:
return 0.0
return round((self.completed / self.total) * 100, 2)
@property
def elapsed_seconds(self) -> float:
"""计算已用时间(秒)
Returns:
从开始到现在的秒数
"""
return time.time() - self.started_at
def is_running(self) -> bool:
"""检查是否正在运行
Returns:
True 表示正在运行
"""
return self.status == "running"
def is_completed(self) -> bool:
"""检查是否已完成
Returns:
True 表示已完成(成功或出错)
"""
return self.status in ("completed", "error")
@dataclass
class RefreshConfig:
"""刷新配置
控制 Token 刷新行为的配置参数。
Attributes:
max_retries: 单个账号刷新失败时的最大重试次数
retry_base_delay: 重试基础延迟时间(秒),实际延迟会指数增长
concurrency: 并发刷新的账号数量
token_refresh_before_expiry: Token 过期前多少秒开始刷新(默认5分钟)
auto_refresh_interval: 自动刷新检查间隔(秒)
"""
max_retries: int = 3
retry_base_delay: float = 1.0
concurrency: int = 3
token_refresh_before_expiry: int = 300 # 5分钟
auto_refresh_interval: int = 60 # 1分钟
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式
Returns:
包含所有配置项的字典
"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RefreshConfig':
"""从字典创建配置实例
Args:
data: 配置字典
Returns:
RefreshConfig 实例
"""
return cls(
max_retries=data.get("max_retries", 3),
retry_base_delay=data.get("retry_base_delay", 1.0),
concurrency=data.get("concurrency", 3),
token_refresh_before_expiry=data.get("token_refresh_before_expiry", 300),
auto_refresh_interval=data.get("auto_refresh_interval", 60)
)
def validate(self) -> bool:
"""验证配置有效性
Returns:
True 表示配置有效
Raises:
ValueError: 配置值无效时抛出
"""
if self.max_retries < 0:
raise ValueError("max_retries 不能为负数")
if self.retry_base_delay <= 0:
raise ValueError("retry_base_delay 必须大于0")
if self.concurrency < 1:
raise ValueError("concurrency 必须至少为1")
if self.token_refresh_before_expiry < 0:
raise ValueError("token_refresh_before_expiry 不能为负数")
if self.auto_refresh_interval < 1:
raise ValueError("auto_refresh_interval 必须至少为1秒")
return True
class RefreshManager:
"""Token 刷新管理器
管理 Token 批量刷新操作,提供:
- 全局锁机制防止重复刷新
- 进度跟踪
- 配置管理
- 自动 Token 刷新定时器
使用示例:
manager = get_refresh_manager()
if not manager.is_refreshing():
# 开始刷新操作
pass
"""
def __init__(self, config: Optional[RefreshConfig] = None):
"""初始化刷新管理器
Args:
config: 刷新配置,None 则使用默认配置
"""
# 配置
self._config = config or RefreshConfig()
# 线程锁(用于同步访问状态)
self._lock = Lock()
# 异步锁(用于防止并发刷新操作)
self._async_lock = asyncio.Lock()
# 刷新状态
self._is_refreshing: bool = False
self._progress: Optional[RefreshProgress] = None
# 上次刷新完成时间
self._last_refresh_time: Optional[float] = None
# 自动刷新定时器
self._auto_refresh_task: Optional[asyncio.Task] = None
self._auto_refresh_running: bool = False
# 获取账号列表的回调函数
self._accounts_getter: Optional[Callable] = None
@property
def config(self) -> RefreshConfig:
"""获取当前配置
Returns:
当前的刷新配置
"""
with self._lock:
return self._config
def is_refreshing(self) -> bool:
"""检查是否正在刷新
Returns:
True 表示正在进行刷新操作
"""
with self._lock:
return self._is_refreshing
def get_progress(self) -> Optional[RefreshProgress]:
"""获取当前刷新进度
Returns:
当前进度信息,如果没有进行中的刷新则返回 None
"""
with self._lock:
return self._progress
def get_progress_dict(self) -> Optional[Dict[str, Any]]:
"""获取当前刷新进度(字典格式)
Returns:
进度信息字典,如果没有进行中的刷新则返回 None
"""
with self._lock:
if self._progress is None:
return None
return self._progress.to_dict()
def update_config(self, **kwargs) -> None:
"""更新配置参数
支持的参数:
max_retries: 最大重试次数
retry_base_delay: 重试基础延迟
concurrency: 并发数
token_refresh_before_expiry: Token 过期前刷新时间
auto_refresh_interval: 自动刷新检查间隔
Args:
**kwargs: 要更新的配置项
Raises:
ValueError: 配置值无效时抛出
"""
with self._lock:
# 创建新配置
new_config = RefreshConfig(
max_retries=kwargs.get("max_retries", self._config.max_retries),
retry_base_delay=kwargs.get("retry_base_delay", self._config.retry_base_delay),
concurrency=kwargs.get("concurrency", self._config.concurrency),
token_refresh_before_expiry=kwargs.get(
"token_refresh_before_expiry",
self._config.token_refresh_before_expiry
),
auto_refresh_interval=kwargs.get(
"auto_refresh_interval",
self._config.auto_refresh_interval
)
)
# 验证配置
new_config.validate()
# 应用新配置
self._config = new_config
def _start_refresh(self, total: int, message: Optional[str] = None) -> RefreshProgress:
"""开始刷新操作(内部方法)
Args:
total: 需要刷新的账号总数
message: 初始状态消息
Returns:
新创建的进度对象
"""
with self._lock:
self._is_refreshing = True
self._progress = RefreshProgress(
total=total,
completed=0,
success=0,
failed=0,
current_account=None,
status="running",
started_at=time.time(),
message=message or "开始刷新"
)
return self._progress
def _update_progress(
self,
current_account: Optional[str] = None,
success: bool = False,
failed: bool = False,
message: Optional[str] = None
) -> None:
"""更新刷新进度(内部方法)
Args:
current_account: 当前处理的账号ID
success: 是否成功完成一个账号
failed: 是否失败一个账号
message: 状态消息
"""
with self._lock:
if self._progress is None:
return
if current_account is not None:
self._progress.current_account = current_account
if success:
self._progress.success += 1
self._progress.completed += 1
elif failed:
self._progress.failed += 1
self._progress.completed += 1
if message is not None:
self._progress.message = message
def _finish_refresh(self, status: str = "completed", message: Optional[str] = None) -> None:
"""完成刷新操作(内部方法)
Args:
status: 最终状态 - completed 或 error
message: 最终状态消息
"""
with self._lock:
self._is_refreshing = False
self._last_refresh_time = time.time()
if self._progress is not None:
self._progress.status = status
self._progress.current_account = None
if message is not None:
self._progress.message = message
elif status == "completed":
self._progress.message = (
f"刷新完成: 成功 {self._progress.success}, "
f"失败 {self._progress.failed}"
)
def get_last_refresh_time(self) -> Optional[float]:
"""获取上次刷新完成时间
Returns:
上次刷新完成的时间戳,如果从未刷新则返回 None
"""
with self._lock:
return self._last_refresh_time
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态
Returns:
包含管理器状态信息的字典
"""
with self._lock:
return {
"is_refreshing": self._is_refreshing,
"progress": self._progress.to_dict() if self._progress else None,
"last_refresh_time": self._last_refresh_time,
"config": self._config.to_dict()
}
async def acquire_refresh_lock(self) -> bool:
"""尝试获取刷新锁
用于在开始刷新操作前获取异步锁,防止并发刷新。
Returns:
True 表示成功获取锁,False 表示已有刷新在进行
"""
if self._async_lock.locked():
return False
await self._async_lock.acquire()
return True
def release_refresh_lock(self) -> None:
"""释放刷新锁
在刷新操作完成后调用,释放异步锁。
"""
if self._async_lock.locked():
self._async_lock.release()
def should_refresh_token(self, account: 'Account') -> bool:
"""判断是否需要刷新 Token
检查账号的 Token 是否即将过期(过期前5分钟)或已过期。
Args:
account: 账号对象
Returns:
True 表示需要刷新 Token
"""
creds = account.get_credentials()
if creds is None:
return True # 无法获取凭证,需要刷新
# 检查是否已过期或即将过期
minutes_before = self._config.token_refresh_before_expiry // 60
return creds.is_expired() or creds.is_expiring_soon(minutes=minutes_before)
async def refresh_token_if_needed(self, account: 'Account') -> Tuple[bool, str]:
"""如果需要则刷新 Token
检查账号 Token 状态,如果即将过期或已过期则刷新。
Args:
account: 账号对象
Returns:
(success, message) 元组
- success: True 表示 Token 有效(无需刷新或刷新成功)
- message: 状态消息
"""
if not self.should_refresh_token(account):
return True, "Token 有效,无需刷新"
print(f"[RefreshManager] 账号 {account.id} Token 即将过期,开始刷新...")
success, result = await account.refresh_token()
if success:
print(f"[RefreshManager] 账号 {account.id} Token 刷新成功")
return True, "Token 刷新成功"
else:
print(f"[RefreshManager] 账号 {account.id} Token 刷新失败: {result}")
return False, f"Token 刷新失败: {result}"
async def refresh_account_with_token(
self,
account: 'Account',
get_quota_func: Optional[Callable] = None
) -> Tuple[bool, str]:
"""刷新单个账号(先刷新 Token,再获取额度)
Args:
account: 账号对象
get_quota_func: 获取额度的异步函数,接受 account 参数
Returns:
(success, message) 元组
"""
# 1. 先刷新 Token(如果需要)
token_success, token_msg = await self.refresh_token_if_needed(account)
if not token_success:
return False, token_msg
# 2. 获取额度(如果提供了获取函数)
if get_quota_func:
try:
quota_success, quota_result = await get_quota_func(account)
if quota_success:
return True, "刷新成功"
else:
error_msg = quota_result.get("error", "Unknown error") if isinstance(quota_result, dict) else str(quota_result)
return False, f"获取额度失败: {error_msg}"
except Exception as e:
return False, f"获取额度异常: {str(e)}"
return True, token_msg
async def retry_with_backoff(
self,
func: Callable,
*args,
max_retries: Optional[int] = None,
**kwargs
) -> Tuple[bool, Any]:
"""带指数退避的重试
执行异步函数,失败时使用指数退避策略重试。
Args:
func: 要执行的异步函数
*args: 传递给函数的位置参数
max_retries: 最大重试次数,None 则使用配置值
**kwargs: 传递给函数的关键字参数
Returns:
(success, result) 元组
- success: True 表示执行成功
- result: 成功时为函数返回值,失败时为错误信息
"""
retries = max_retries if max_retries is not None else self._config.max_retries
base_delay = self._config.retry_base_delay
last_error = None
for attempt in range(retries + 1):
try:
result = await func(*args, **kwargs)
# 检查返回值格式
if isinstance(result, tuple) and len(result) == 2:
success, data = result
if success:
return True, data
else:
last_error = data
# 检查是否是 429 错误
if self._is_rate_limit_error(data):
delay = self._get_rate_limit_delay(attempt, base_delay)
else:
delay = base_delay * (2 ** attempt)
else:
# 函数返回非元组,视为成功
return True, result
except Exception as e:
last_error = str(e)
delay = base_delay * (2 ** attempt)
# 如果还有重试机会,等待后重试
if attempt < retries:
print(f"[RefreshManager] 第 {attempt + 1} 次尝试失败,{delay:.1f}秒后重试...")
await asyncio.sleep(delay)
return False, last_error
def _is_rate_limit_error(self, error: Any) -> bool:
"""检查是否是限流错误(429)
Args:
error: 错误信息
Returns:
True 表示是限流错误
"""
if isinstance(error, str):
return "429" in error or "rate limit" in error.lower() or "请求过于频繁" in error
return False
def _get_rate_limit_delay(self, attempt: int, base_delay: float) -> float:
"""获取限流错误的等待时间
429 错误使用更长的等待时间。
Args:
attempt: 当前尝试次数(从0开始)
base_delay: 基础延迟
Returns:
等待时间(秒)
"""
# 429 错误使用 3 倍的基础延迟
return base_delay * 3 * (2 ** attempt)
async def refresh_all_with_token(
self,
accounts: List['Account'],
get_quota_func: Optional[Callable] = None,
skip_disabled: bool = True,
skip_error: bool = True
) -> RefreshProgress:
"""刷新所有账号(先刷新 Token,再获取额度)
使用全局锁防止并发刷新,支持进度跟踪。
Args:
accounts: 账号列表
get_quota_func: 获取额度的异步函数
skip_disabled: 是否跳过已禁用的账号
skip_error: 是否跳过已处于错误状态的账号
Returns:
刷新进度信息
"""
# 尝试获取锁
if not await self.acquire_refresh_lock():
# 已有刷新在进行
progress = self.get_progress()
if progress:
return progress
# 返回一个错误状态的进度
return RefreshProgress(
total=0,
status="error",
message="刷新操作正在进行中"
)
try:
# 过滤账号
accounts_to_refresh = []
for acc in accounts:
if skip_disabled and not acc.enabled:
continue
if skip_error and acc.status.value in ("unhealthy", "suspended"):
continue
accounts_to_refresh.append(acc)
total = len(accounts_to_refresh)
# 开始刷新
self._start_refresh(total, f"开始刷新 {total} 个账号")
if total == 0:
self._finish_refresh("completed", "没有需要刷新的账号")
return self.get_progress()
# 使用信号量控制并发
semaphore = asyncio.Semaphore(self._config.concurrency)
async def refresh_one(account: 'Account'):
async with semaphore:
self._update_progress(
current_account=account.id,
message=f"正在刷新: {account.name}"
)
# 使用重试机制刷新
success, result = await self.retry_with_backoff(
self.refresh_account_with_token,
account,
get_quota_func
)
if success:
self._update_progress(success=True)
else:
self._update_progress(failed=True)
return success, result
# 并发执行
tasks = [refresh_one(acc) for acc in accounts_to_refresh]
await asyncio.gather(*tasks, return_exceptions=True)
# 完成
self._finish_refresh("completed")
return self.get_progress()
except Exception as e:
self._finish_refresh("error", f"刷新异常: {str(e)}")
return self.get_progress()
finally:
self.release_refresh_lock()
def _is_auth_error(self, error: Any) -> bool:
"""检查是否是认证错误(401)
Args:
error: 错误信息
Returns:
True 表示是认证错误
"""
if isinstance(error, str):
return "401" in error or "unauthorized" in error.lower() or "凭证已过期" in error or "需要重新登录" in error
return False
async def execute_with_auth_retry(
self,
account: 'Account',
func: Callable,
*args,
**kwargs
) -> Tuple[bool, Any]:
"""执行操作,遇到 401 错误时自动刷新 Token 并重试
Args:
account: 账号对象
func: 要执行的异步函数
*args: 传递给函数的位置参数
**kwargs: 传递给函数的关键字参数
Returns:
(success, result) 元组
"""
try:
result = await func(*args, **kwargs)
# 检查返回值
if isinstance(result, tuple) and len(result) == 2:
success, data = result
if success:
return True, data
# 检查是否是 401 错误
if self._is_auth_error(data):
print(f"[RefreshManager] 账号 {account.id} 遇到 401 错误,尝试刷新 Token...")
# 刷新 Token
refresh_success, refresh_msg = await account.refresh_token()
if refresh_success:
print(f"[RefreshManager] Token 刷新成功,重试请求...")
# 重试原请求
retry_result = await func(*args, **kwargs)
if isinstance(retry_result, tuple) and len(retry_result) == 2:
return retry_result
return True, retry_result
else:
return False, f"Token 刷新失败: {refresh_msg}"
return False, data
return True, result
except Exception as e:
error_str = str(e)
# 检查异常是否是 401 错误
if self._is_auth_error(error_str):
print(f"[RefreshManager] 账号 {account.id} 遇到 401 异常,尝试刷新 Token...")
refresh_success, refresh_msg = await account.refresh_token()
if refresh_success:
print(f"[RefreshManager] Token 刷新成功,重试请求...")
try:
retry_result = await func(*args, **kwargs)
if isinstance(retry_result, tuple) and len(retry_result) == 2:
return retry_result
return True, retry_result
except Exception as retry_e:
return False, f"重试失败: {str(retry_e)}"
else:
return False, f"Token 刷新失败: {refresh_msg}"
return False, error_str
def set_accounts_getter(self, getter: Callable) -> None:
"""设置获取账号列表的回调函数
Args:
getter: 返回账号列表的可调用对象
"""
self._accounts_getter = getter
def _get_accounts(self) -> List['Account']:
"""获取账号列表"""
if self._accounts_getter:
return self._accounts_getter()
return []
async def start_auto_refresh(self) -> None:
"""启动自动 Token 刷新定时器
定期检查所有账号的 Token 状态,自动刷新即将过期的 Token。
启动前会清除已存在的定时器,防止重复启动。
"""
# 先停止已存在的定时器
await self.stop_auto_refresh()
self._auto_refresh_running = True
self._auto_refresh_task = asyncio.create_task(self._auto_refresh_loop())
print(f"[RefreshManager] 自动 Token 刷新定时器已启动,检查间隔: {self._config.auto_refresh_interval}秒")
async def stop_auto_refresh(self) -> None:
"""停止自动 Token 刷新定时器"""
self._auto_refresh_running = False
if self._auto_refresh_task:
self._auto_refresh_task.cancel()
try:
await self._auto_refresh_task
except asyncio.CancelledError:
pass
self._auto_refresh_task = None
print("[RefreshManager] 自动 Token 刷新定时器已停止")
def is_auto_refresh_running(self) -> bool:
"""检查自动刷新定时器是否在运行
Returns:
True 表示定时器正在运行
"""
return self._auto_refresh_running and self._auto_refresh_task is not None
async def _auto_refresh_loop(self) -> None:
"""自动刷新循环
定期检查所有账号的 Token 状态,刷新即将过期的 Token。
跳过已禁用或错误状态的账号,单个失败不影响其他账号。
"""
while self._auto_refresh_running:
try:
await asyncio.sleep(self._config.auto_refresh_interval)
if not self._auto_refresh_running:
break
accounts = self._get_accounts()
if not accounts:
continue
# 检查需要刷新的账号
accounts_to_refresh = []
for account in accounts:
# 跳过已禁用的账号
if not account.enabled:
continue
# 跳过错误状态的账号
if hasattr(account, 'status') and account.status.value in ("unhealthy", "suspended", "disabled"):
continue
# 检查是否需要刷新 Token
if self.should_refresh_token(account):
accounts_to_refresh.append(account)
if accounts_to_refresh:
print(f"[RefreshManager] 发现 {len(accounts_to_refresh)} 个账号需要刷新 Token")
# 逐个刷新,单个失败不影响其他
for account in accounts_to_refresh:
try:
success, message = await self.refresh_token_if_needed(account)
if not success:
print(f"[RefreshManager] 账号 {account.id} 自动刷新失败: {message}")
except Exception as e:
print(f"[RefreshManager] 账号 {account.id} 自动刷新异常: {e}")
# 继续处理其他账号
except asyncio.CancelledError:
break
except Exception as e:
print(f"[RefreshManager] 自动刷新循环异常: {e}")
# 继续运行,不因异常停止
def get_auto_refresh_status(self) -> Dict[str, Any]:
"""获取自动刷新状态
Returns:
包含自动刷新状态信息的字典
"""
return {
"running": self.is_auto_refresh_running(),
"interval": self._config.auto_refresh_interval,
"token_refresh_before_expiry": self._config.token_refresh_before_expiry
}
# 全局刷新管理器实例
_refresh_manager: Optional[RefreshManager] = None
_manager_lock = Lock()
def get_refresh_manager() -> RefreshManager:
"""获取全局刷新管理器实例
使用单例模式,确保全局只有一个刷新管理器实例。
Returns:
全局 RefreshManager 实例
"""
global _refresh_manager
if _refresh_manager is None:
with _manager_lock:
# 双重检查锁定
if _refresh_manager is None:
_refresh_manager = RefreshManager()
return _refresh_manager
def reset_refresh_manager() -> None:
"""重置全局刷新管理器
主要用于测试场景,重置全局实例。
"""
global _refresh_manager
with _manager_lock:
_refresh_manager = None