test-w / token_manager.py
letterm's picture
Upload 13 files
47d8ca8 verified
raw
history blame
14.2 kB
"""
多Token管理器模块
支持多个refresh token的管理、负载均衡和自动刷新
"""
import time
import threading
import requests
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
from config import Config
from utils import Utils
@dataclass
class TokenInfo:
"""Token信息数据类"""
refresh_token: str
token_id: str # token标识符
access_token: Optional[str] = None
last_refresh_time: Optional[float] = None
refresh_count: int = 0
error_count: int = 0
is_active: bool = True
class MultiTokenManager:
"""多Token管理器,支持负载均衡和自动刷新"""
def __init__(self):
# Token存储
self.tokens: Dict[str, TokenInfo] = {}
self.token_lock = threading.RLock()
# 负载均衡相关
self.current_index = 0
self.usage_count = defaultdict(int)
# 自动刷新相关
self.refresh_timer = None
self.refresh_executor = ThreadPoolExecutor(max_workers=5, thread_name_prefix="token-refresh")
self.session = requests.Session()
# 初始化
self._initialize_tokens()
def _initialize_tokens(self):
"""初始化token列表"""
# 从环境变量获取refresh token列表(支持分号分割)
env_tokens = Config.get_refresh_tokens()
if env_tokens:
# 过滤有效的token
valid_tokens = [token for token in env_tokens if Utils.validate_refresh_token_format(token)]
if valid_tokens:
self.add_refresh_tokens(valid_tokens)
logger.info(f"✅ 从环境变量初始化了 {len(valid_tokens)} 个有效token")
if len(env_tokens) > len(valid_tokens):
invalid_count = len(env_tokens) - len(valid_tokens)
logger.warning(f"⚠️ 跳过了 {invalid_count} 个格式无效的token")
else:
logger.warning("⚠️ 环境变量中的所有token格式都无效")
else:
logger.warning("⚠️ 未找到有效的refresh token,请通过环境变量 WARP_REFRESH_TOKEN 设置或在管理界面添加")
logger.info("💡 环境变量支持多个token,使用分号(;)分割:token1;token2;token3")
def add_refresh_token(self, refresh_token: str) -> bool:
"""添加单个refresh token"""
return self.add_refresh_tokens([refresh_token])
def add_refresh_tokens(self, refresh_tokens: List[str]) -> bool:
"""添加多个refresh token"""
with self.token_lock:
added_count = 0
for token in refresh_tokens:
if not Utils.validate_refresh_token_format(token):
logger.warning(f"⚠️ Token格式无效: {token[:20]}...")
continue
if token in self.tokens:
logger.info(f"⚠️ Token已存在,跳过: {self.tokens[token].token_id}")
continue
token_id = Utils.generate_token_id(token)
self.tokens[token] = TokenInfo(
refresh_token=token,
token_id=token_id
)
added_count += 1
logger.info(f"✅ 添加新token: {token_id}")
logger.info(f"📊 成功添加 {added_count} 个新token,当前总数: {len(self.tokens)}")
return added_count > 0
def remove_refresh_token(self, refresh_token: str) -> bool:
"""移除refresh token"""
with self.token_lock:
if refresh_token in self.tokens:
token_id = self.tokens[refresh_token].token_id
del self.tokens[refresh_token]
logger.info(f"🗑️ 移除token: {token_id}")
return True
return False
def get_refresh_token_from_env(self, refresh_token: str) -> bool:
"""检查token是否来自环境变量"""
env_tokens = Config.get_refresh_tokens()
return refresh_token in env_tokens
def remove_duplicate_tokens(self, new_tokens: List[str]) -> List[str]:
"""移除重复的token,优先保留环境变量中的token"""
with self.token_lock:
unique_tokens = []
for token in new_tokens:
# 检查是否与现有token重复
if token in self.tokens:
# 如果现有token来自环境变量,跳过新token
if self.get_refresh_token_from_env(token):
logger.info(f"⚠️ 跳过重复token(环境变量优先): {self.tokens[token].token_id}")
continue
else:
# 如果新token优先级更高,移除旧token
self.remove_refresh_token(token)
unique_tokens.append(token)
return unique_tokens
def get_access_token(self, refresh_token: str) -> Optional[str]:
"""通过refresh token获取新的access token"""
token_id = self.tokens.get(refresh_token, TokenInfo("", "")).token_id
try:
headers = {
'Accept-Encoding': 'gzip, br',
'Content-Type': 'application/x-www-form-urlencoded',
'x-warp-client-version': Config.WARP_CLIENT_VERSION,
'x-warp-os-category': Config.WARP_OS_CATEGORY,
'x-warp-os-name': Config.WARP_OS_NAME,
'x-warp-os-version': Config.WARP_OS_VERSION
}
data = {
'grant_type': 'refresh_token',
'refresh_token': refresh_token
}
logger.debug(f"🔄 正在刷新access token: {token_id}")
response = self.session.post(
Config.GOOGLE_TOKEN_URL,
headers=headers,
data=data,
timeout=30
)
if response.status_code == 200:
token_data = response.json()
access_token = token_data.get('access_token')
if access_token:
logger.success(f"✅ Access token获取成功: {token_id}")
return access_token
else:
logger.error(f"❌ 响应中未找到access_token字段: {token_id}")
else:
logger.error(f"❌ 获取token失败,状态码 {response.status_code}: {token_id}")
except Exception as e:
logger.error(f"❌ 获取access token时出错: {token_id} - {e}")
return None
def refresh_single_token(self, refresh_token: str) -> bool:
"""刷新单个token"""
with self.token_lock:
if refresh_token not in self.tokens:
return False
token_info = self.tokens[refresh_token]
try:
access_token = self.get_access_token(refresh_token)
with self.token_lock:
if access_token:
token_info.access_token = access_token
token_info.last_refresh_time = time.time()
token_info.refresh_count += 1
token_info.error_count = 0 # 重置错误计数
token_info.is_active = True
logger.success(f"🔄 Token刷新成功: {token_info.token_id} (第{token_info.refresh_count}次)")
return True
else:
token_info.error_count += 1
if token_info.error_count >= 3:
token_info.is_active = False
logger.warning(f"⚠️ Token连续失败3次,标记为不可用: {token_info.token_id}")
return False
except Exception as e:
with self.token_lock:
token_info.error_count += 1
logger.error(f"❌ 刷新token时出错: {token_info.token_id} - {e}")
return False
def refresh_all_tokens(self):
"""并发刷新所有token"""
with self.token_lock:
refresh_tokens = list(self.tokens.keys())
if not refresh_tokens:
logger.warning("⚠️ 没有可刷新的token")
return
logger.info(f"🔄 开始并发刷新 {len(refresh_tokens)} 个token")
# 使用线程池并发刷新
futures = {
self.refresh_executor.submit(self.refresh_single_token, token): token
for token in refresh_tokens
}
success_count = 0
for future in as_completed(futures):
token = futures[future]
try:
success = future.result(timeout=30)
if success:
success_count += 1
except Exception as e:
token_id = self.tokens.get(token, TokenInfo("", "")).token_id
logger.error(f"❌ 刷新token时出错: {token_id} - {e}")
logger.info(f"📊 Token刷新完成: {success_count}/{len(refresh_tokens)} 成功")
def get_current_access_token(self) -> Optional[str]:
"""获取当前可用的access token(负载均衡)"""
with self.token_lock:
# 获取所有活跃token
active_tokens = [
(token, info) for token, info in self.tokens.items()
if info.is_active and info.access_token
]
if not active_tokens:
logger.warning("⚠️ 没有可用的access token!")
logger.warning("💡 请确保:")
logger.warning(" 1. 设置了有效的 WARP_REFRESH_TOKEN 环境变量")
logger.warning(" 2. 或通过管理界面添加有效的 refresh token")
logger.warning(" 3. refresh token 已成功刷新获得 access token")
return None
# 简单的轮询负载均衡
if self.current_index >= len(active_tokens):
self.current_index = 0
selected_token, token_info = active_tokens[self.current_index]
self.current_index = (self.current_index + 1) % len(active_tokens)
# 更新使用计数
self.usage_count[selected_token] += 1
logger.info(f"🎯 轮询选择token: {token_info.token_id} (使用次数: {self.usage_count[selected_token]})")
return token_info.access_token
def get_token_status(self) -> Dict[str, Any]:
"""获取所有token的状态信息"""
with self.token_lock:
status = {
'total_tokens': len(self.tokens),
'active_tokens': sum(1 for info in self.tokens.values() if info.is_active),
'tokens_with_access': sum(1 for info in self.tokens.values() if info.access_token),
'tokens': []
}
for token, info in self.tokens.items():
token_status = {
'refresh_token': info.token_id, # 使用token_id而不是实际token
'has_access_token': bool(info.access_token),
'is_active': info.is_active,
'refresh_count': info.refresh_count,
'error_count': info.error_count,
'last_refresh_time': info.last_refresh_time,
'usage_count': self.usage_count.get(token, 0)
}
status['tokens'].append(token_status)
return status
def start_auto_refresh(self):
"""启动自动刷新服务"""
logger.info(f"🚀 启动Token自动刷新服务,间隔: {Config.TOKEN_REFRESH_INTERVAL // 60} 分钟")
# 立即进行一次刷新
refresh_thread = threading.Thread(target=self._initial_refresh, daemon=True)
refresh_thread.start()
def _initial_refresh(self):
"""初始刷新(在后台线程中执行)"""
logger.info("🔄 正在后台执行初始token刷新...")
self.refresh_all_tokens()
self._schedule_next_refresh()
def _schedule_next_refresh(self):
"""安排下次刷新"""
if self.refresh_timer:
self.refresh_timer.cancel()
self.refresh_timer = threading.Timer(
Config.TOKEN_REFRESH_INTERVAL,
self._auto_refresh_callback
)
self.refresh_timer.daemon = True
self.refresh_timer.start()
# 计算下次刷新时间
next_refresh_time = time.time() + Config.TOKEN_REFRESH_INTERVAL
next_refresh_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(next_refresh_time))
logger.info(f"⏰ 下次token刷新时间: {next_refresh_str}")
def _auto_refresh_callback(self):
"""自动刷新回调"""
logger.info("🔄 开始定时自动刷新所有token...")
self.refresh_all_tokens()
self._schedule_next_refresh()
def stop_auto_refresh(self):
"""停止自动刷新"""
if self.refresh_timer:
self.refresh_timer.cancel()
self.refresh_timer = None
self.refresh_executor.shutdown(wait=False)
logger.info("⏹️ Token自动刷新服务已停止")