| """
|
| Warp客户端模块
|
| 处理与Warp API的通信,包括protobuf数据处理和HTTP请求
|
| """
|
| import requests
|
| import json
|
| from typing import Generator, List, Optional
|
| from loguru import logger
|
|
|
| from config import Config
|
| from request_converter import OpenAIMessage
|
| from utils import Utils
|
| from protobuf_manager import ProtobufManager
|
|
|
|
|
| class WarpClient:
|
| """Warp API客户端"""
|
|
|
| def __init__(self, token_manager=None):
|
| self.token_manager = token_manager
|
| self.session = requests.Session()
|
|
|
| def get_auth_token(self) -> Optional[str]:
|
| """获取当前认证token"""
|
| if self.token_manager:
|
| token = self.token_manager.get_current_access_token()
|
| if token:
|
| return token
|
| else:
|
| logger.error("❌ 无法获取有效的认证token")
|
| return None
|
| else:
|
| logger.error("❌ Token管理器未初始化")
|
| return None
|
|
|
|
|
|
|
| def create_protobuf_data(self, messages: List[OpenAIMessage], model: str = "gemini-2.0-flash") -> Optional[bytes]:
|
| """创建protobuf数据"""
|
| return ProtobufManager.create_chat_request(messages, model)
|
|
|
| def parse_protobuf_response(self, base64_data: str) -> str:
|
| """解析protobuf响应数据"""
|
| return ProtobufManager.parse_chat_response(base64_data)
|
|
|
| def _check_and_handle_401_error(self, response_text: str, current_access_token: str = None, current_refresh_token: str = None) -> bool:
|
| """检查并处理401错误,返回True表示需要重试"""
|
| try:
|
|
|
| error_data = json.loads(response_text)
|
| error_message = error_data.get("error", "")
|
|
|
|
|
| if "Unauthorized: User not in context: Not found: no rows in result set" in error_message:
|
| logger.warning("⚠️ 检测到用户未创建错误,尝试用当前token创建用户")
|
|
|
| if self.token_manager and current_access_token:
|
|
|
| logger.info(f"🔄 使用当前access_token创建用户: {current_access_token[:20]}...")
|
|
|
|
|
| from login_client import LoginClient
|
| login_client = LoginClient()
|
|
|
| if login_client.create_user_with_access_token(current_access_token):
|
| logger.success("✅ 用户创建成功,可以重试请求")
|
| return True
|
| else:
|
| logger.error("❌ 当前token创建用户失败")
|
|
|
| if current_refresh_token:
|
| logger.info(f"🗑️ 移除创建用户失败的refresh_token")
|
| self.token_manager.remove_refresh_token(current_refresh_token)
|
|
|
|
|
| logger.info("🔄 尝试获取下一个access_token...")
|
| next_token = self.token_manager.get_current_access_token()
|
| if next_token and next_token != current_access_token:
|
| logger.info(f"🔄 获取到下一个access_token: {next_token[:20]}...")
|
| return True
|
| else:
|
| logger.error("❌ 无法获取下一个有效的access_token")
|
| return False
|
| else:
|
| logger.error("❌ Token管理器未初始化或当前access_token为空")
|
| return False
|
|
|
| return False
|
|
|
| except (json.JSONDecodeError, KeyError, TypeError):
|
|
|
| return False
|
|
|
| def send_request(self, protobuf_data: bytes) -> Generator[str, None, None]:
|
| """发送请求到Warp API并返回流式响应"""
|
| url = f"{Config.WARP_BASE_URL}{Config.WARP_AI_ENDPOINT}"
|
|
|
| max_retries = 2
|
| retry_count = 0
|
|
|
| while retry_count <= max_retries:
|
|
|
| auth_token = self.get_auth_token()
|
| if not auth_token:
|
| logger.error("❌ 无法获取认证token,请求终止")
|
| return
|
|
|
|
|
| current_refresh_token = None
|
| if self.token_manager:
|
| with self.token_manager.token_lock:
|
| for refresh_token, token_info in self.token_manager.tokens.items():
|
| if token_info.access_token == auth_token:
|
| current_refresh_token = refresh_token
|
| break
|
|
|
| headers = {
|
| 'Accept': 'text/event-stream',
|
| 'Accept-Encoding': 'gzip, br',
|
| 'Content-Type': 'application/x-protobuf',
|
| 'x-warp-client-version': Config.WARP_CLIENT_VERSION,
|
| 'x-warp-os-category': Config.WARP_OS_CATEGORY,
|
| 'x-warp-os-name': Config.WARP_OS_NAME,
|
| 'x-warp-os-version': Config.WARP_OS_VERSION,
|
| 'authorization': f'Bearer {auth_token}'
|
| }
|
|
|
| if retry_count > 0:
|
| logger.info(f"🔄 第{retry_count}次重试请求到Warp API: {url}")
|
| else:
|
| logger.info(f"🌐 发送请求到Warp API: {url}")
|
|
|
| logger.debug(f"📦 Protobuf数据大小: {len(protobuf_data)} 字节")
|
|
|
| try:
|
| response = self.session.post(
|
| url,
|
| headers=headers,
|
| data=protobuf_data,
|
| stream=True,
|
| timeout=Config.REQUEST_TIMEOUT,
|
| verify=False
|
| )
|
|
|
| if response.status_code == 200:
|
| logger.success("✅ 请求成功,开始接收流式响应")
|
| chunk_count = 0
|
|
|
| for line in response.iter_lines(decode_unicode=True):
|
| if line and line.startswith('data:'):
|
| data = line[5:].strip()
|
| text = self.parse_protobuf_response(data)
|
| if text:
|
| chunk_count += 1
|
| logger.debug(f"📨 接收到响应块 #{chunk_count}: {len(text)} 字符")
|
| yield text
|
|
|
| logger.success(f"🎉 流式响应接收完成,总共 {chunk_count} 个块")
|
| return
|
|
|
| elif response.status_code == 401:
|
| logger.error(f"❌ Warp API请求失败,状态码: {response.status_code}")
|
| if response.text:
|
| logger.error(f"❌ 错误详情: {response.text}")
|
|
|
|
|
| should_retry = self._check_and_handle_401_error(response.text, auth_token, current_refresh_token)
|
|
|
| if should_retry and retry_count < max_retries:
|
| retry_count += 1
|
| logger.info(f"🔄 准备第{retry_count}次重试...")
|
| continue
|
| else:
|
| logger.error(f"❌ 401错误处理失败或已达到最大重试次数")
|
| return
|
| else:
|
| logger.error("❌ 401错误,但无响应内容")
|
| return
|
| else:
|
| logger.error(f"❌ Warp API请求失败,状态码: {response.status_code}")
|
| if response.text:
|
| logger.error(f"❌ 错误详情: {response.text}")
|
| return
|
|
|
| except requests.Timeout:
|
| logger.error(f"❌ 请求超时 ({Config.REQUEST_TIMEOUT}秒)")
|
| return
|
| except Exception as e:
|
| logger.error(f"❌ 发送请求时出错: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return |