test-w / warp_client.py
letterm's picture
Upload 13 files
47d8ca8 verified
raw
history blame
9.01 kB
"""
Warp客户端模块
处理与Warp API的通信,包括protobuf数据处理和HTTP请求
"""
import requests
import json
from typing import Generator, List, Optional
from loguru import logger
from config import Config
from request_converter import OpenAIMessage
from utils import Utils
from protobuf_manager import ProtobufManager
class WarpClient:
"""Warp API客户端"""
def __init__(self, token_manager=None):
self.token_manager = token_manager
self.session = requests.Session()
def get_auth_token(self) -> Optional[str]:
"""获取当前认证token"""
if self.token_manager:
token = self.token_manager.get_current_access_token()
if token:
return token
else:
logger.error("❌ 无法获取有效的认证token")
return None
else:
logger.error("❌ Token管理器未初始化")
return None
def create_protobuf_data(self, messages: List[OpenAIMessage], model: str = "gemini-2.0-flash") -> Optional[bytes]:
"""创建protobuf数据"""
return ProtobufManager.create_chat_request(messages, model)
def parse_protobuf_response(self, base64_data: str) -> str:
"""解析protobuf响应数据"""
return ProtobufManager.parse_chat_response(base64_data)
def _check_and_handle_401_error(self, response_text: str, current_access_token: str = None, current_refresh_token: str = None) -> bool:
"""检查并处理401错误,返回True表示需要重试"""
try:
# 尝试解析响应为JSON
error_data = json.loads(response_text)
error_message = error_data.get("error", "")
# 检查是否是目标错误消息
if "Unauthorized: User not in context: Not found: no rows in result set" in error_message:
logger.warning("⚠️ 检测到用户未创建错误,尝试用当前token创建用户")
if self.token_manager and current_access_token:
# 使用当前的access_token尝试创建用户
logger.info(f"🔄 使用当前access_token创建用户: {current_access_token[:20]}...")
# 尝试创建用户
from login_client import LoginClient
login_client = LoginClient()
if login_client.create_user_with_access_token(current_access_token):
logger.success("✅ 用户创建成功,可以重试请求")
return True
else:
logger.error("❌ 当前token创建用户失败")
# 如果用当前token创建用户失败,移除这个refresh_token
if current_refresh_token:
logger.info(f"🗑️ 移除创建用户失败的refresh_token")
self.token_manager.remove_refresh_token(current_refresh_token)
# 尝试获取下一个token
logger.info("🔄 尝试获取下一个access_token...")
next_token = self.token_manager.get_current_access_token()
if next_token and next_token != current_access_token:
logger.info(f"🔄 获取到下一个access_token: {next_token[:20]}...")
return True # 返回True表示可以用新token重试
else:
logger.error("❌ 无法获取下一个有效的access_token")
return False
else:
logger.error("❌ Token管理器未初始化或当前access_token为空")
return False
return False
except (json.JSONDecodeError, KeyError, TypeError):
# 如果不是JSON格式或没有error字段,则不是目标错误
return False
def send_request(self, protobuf_data: bytes) -> Generator[str, None, None]:
"""发送请求到Warp API并返回流式响应"""
url = f"{Config.WARP_BASE_URL}{Config.WARP_AI_ENDPOINT}"
max_retries = 2 # 最多重试2次
retry_count = 0
while retry_count <= max_retries:
# 获取认证token
auth_token = self.get_auth_token()
if not auth_token:
logger.error("❌ 无法获取认证token,请求终止")
return
# 获取当前使用的refresh_token(用于错误时移除)
current_refresh_token = None
if self.token_manager:
with self.token_manager.token_lock:
for refresh_token, token_info in self.token_manager.tokens.items():
if token_info.access_token == auth_token:
current_refresh_token = refresh_token
break
headers = {
'Accept': 'text/event-stream',
'Accept-Encoding': 'gzip, br',
'Content-Type': 'application/x-protobuf',
'x-warp-client-version': Config.WARP_CLIENT_VERSION,
'x-warp-os-category': Config.WARP_OS_CATEGORY,
'x-warp-os-name': Config.WARP_OS_NAME,
'x-warp-os-version': Config.WARP_OS_VERSION,
'authorization': f'Bearer {auth_token}'
}
if retry_count > 0:
logger.info(f"🔄 第{retry_count}次重试请求到Warp API: {url}")
else:
logger.info(f"🌐 发送请求到Warp API: {url}")
logger.debug(f"📦 Protobuf数据大小: {len(protobuf_data)} 字节")
try:
response = self.session.post(
url,
headers=headers,
data=protobuf_data,
stream=True,
timeout=Config.REQUEST_TIMEOUT,
verify=False
)
if response.status_code == 200:
logger.success("✅ 请求成功,开始接收流式响应")
chunk_count = 0
for line in response.iter_lines(decode_unicode=True):
if line and line.startswith('data:'):
data = line[5:].strip()
text = self.parse_protobuf_response(data)
if text:
chunk_count += 1
logger.debug(f"📨 接收到响应块 #{chunk_count}: {len(text)} 字符")
yield text
logger.success(f"🎉 流式响应接收完成,总共 {chunk_count} 个块")
return # 成功,退出重试循环
elif response.status_code == 401:
logger.error(f"❌ Warp API请求失败,状态码: {response.status_code}")
if response.text:
logger.error(f"❌ 错误详情: {response.text}")
# 检查并处理特定的401错误
should_retry = self._check_and_handle_401_error(response.text, auth_token, current_refresh_token)
if should_retry and retry_count < max_retries:
retry_count += 1
logger.info(f"🔄 准备第{retry_count}次重试...")
continue
else:
logger.error(f"❌ 401错误处理失败或已达到最大重试次数")
return
else:
logger.error("❌ 401错误,但无响应内容")
return
else:
logger.error(f"❌ Warp API请求失败,状态码: {response.status_code}")
if response.text:
logger.error(f"❌ 错误详情: {response.text}")
return
except requests.Timeout:
logger.error(f"❌ 请求超时 ({Config.REQUEST_TIMEOUT}秒)")
return
except Exception as e:
logger.error(f"❌ 发送请求时出错: {e}")
import traceback
traceback.print_exc()
return