Spaces:
Paused
Paused
Upload 31 files
Browse files- .env.example +60 -0
- Dockerfile +10 -1
- app/__init__.py +6 -0
- app/core/__init__.py +6 -0
- app/core/config.py +206 -0
- app/core/openai.py +268 -0
- app/core/zai_transformer.py +777 -0
- app/models/__init__.py +6 -0
- app/models/schemas.py +145 -0
- app/providers/__init__.py +26 -0
- app/providers/base.py +268 -0
- app/providers/k2think_provider.py +509 -0
- app/providers/longcat_provider.py +466 -0
- app/providers/provider_factory.py +196 -0
- app/providers/zai_provider.py +764 -0
- app/utils/__init__.py +6 -0
- app/utils/logger.py +105 -0
- app/utils/reload_config.py +89 -0
- app/utils/sse_tool_handler.py +612 -0
- app/utils/token_pool.py +455 -0
- app/utils/user_agent.py +133 -0
- longcat_tokens.txt.example +26 -0
- main.py +98 -0
- pyproject.toml +67 -0
- requirements.txt +11 -0
- tests/test_comprehensive_fix.py +289 -0
- tests/test_done_phase.py +231 -0
- tests/test_longcat_connection.py +166 -0
- tests/test_multiple_tools.py +133 -0
- tests/test_simple_performance.py +178 -0
- tokens.txt.example +26 -0
.env.example
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 代理服务配置文件示例
|
| 2 |
+
# 复制此文件为 .env 并根据需要修改配置值
|
| 3 |
+
|
| 4 |
+
# ========== API 基础配置 ==========
|
| 5 |
+
# 客户端认证密钥(您自定义的 API 密钥,用于客户端访问本服务)
|
| 6 |
+
AUTH_TOKEN=sk-your-api-key
|
| 7 |
+
|
| 8 |
+
# 跳过客户端认证(仅开发环境使用)
|
| 9 |
+
SKIP_AUTH_TOKEN=false
|
| 10 |
+
|
| 11 |
+
# ========== Z.ai Token池配置 ==========
|
| 12 |
+
# Token失败阈值(失败多少次后标记为不可用)
|
| 13 |
+
TOKEN_FAILURE_THRESHOLD=3
|
| 14 |
+
|
| 15 |
+
# Token恢复超时时间(秒,失败token在此时间后重新尝试)
|
| 16 |
+
TOKEN_RECOVERY_TIMEOUT=1800
|
| 17 |
+
|
| 18 |
+
# Token健康检查间隔(秒,定期检查token状态)
|
| 19 |
+
TOKEN_HEALTH_CHECK_INTERVAL=300
|
| 20 |
+
|
| 21 |
+
# Z.AI 匿名用户模式
|
| 22 |
+
# false: 使用认证 Token 令牌,失败时自动降级为匿名请求
|
| 23 |
+
# true: 自动从 Z.ai 获取临时访问令牌,避免对话历史共享
|
| 24 |
+
ANONYMOUS_MODE=true
|
| 25 |
+
|
| 26 |
+
# ========== Z.ai 认证token配置(可选) ===========
|
| 27 |
+
# 使用独立的token文件配置(可选)
|
| 28 |
+
# 如果需要认证token,在项目根目录创建 tokens.txt 文件,每行一个token或逗号分隔
|
| 29 |
+
# 如果不需要认证token,想走匿名请求模式,可以注释掉或删除此配置项
|
| 30 |
+
# AUTH_TOKENS_FILE=tokens.txt
|
| 31 |
+
|
| 32 |
+
# ========== LongCat 配置 ==========
|
| 33 |
+
# LongCat passport token(单个token)
|
| 34 |
+
# LONGCAT_PASSPORT_TOKEN=your_passport_token_here
|
| 35 |
+
|
| 36 |
+
# LongCat tokens 文件路径(多个token)
|
| 37 |
+
# LONGCAT_TOKENS_FILE=longcat_tokens.txt
|
| 38 |
+
|
| 39 |
+
# ========== 服务器配置 ==========
|
| 40 |
+
# 服务监听端口
|
| 41 |
+
LISTEN_PORT=8080
|
| 42 |
+
|
| 43 |
+
# 服务名称(用于进程唯一性验证)
|
| 44 |
+
SERVICE_NAME=z-ai2api-server
|
| 45 |
+
|
| 46 |
+
# 调试日志
|
| 47 |
+
DEBUG_LOGGING=false
|
| 48 |
+
|
| 49 |
+
# Function Call 功能开关
|
| 50 |
+
TOOL_SUPPORT=true
|
| 51 |
+
|
| 52 |
+
# 工具调用扫描限制(字符数)
|
| 53 |
+
SCAN_LIMIT=200000
|
| 54 |
+
|
| 55 |
+
# ========== Z.AI 错误码400处理 ==========
|
| 56 |
+
|
| 57 |
+
# 重试次数
|
| 58 |
+
MAX_RETRIES=6
|
| 59 |
+
# 初始重试延迟
|
| 60 |
+
RETRY_DELAY=1
|
Dockerfile
CHANGED
|
@@ -1 +1,10 @@
|
|
| 1 |
-
FROM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY .. .
|
| 9 |
+
|
| 10 |
+
CMD ["python", "main.py"]
|
app/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from app import core, models, utils
|
| 5 |
+
|
| 6 |
+
__all__ = ["core", "models", "utils"]
|
app/core/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from app.core import config, zai_transformer, openai
|
| 5 |
+
|
| 6 |
+
__all__ = ["config", "zai_transformer", "openai"]
|
app/core/config.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
from app.utils.logger import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Settings(BaseSettings):
|
| 11 |
+
"""Application settings"""
|
| 12 |
+
|
| 13 |
+
# API Configuration
|
| 14 |
+
API_ENDPOINT: str = "https://chat.z.ai/api/chat/completions"
|
| 15 |
+
AUTH_TOKEN: str = os.getenv("AUTH_TOKEN", "sk-your-api-key")
|
| 16 |
+
|
| 17 |
+
# 认证token文件路径(可选)
|
| 18 |
+
AUTH_TOKENS_FILE: Optional[str] = os.getenv("AUTH_TOKENS_FILE")
|
| 19 |
+
|
| 20 |
+
# Token池配置
|
| 21 |
+
TOKEN_HEALTH_CHECK_INTERVAL: int = int(os.getenv("TOKEN_HEALTH_CHECK_INTERVAL", "300")) # 5分钟
|
| 22 |
+
TOKEN_FAILURE_THRESHOLD: int = int(os.getenv("TOKEN_FAILURE_THRESHOLD", "3")) # 失败3次后标记为不可用
|
| 23 |
+
TOKEN_RECOVERY_TIMEOUT: int = int(os.getenv("TOKEN_RECOVERY_TIMEOUT", "1800")) # 30分钟后重试失败的token
|
| 24 |
+
|
| 25 |
+
def _load_tokens_from_file(self, file_path: str) -> List[str]:
|
| 26 |
+
"""
|
| 27 |
+
从文件加载token列表
|
| 28 |
+
|
| 29 |
+
支持多种格式的混合使用:
|
| 30 |
+
1. 每行一个token(换行分隔)
|
| 31 |
+
2. 逗号分隔的token
|
| 32 |
+
3. 混合格式(同时支持换行和逗号分隔)
|
| 33 |
+
"""
|
| 34 |
+
tokens = []
|
| 35 |
+
try:
|
| 36 |
+
if os.path.exists(file_path):
|
| 37 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 38 |
+
content = f.read().strip()
|
| 39 |
+
|
| 40 |
+
if not content:
|
| 41 |
+
logger.debug(f"📄 Token文件为空: {file_path}")
|
| 42 |
+
return tokens
|
| 43 |
+
|
| 44 |
+
# 智能解析:同时支持换行和逗号分隔
|
| 45 |
+
# 1. 先按换行符分割处理每一行
|
| 46 |
+
lines = content.split('\n')
|
| 47 |
+
|
| 48 |
+
for line in lines:
|
| 49 |
+
line = line.strip()
|
| 50 |
+
# 跳过空行和注释行
|
| 51 |
+
if not line or line.startswith('#'):
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# 2. 检查当前行是否包含逗号分隔
|
| 55 |
+
if ',' in line:
|
| 56 |
+
# 按逗号分割当前行
|
| 57 |
+
comma_tokens = line.split(',')
|
| 58 |
+
for token in comma_tokens:
|
| 59 |
+
token = token.strip()
|
| 60 |
+
if token: # 跳过空token
|
| 61 |
+
tokens.append(token)
|
| 62 |
+
else:
|
| 63 |
+
# 整行作为一个token
|
| 64 |
+
tokens.append(line)
|
| 65 |
+
|
| 66 |
+
logger.info(f"📄 从文件加载了 {len(tokens)} 个token: {file_path}")
|
| 67 |
+
else:
|
| 68 |
+
logger.debug(f"📄 Token文件不存在: {file_path}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"❌ 读取token文件失败 {file_path}: {e}")
|
| 71 |
+
return tokens
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def auth_token_list(self) -> List[str]:
|
| 75 |
+
"""
|
| 76 |
+
解析认证token列表
|
| 77 |
+
|
| 78 |
+
从AUTH_TOKENS_FILE指定的文件加载token(如果配置了文件路径)
|
| 79 |
+
"""
|
| 80 |
+
# 如果未配置token文件路径,返回空列表
|
| 81 |
+
if not self.AUTH_TOKENS_FILE:
|
| 82 |
+
logger.debug("📄 未配置AUTH_TOKENS_FILE,跳过token文件加载")
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
# 从文件加载token
|
| 86 |
+
tokens = self._load_tokens_from_file(self.AUTH_TOKENS_FILE)
|
| 87 |
+
|
| 88 |
+
# 去重,保持顺序
|
| 89 |
+
if tokens:
|
| 90 |
+
seen = set()
|
| 91 |
+
unique_tokens = []
|
| 92 |
+
for token in tokens:
|
| 93 |
+
if token not in seen:
|
| 94 |
+
unique_tokens.append(token)
|
| 95 |
+
seen.add(token)
|
| 96 |
+
|
| 97 |
+
# 记录去重信息
|
| 98 |
+
duplicate_count = len(tokens) - len(unique_tokens)
|
| 99 |
+
if duplicate_count > 0:
|
| 100 |
+
logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复token,已自动去重")
|
| 101 |
+
|
| 102 |
+
return unique_tokens
|
| 103 |
+
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def longcat_token_list(self) -> List[str]:
|
| 108 |
+
"""
|
| 109 |
+
解析 LongCat token 列表
|
| 110 |
+
|
| 111 |
+
从 LONGCAT_TOKENS_FILE 指定的文件加载 token(如果配置了文件路径)
|
| 112 |
+
"""
|
| 113 |
+
# 如果未配置token文件路径,返回空列表
|
| 114 |
+
if not self.LONGCAT_TOKENS_FILE:
|
| 115 |
+
logger.debug("📄 未配置LONGCAT_TOKENS_FILE,跳过LongCat token文件加载")
|
| 116 |
+
return []
|
| 117 |
+
|
| 118 |
+
# 从文件加载token
|
| 119 |
+
tokens = self._load_tokens_from_file(self.LONGCAT_TOKENS_FILE)
|
| 120 |
+
|
| 121 |
+
# 去重,保持顺序
|
| 122 |
+
if tokens:
|
| 123 |
+
seen = set()
|
| 124 |
+
unique_tokens = []
|
| 125 |
+
for token in tokens:
|
| 126 |
+
if token not in seen:
|
| 127 |
+
unique_tokens.append(token)
|
| 128 |
+
seen.add(token)
|
| 129 |
+
|
| 130 |
+
# 记录去重信息
|
| 131 |
+
duplicate_count = len(tokens) - len(unique_tokens)
|
| 132 |
+
if duplicate_count > 0:
|
| 133 |
+
logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复LongCat token,已自动去重")
|
| 134 |
+
|
| 135 |
+
return unique_tokens
|
| 136 |
+
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
# Model Configuration
|
| 140 |
+
PRIMARY_MODEL: str = os.getenv("PRIMARY_MODEL", "GLM-4.5")
|
| 141 |
+
THINKING_MODEL: str = os.getenv("THINKING_MODEL", "GLM-4.5-Thinking")
|
| 142 |
+
SEARCH_MODEL: str = os.getenv("SEARCH_MODEL", "GLM-4.5-Search")
|
| 143 |
+
AIR_MODEL: str = os.getenv("AIR_MODEL", "GLM-4.5-Air")
|
| 144 |
+
GLM46_MODEL: str = os.getenv("GLM46_MODEL", "GLM-4.6")
|
| 145 |
+
GLM46_THINKING_MODEL: str = os.getenv("GLM46_THINKING_MODEL", "GLM-4.6-Thinking")
|
| 146 |
+
GLM46_SEARCH_MODEL: str = os.getenv("GLM46_SEARCH_MODEL", "GLM-4.6-Search")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Provider Model Mapping
|
| 151 |
+
@property
|
| 152 |
+
def provider_model_mapping(self) -> Dict[str, str]:
|
| 153 |
+
"""模型到提供商的映射"""
|
| 154 |
+
return {
|
| 155 |
+
# Z.AI models
|
| 156 |
+
"GLM-4.5": "zai",
|
| 157 |
+
"GLM-4.5-Thinking": "zai",
|
| 158 |
+
"GLM-4.5-Search": "zai",
|
| 159 |
+
"GLM-4.5-Air": "zai",
|
| 160 |
+
"GLM-4.6": "zai",
|
| 161 |
+
"GLM-4.6-Thinking": "zai",
|
| 162 |
+
"GLM-4.6-Search": "zai",
|
| 163 |
+
# K2Think models
|
| 164 |
+
"MBZUAI-IFM/K2-Think": "k2think",
|
| 165 |
+
# LongCat models
|
| 166 |
+
"LongCat-Flash": "longcat",
|
| 167 |
+
"LongCat": "longcat",
|
| 168 |
+
"LongCat-Search": "longcat",
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
# Server Configuration
|
| 172 |
+
LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080"))
|
| 173 |
+
DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true"
|
| 174 |
+
SERVICE_NAME: str = os.getenv("SERVICE_NAME", "z-ai2api-server")
|
| 175 |
+
|
| 176 |
+
ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true"
|
| 177 |
+
TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
|
| 178 |
+
SCAN_LIMIT: int = int(os.getenv("SCAN_LIMIT", "200000"))
|
| 179 |
+
SKIP_AUTH_TOKEN: bool = os.getenv("SKIP_AUTH_TOKEN", "false").lower() == "true"
|
| 180 |
+
|
| 181 |
+
# LongCat Configuration
|
| 182 |
+
LONGCAT_PASSPORT_TOKEN: Optional[str] = os.getenv("LONGCAT_PASSPORT_TOKEN")
|
| 183 |
+
LONGCAT_TOKENS_FILE: Optional[str] = os.getenv("LONGCAT_TOKENS_FILE")
|
| 184 |
+
|
| 185 |
+
# Retry Configuration
|
| 186 |
+
MAX_RETRIES: int = int(os.getenv("MAX_RETRIES", "5"))
|
| 187 |
+
RETRY_DELAY: float = float(os.getenv("RETRY_DELAY", "1.0")) # 初始重试延迟(秒)
|
| 188 |
+
|
| 189 |
+
# Browser Headers
|
| 190 |
+
CLIENT_HEADERS: Dict[str, str] = {
|
| 191 |
+
"Content-Type": "application/json",
|
| 192 |
+
"Accept": "application/json, text/event-stream",
|
| 193 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0",
|
| 194 |
+
"Accept-Language": "zh-CN",
|
| 195 |
+
"sec-ch-ua": '"Not;A=Brand";v="99", "Microsoft Edge";v="139", "Chromium";v="139"',
|
| 196 |
+
"sec-ch-ua-mobile": "?0",
|
| 197 |
+
"sec-ch-ua-platform": '"Windows"',
|
| 198 |
+
"X-FE-Version": "prod-fe-1.0.70",
|
| 199 |
+
"Origin": "https://chat.z.ai",
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
class Config:
|
| 203 |
+
env_file = ".env"
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
settings = Settings()
|
app/core/openai.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
from fastapi import APIRouter, Header, HTTPException
|
| 8 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
| 9 |
+
|
| 10 |
+
from app.core.config import settings
|
| 11 |
+
from app.models.schemas import OpenAIRequest, Message, ModelsResponse, Model, OpenAIResponse, Choice, Usage
|
| 12 |
+
from app.utils.logger import get_logger
|
| 13 |
+
from app.providers import get_provider_router
|
| 14 |
+
from app.utils.token_pool import get_token_pool
|
| 15 |
+
|
| 16 |
+
logger = get_logger()
|
| 17 |
+
router = APIRouter()
|
| 18 |
+
|
| 19 |
+
# 全局提供商路由器实例
|
| 20 |
+
provider_router = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_provider_router_instance():
|
| 24 |
+
"""获取提供商路由器实例"""
|
| 25 |
+
global provider_router
|
| 26 |
+
if provider_router is None:
|
| 27 |
+
provider_router = get_provider_router()
|
| 28 |
+
return provider_router
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_chunk(chat_id: str, model: str, delta: Dict[str, Any], finish_reason: str = None) -> Dict[str, Any]:
|
| 32 |
+
"""创建标准的 OpenAI chunk 结构"""
|
| 33 |
+
return {
|
| 34 |
+
"choices": [{
|
| 35 |
+
"delta": delta,
|
| 36 |
+
"finish_reason": finish_reason,
|
| 37 |
+
"index": 0,
|
| 38 |
+
"logprobs": None,
|
| 39 |
+
}],
|
| 40 |
+
"created": int(time.time()),
|
| 41 |
+
"id": chat_id,
|
| 42 |
+
"model": model,
|
| 43 |
+
"object": "chat.completion.chunk",
|
| 44 |
+
"system_fingerprint": "fp_zai_001",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def handle_non_stream_response(stream_response, request: OpenAIRequest) -> JSONResponse:
|
| 49 |
+
"""处理非流式响应"""
|
| 50 |
+
logger.info("📄 开始处理非流式响应")
|
| 51 |
+
|
| 52 |
+
# 收集所有流式数据
|
| 53 |
+
full_content = []
|
| 54 |
+
async for chunk_data in stream_response():
|
| 55 |
+
if chunk_data.startswith("data: "):
|
| 56 |
+
chunk_str = chunk_data[6:].strip()
|
| 57 |
+
if chunk_str and chunk_str != "[DONE]":
|
| 58 |
+
try:
|
| 59 |
+
chunk = json.loads(chunk_str)
|
| 60 |
+
if "choices" in chunk and chunk["choices"]:
|
| 61 |
+
choice = chunk["choices"][0]
|
| 62 |
+
if "delta" in choice and "content" in choice["delta"]:
|
| 63 |
+
content = choice["delta"]["content"]
|
| 64 |
+
if content:
|
| 65 |
+
full_content.append(content)
|
| 66 |
+
except json.JSONDecodeError:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
# 构建响应
|
| 70 |
+
response_data = OpenAIResponse(
|
| 71 |
+
id=f"chatcmpl-{int(time.time())}",
|
| 72 |
+
object="chat.completion",
|
| 73 |
+
created=int(time.time()),
|
| 74 |
+
model=request.model,
|
| 75 |
+
choices=[Choice(
|
| 76 |
+
index=0,
|
| 77 |
+
message=Message(
|
| 78 |
+
role="assistant",
|
| 79 |
+
content="".join(full_content),
|
| 80 |
+
tool_calls=None
|
| 81 |
+
),
|
| 82 |
+
finish_reason="stop"
|
| 83 |
+
)],
|
| 84 |
+
usage=Usage(
|
| 85 |
+
prompt_tokens=0,
|
| 86 |
+
completion_tokens=0,
|
| 87 |
+
total_tokens=0
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
logger.info("✅ 非流式响应处理完成")
|
| 92 |
+
return JSONResponse(content=response_data.model_dump(exclude_none=True))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@router.get("/v1/models")
|
| 96 |
+
async def list_models():
|
| 97 |
+
"""List available models from all providers"""
|
| 98 |
+
try:
|
| 99 |
+
router_instance = get_provider_router_instance()
|
| 100 |
+
models_data = router_instance.get_models_list()
|
| 101 |
+
return JSONResponse(content=models_data)
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"❌ 获取模型列表失败: {e}")
|
| 104 |
+
# 返回默认模型列表作为后备
|
| 105 |
+
current_time = int(time.time())
|
| 106 |
+
fallback_response = ModelsResponse(
|
| 107 |
+
data=[
|
| 108 |
+
Model(id=settings.PRIMARY_MODEL, created=current_time, owned_by="z.ai"),
|
| 109 |
+
Model(id=settings.THINKING_MODEL, created=current_time, owned_by="z.ai"),
|
| 110 |
+
Model(id=settings.SEARCH_MODEL, created=current_time, owned_by="z.ai"),
|
| 111 |
+
Model(id=settings.AIR_MODEL, created=current_time, owned_by="z.ai"),
|
| 112 |
+
]
|
| 113 |
+
)
|
| 114 |
+
return fallback_response
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@router.post("/v1/chat/completions")
|
| 118 |
+
async def chat_completions(request: OpenAIRequest, authorization: str = Header(...)):
|
| 119 |
+
"""Handle chat completion requests with multi-provider architecture"""
|
| 120 |
+
role = request.messages[0].role if request.messages else "unknown"
|
| 121 |
+
logger.info(f"😶🌫️ 收到客户端请求 - 模型: {request.model}, 流式: {request.stream}, 消息数: {len(request.messages)}, 角色: {role}, 工具数: {len(request.tools) if request.tools else 0}")
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
# Validate API key (skip if SKIP_AUTH_TOKEN is enabled)
|
| 125 |
+
if not settings.SKIP_AUTH_TOKEN:
|
| 126 |
+
if not authorization.startswith("Bearer "):
|
| 127 |
+
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
|
| 128 |
+
|
| 129 |
+
api_key = authorization[7:]
|
| 130 |
+
if api_key != settings.AUTH_TOKEN:
|
| 131 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 132 |
+
|
| 133 |
+
# 使用多提供商路由器处理请求
|
| 134 |
+
router_instance = get_provider_router_instance()
|
| 135 |
+
result = await router_instance.route_request(request)
|
| 136 |
+
|
| 137 |
+
# 检查是否有错误
|
| 138 |
+
if isinstance(result, dict) and "error" in result:
|
| 139 |
+
error_info = result["error"]
|
| 140 |
+
if error_info.get("code") == "model_not_found":
|
| 141 |
+
raise HTTPException(status_code=404, detail=error_info["message"])
|
| 142 |
+
else:
|
| 143 |
+
raise HTTPException(status_code=500, detail=error_info["message"])
|
| 144 |
+
|
| 145 |
+
# 处理响应
|
| 146 |
+
if request.stream:
|
| 147 |
+
# 流式响应
|
| 148 |
+
if hasattr(result, '__aiter__'):
|
| 149 |
+
# 结果是异步生成器
|
| 150 |
+
return StreamingResponse(
|
| 151 |
+
result,
|
| 152 |
+
media_type="text/event-stream",
|
| 153 |
+
headers={
|
| 154 |
+
"Cache-Control": "no-cache",
|
| 155 |
+
"Connection": "keep-alive",
|
| 156 |
+
"Access-Control-Allow-Origin": "*",
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
# 结果是字典,可能包含错误
|
| 161 |
+
raise HTTPException(status_code=500, detail="Expected streaming response but got non-streaming result")
|
| 162 |
+
else:
|
| 163 |
+
# 非流式响应
|
| 164 |
+
if isinstance(result, dict):
|
| 165 |
+
return JSONResponse(content=result)
|
| 166 |
+
else:
|
| 167 |
+
# 如果是异步生成器,需要收集所有内容
|
| 168 |
+
return await handle_non_stream_response(result, request)
|
| 169 |
+
|
| 170 |
+
except HTTPException:
|
| 171 |
+
# 重新抛出 HTTP 异常
|
| 172 |
+
raise
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"❌ 请求处理失败: {e}")
|
| 175 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# Token pool management endpoints
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@router.get("/v1/token-pool/status")
|
| 182 |
+
async def get_token_pool_status():
|
| 183 |
+
"""获取token池状态信息"""
|
| 184 |
+
try:
|
| 185 |
+
token_pool = get_token_pool()
|
| 186 |
+
if not token_pool:
|
| 187 |
+
return {
|
| 188 |
+
"status": "disabled",
|
| 189 |
+
"message": "Token池未初始化,当前仅使用匿名模式",
|
| 190 |
+
"anonymous_mode": settings.ANONYMOUS_MODE,
|
| 191 |
+
"auth_tokens_file": settings.AUTH_TOKENS_FILE,
|
| 192 |
+
"auth_tokens_configured": len(settings.auth_token_list) > 0
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
pool_status = token_pool.get_pool_status()
|
| 196 |
+
return {
|
| 197 |
+
"status": "active",
|
| 198 |
+
"pool_info": pool_status,
|
| 199 |
+
"config": {
|
| 200 |
+
"anonymous_mode": settings.ANONYMOUS_MODE,
|
| 201 |
+
"failure_threshold": settings.TOKEN_FAILURE_THRESHOLD,
|
| 202 |
+
"recovery_timeout": settings.TOKEN_RECOVERY_TIMEOUT,
|
| 203 |
+
"health_check_interval": settings.TOKEN_HEALTH_CHECK_INTERVAL
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error(f"获取token池状态失败: {e}")
|
| 208 |
+
raise HTTPException(status_code=500, detail=f"Failed to get token pool status: {str(e)}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@router.post("/v1/token-pool/health-check")
|
| 212 |
+
async def trigger_health_check():
|
| 213 |
+
"""手动触发token池健康检查"""
|
| 214 |
+
try:
|
| 215 |
+
token_pool = get_token_pool()
|
| 216 |
+
if not token_pool:
|
| 217 |
+
raise HTTPException(status_code=404, detail="Token池未初始化")
|
| 218 |
+
|
| 219 |
+
start_time = time.time()
|
| 220 |
+
logger.info("🔍 API触发Token池健康检查...")
|
| 221 |
+
await token_pool.health_check_all()
|
| 222 |
+
duration = time.time() - start_time
|
| 223 |
+
|
| 224 |
+
pool_status = token_pool.get_pool_status()
|
| 225 |
+
total_tokens = pool_status['total_tokens']
|
| 226 |
+
healthy_tokens = sum(1 for token_info in pool_status['tokens'] if token_info['is_healthy'])
|
| 227 |
+
|
| 228 |
+
response = {
|
| 229 |
+
"status": "completed",
|
| 230 |
+
"message": f"健康检查已完成,耗时 {duration:.2f} 秒",
|
| 231 |
+
"summary": {
|
| 232 |
+
"total_tokens": total_tokens,
|
| 233 |
+
"healthy_tokens": healthy_tokens,
|
| 234 |
+
"unhealthy_tokens": total_tokens - healthy_tokens,
|
| 235 |
+
"health_rate": f"{(healthy_tokens/total_tokens*100):.1f}%" if total_tokens > 0 else "0%",
|
| 236 |
+
"duration_seconds": round(duration, 2)
|
| 237 |
+
},
|
| 238 |
+
"pool_info": pool_status
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
logger.info(f"✅ API健康检查完成: {healthy_tokens}/{total_tokens} 个token健康")
|
| 242 |
+
return response
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.error(f"健康检查失败: {e}")
|
| 245 |
+
raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@router.post("/v1/token-pool/update")
|
| 249 |
+
async def update_token_pool_endpoint(tokens: List[str]):
|
| 250 |
+
"""动态更新token池"""
|
| 251 |
+
try:
|
| 252 |
+
from app.utils.token_pool import update_token_pool
|
| 253 |
+
|
| 254 |
+
valid_tokens = [token.strip() for token in tokens if token.strip()]
|
| 255 |
+
if not valid_tokens:
|
| 256 |
+
raise HTTPException(status_code=400, detail="至少需要提供一个有效的token")
|
| 257 |
+
|
| 258 |
+
update_token_pool(valid_tokens)
|
| 259 |
+
token_pool = get_token_pool()
|
| 260 |
+
|
| 261 |
+
return {
|
| 262 |
+
"status": "updated",
|
| 263 |
+
"message": f"Token池已更新,共 {len(valid_tokens)} 个token",
|
| 264 |
+
"pool_info": token_pool.get_pool_status() if token_pool else None
|
| 265 |
+
}
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"更新token池失败: {e}")
|
| 268 |
+
raise HTTPException(status_code=500, detail=f"Failed to update token pool: {str(e)}")
|
app/core/zai_transformer.py
ADDED
|
@@ -0,0 +1,777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import random
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Dict, List, Any, Optional, Generator, AsyncGenerator
|
| 10 |
+
import httpx
|
| 11 |
+
import asyncio
|
| 12 |
+
|
| 13 |
+
from app.core.config import settings
|
| 14 |
+
from app.utils.logger import get_logger
|
| 15 |
+
from app.utils.token_pool import get_token_pool, initialize_token_pool
|
| 16 |
+
from app.utils.user_agent import get_random_user_agent
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_zai_dynamic_headers(chat_id: str = "") -> Dict[str, str]:
|
| 22 |
+
"""
|
| 23 |
+
生成 Z.AI 特定的动态浏览器 headers,包含随机 User-Agent
|
| 24 |
+
使用通用的 UserAgent 工具,但添加 Z.AI 特定的业务逻辑
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
chat_id: 聊天 ID,用于生成正确的 Referer
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Dict[str, str]: 包含 Z.AI 特定配置的 headers
|
| 31 |
+
"""
|
| 32 |
+
# 随机选择浏览器类型,偏向Chrome和Edge
|
| 33 |
+
browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"]
|
| 34 |
+
browser_type = random.choice(browser_choices)
|
| 35 |
+
|
| 36 |
+
user_agent = get_random_user_agent(browser_type)
|
| 37 |
+
|
| 38 |
+
# 提取版本信息
|
| 39 |
+
chrome_version = "139"
|
| 40 |
+
edge_version = "139"
|
| 41 |
+
|
| 42 |
+
if "Chrome/" in user_agent:
|
| 43 |
+
try:
|
| 44 |
+
chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
if "Edg/" in user_agent:
|
| 49 |
+
try:
|
| 50 |
+
edge_version = user_agent.split("Edg/")[1].split(".")[0]
|
| 51 |
+
sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
|
| 52 |
+
except:
|
| 53 |
+
sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
|
| 54 |
+
elif "Firefox/" in user_agent:
|
| 55 |
+
sec_ch_ua = None # Firefox不使用sec-ch-ua
|
| 56 |
+
else:
|
| 57 |
+
sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
|
| 58 |
+
|
| 59 |
+
# Z.AI 特定的 headers
|
| 60 |
+
headers = {
|
| 61 |
+
"Content-Type": "application/json",
|
| 62 |
+
"Accept": "application/json, text/event-stream",
|
| 63 |
+
"User-Agent": user_agent,
|
| 64 |
+
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
|
| 65 |
+
"X-FE-Version": "prod-fe-1.0.79",
|
| 66 |
+
"Origin": "https://chat.z.ai",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# 添加浏览器特定的 sec-ch-ua headers
|
| 70 |
+
if sec_ch_ua:
|
| 71 |
+
headers["sec-ch-ua"] = sec_ch_ua
|
| 72 |
+
headers["sec-ch-ua-mobile"] = "?0"
|
| 73 |
+
headers["sec-ch-ua-platform"] = '"Windows"'
|
| 74 |
+
|
| 75 |
+
# 根据 chat_id 设置 Referer
|
| 76 |
+
if chat_id:
|
| 77 |
+
headers["Referer"] = f"https://chat.z.ai/c/{chat_id}"
|
| 78 |
+
else:
|
| 79 |
+
headers["Referer"] = "https://chat.z.ai/"
|
| 80 |
+
|
| 81 |
+
return headers
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generate_uuid() -> str:
|
| 86 |
+
"""生成UUID v4"""
|
| 87 |
+
return str(uuid.uuid4())
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_auth_token_sync() -> str:
|
| 91 |
+
"""同步获取认证令牌(用于非异步场景)"""
|
| 92 |
+
# 如果启用匿名模式,只尝试获取访客令牌
|
| 93 |
+
if settings.ANONYMOUS_MODE:
|
| 94 |
+
try:
|
| 95 |
+
headers = get_zai_dynamic_headers()
|
| 96 |
+
with httpx.Client() as client:
|
| 97 |
+
response = client.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10.0)
|
| 98 |
+
if response.status_code == 200:
|
| 99 |
+
data = response.json()
|
| 100 |
+
token = data.get("token", "")
|
| 101 |
+
if token:
|
| 102 |
+
logger.debug(f"获取访客令牌成功: {token[:20]}...")
|
| 103 |
+
return token
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.warning(f"获取访客令牌失败: {e}")
|
| 106 |
+
|
| 107 |
+
# 匿名模式下,如果获取访客令牌失败,直接返回空
|
| 108 |
+
logger.error("❌ 匿名模式下获取访客令牌失败")
|
| 109 |
+
return ""
|
| 110 |
+
|
| 111 |
+
# 非匿名模式:首先使用token池获取备份令牌
|
| 112 |
+
token_pool = get_token_pool()
|
| 113 |
+
if token_pool:
|
| 114 |
+
token = token_pool.get_next_token()
|
| 115 |
+
if token:
|
| 116 |
+
logger.debug(f"从token池获取令牌: {token[:20]}...")
|
| 117 |
+
return token
|
| 118 |
+
|
| 119 |
+
# 如果没有备份token,尝试降级到匿名模式
|
| 120 |
+
logger.warning("⚠️ 没有可用的备份token,尝试降级到匿名模式...")
|
| 121 |
+
try:
|
| 122 |
+
headers = get_zai_dynamic_headers()
|
| 123 |
+
with httpx.Client() as client:
|
| 124 |
+
response = client.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10.0)
|
| 125 |
+
if response.status_code == 200:
|
| 126 |
+
data = response.json()
|
| 127 |
+
token = data.get("token", "")
|
| 128 |
+
if token:
|
| 129 |
+
logger.info(f"✅ 降级到匿名模式成功: {token[:20]}...")
|
| 130 |
+
return token
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.warning(f"降级到匿名模式失败: {e}")
|
| 133 |
+
|
| 134 |
+
# 没有可用的token
|
| 135 |
+
logger.error("❌ 所有认证方式都失败了")
|
| 136 |
+
return ""
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ZAITransformer:
|
| 140 |
+
"""ZAI转换器类"""
|
| 141 |
+
|
| 142 |
+
def __init__(self):
|
| 143 |
+
"""初始化转换器"""
|
| 144 |
+
self.name = "zai"
|
| 145 |
+
self.base_url = "https://chat.z.ai"
|
| 146 |
+
self.api_url = settings.API_ENDPOINT
|
| 147 |
+
self.auth_url = f"{self.base_url}/api/v1/auths/"
|
| 148 |
+
|
| 149 |
+
# 模型映射
|
| 150 |
+
self.model_mapping = {
|
| 151 |
+
settings.PRIMARY_MODEL: "0727-360B-API", # GLM-4.5
|
| 152 |
+
settings.THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking
|
| 153 |
+
settings.SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search
|
| 154 |
+
settings.AIR_MODEL: "0727-106B-API", # GLM-4.5-Air
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
async def get_token(self) -> str:
|
| 158 |
+
"""异步获取认证令牌"""
|
| 159 |
+
# 如果启用匿名模式,只尝试获取访客令牌
|
| 160 |
+
if settings.ANONYMOUS_MODE:
|
| 161 |
+
try:
|
| 162 |
+
headers = get_zai_dynamic_headers()
|
| 163 |
+
async with httpx.AsyncClient() as client:
|
| 164 |
+
response = await client.get(self.auth_url, headers=headers, timeout=10.0)
|
| 165 |
+
if response.status_code == 200:
|
| 166 |
+
data = response.json()
|
| 167 |
+
token = data.get("token", "")
|
| 168 |
+
if token:
|
| 169 |
+
logger.debug(f"获取访客令牌成功: {token[:20]}...")
|
| 170 |
+
return token
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.warning(f"异步获取访客令牌失败: {e}")
|
| 173 |
+
|
| 174 |
+
# 匿名模式下,如果获取访客令牌失败,直接返回空
|
| 175 |
+
logger.error("❌ 匿名模式下获取访客令牌失败")
|
| 176 |
+
return ""
|
| 177 |
+
|
| 178 |
+
# 非匿名模式:首先使用token池获取备份令牌
|
| 179 |
+
token_pool = get_token_pool()
|
| 180 |
+
if token_pool:
|
| 181 |
+
token = token_pool.get_next_token()
|
| 182 |
+
if token:
|
| 183 |
+
logger.debug(f"从token池获取令牌: {token[:20]}...")
|
| 184 |
+
return token
|
| 185 |
+
|
| 186 |
+
# 如果没有备份token,尝试降级到匿名模式
|
| 187 |
+
logger.warning("⚠️ 没有可用的备份token,尝试降级到匿名模式...")
|
| 188 |
+
try:
|
| 189 |
+
headers = get_zai_dynamic_headers()
|
| 190 |
+
async with httpx.AsyncClient() as client:
|
| 191 |
+
response = await client.get(self.auth_url, headers=headers, timeout=10.0)
|
| 192 |
+
if response.status_code == 200:
|
| 193 |
+
data = response.json()
|
| 194 |
+
token = data.get("token", "")
|
| 195 |
+
if token:
|
| 196 |
+
logger.info(f"✅ 降级到匿名模式成功: {token[:20]}...")
|
| 197 |
+
return token
|
| 198 |
+
except Exception as e:
|
| 199 |
+
logger.warning(f"降级到匿名模式失败: {e}")
|
| 200 |
+
|
| 201 |
+
# 没有可用的token
|
| 202 |
+
logger.error("❌ 所有认证方式都失败了")
|
| 203 |
+
return ""
|
| 204 |
+
|
| 205 |
+
def mark_token_success(self, token: str):
|
| 206 |
+
"""标记token使用成功"""
|
| 207 |
+
token_pool = get_token_pool()
|
| 208 |
+
if token_pool:
|
| 209 |
+
token_pool.mark_token_success(token)
|
| 210 |
+
|
| 211 |
+
def mark_token_failure(self, token: str, error: Exception = None):
|
| 212 |
+
"""标记token使用失败"""
|
| 213 |
+
token_pool = get_token_pool()
|
| 214 |
+
if token_pool:
|
| 215 |
+
token_pool.mark_token_failure(token, error)
|
| 216 |
+
|
| 217 |
+
async def transform_request_in(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
| 218 |
+
"""
|
| 219 |
+
转换OpenAI请求为z.ai格式
|
| 220 |
+
整合现有功能:模型映射、MCP服务器等
|
| 221 |
+
"""
|
| 222 |
+
logger.info(f"🔄 开始转换 OpenAI 请求到 Z.AI 格式: {request.get('model', settings.PRIMARY_MODEL)} -> Z.AI")
|
| 223 |
+
|
| 224 |
+
# 获取认证令牌
|
| 225 |
+
token = await self.get_token()
|
| 226 |
+
logger.debug(f" 使用令牌: {token[:20] if token else 'None'}...")
|
| 227 |
+
|
| 228 |
+
# 检查token是否有效
|
| 229 |
+
if not token:
|
| 230 |
+
# 提供详细的配置建议
|
| 231 |
+
error_msg = "❌ 无法获取有效的认证令牌"
|
| 232 |
+
suggestions = []
|
| 233 |
+
|
| 234 |
+
if not settings.ANONYMOUS_MODE:
|
| 235 |
+
suggestions.append("1. 设置 ANONYMOUS_MODE=true 启用匿名模式")
|
| 236 |
+
|
| 237 |
+
if not settings.AUTH_TOKENS_FILE:
|
| 238 |
+
suggestions.append("2. 配置 AUTH_TOKENS_FILE 并创建对应的token文件")
|
| 239 |
+
elif settings.AUTH_TOKENS_FILE and not settings.auth_token_list:
|
| 240 |
+
suggestions.append(f"3. 检查token文件 '{settings.AUTH_TOKENS_FILE}' 是否存在且包含有效token")
|
| 241 |
+
|
| 242 |
+
if suggestions:
|
| 243 |
+
error_msg += "\n建议的解决方案:\n" + "\n".join(suggestions)
|
| 244 |
+
|
| 245 |
+
logger.error(error_msg)
|
| 246 |
+
raise Exception("无法获取有效的认证令牌,请检查配置")
|
| 247 |
+
|
| 248 |
+
# 确定请求的模型特性
|
| 249 |
+
requested_model = request.get("model", settings.PRIMARY_MODEL)
|
| 250 |
+
is_thinking = requested_model == settings.THINKING_MODEL or request.get("reasoning", False)
|
| 251 |
+
is_search = requested_model == settings.SEARCH_MODEL
|
| 252 |
+
is_air = requested_model == settings.AIR_MODEL
|
| 253 |
+
|
| 254 |
+
# 获取上游模型ID(使用模型映射)
|
| 255 |
+
upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
|
| 256 |
+
logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
|
| 257 |
+
|
| 258 |
+
# 处理消息列表
|
| 259 |
+
logger.debug(f" 开始处理 {len(request.get('messages', []))} 条消息")
|
| 260 |
+
messages = []
|
| 261 |
+
for idx, orig_msg in enumerate(request.get("messages", [])):
|
| 262 |
+
msg = orig_msg.copy()
|
| 263 |
+
|
| 264 |
+
# 处理system角色转换
|
| 265 |
+
if msg.get("role") == "system":
|
| 266 |
+
|
| 267 |
+
msg["role"] = "user"
|
| 268 |
+
content = msg.get("content")
|
| 269 |
+
|
| 270 |
+
if isinstance(content, list):
|
| 271 |
+
msg["content"] = [
|
| 272 |
+
{"type": "text", "text": "This is a system command, you must enforce compliance."}
|
| 273 |
+
] + content
|
| 274 |
+
elif isinstance(content, str):
|
| 275 |
+
msg["content"] = f"This is a system command, you must enforce compliance.{content}"
|
| 276 |
+
|
| 277 |
+
# 处理user角色的图片内容
|
| 278 |
+
elif msg.get("role") == "user":
|
| 279 |
+
content = msg.get("content")
|
| 280 |
+
if isinstance(content, list):
|
| 281 |
+
new_content = []
|
| 282 |
+
for part_idx, part in enumerate(content):
|
| 283 |
+
# 处理图片URL(支持base64和http URL)
|
| 284 |
+
if (
|
| 285 |
+
part.get("type") == "image_url"
|
| 286 |
+
and part.get("image_url", {}).get("url")
|
| 287 |
+
and isinstance(part["image_url"]["url"], str)
|
| 288 |
+
):
|
| 289 |
+
logger.debug(f" 消息[{idx}]内容[{part_idx}]: 检测到图片URL")
|
| 290 |
+
# 直接传递图片内容
|
| 291 |
+
new_content.append(part)
|
| 292 |
+
else:
|
| 293 |
+
new_content.append(part)
|
| 294 |
+
msg["content"] = new_content
|
| 295 |
+
|
| 296 |
+
# 处理assistant消息中的reasoning_content
|
| 297 |
+
elif msg.get("role") == "assistant" and msg.get("reasoning_content"):
|
| 298 |
+
|
| 299 |
+
# 如果有reasoning_content,保留它
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
messages.append(msg)
|
| 303 |
+
|
| 304 |
+
# 构建MCP服务器列表
|
| 305 |
+
mcp_servers = []
|
| 306 |
+
if is_search:
|
| 307 |
+
mcp_servers.append("deep-web-search")
|
| 308 |
+
logger.info(f"🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
|
| 309 |
+
else:
|
| 310 |
+
logger.debug(f" 非搜索模型,不添加 MCP 服务器")
|
| 311 |
+
|
| 312 |
+
logger.debug(f" MCP服务器列表: {mcp_servers}")
|
| 313 |
+
|
| 314 |
+
# 构建上游请求体
|
| 315 |
+
chat_id = generate_uuid()
|
| 316 |
+
|
| 317 |
+
body = {
|
| 318 |
+
"stream": True, # 总是使用流式
|
| 319 |
+
"model": upstream_model_id, # 使用映射后的模型ID
|
| 320 |
+
"messages": messages,
|
| 321 |
+
"params": {},
|
| 322 |
+
"features": {
|
| 323 |
+
"image_generation": False,
|
| 324 |
+
"web_search": is_search,
|
| 325 |
+
"auto_web_search": is_search,
|
| 326 |
+
"preview_mode": False,
|
| 327 |
+
"flags": [],
|
| 328 |
+
"features": [],
|
| 329 |
+
"enable_thinking": is_thinking,
|
| 330 |
+
},
|
| 331 |
+
"background_tasks": {
|
| 332 |
+
"title_generation": False,
|
| 333 |
+
"tags_generation": False,
|
| 334 |
+
},
|
| 335 |
+
"mcp_servers": mcp_servers, # 保留MCP服务器支持
|
| 336 |
+
"variables": {
|
| 337 |
+
"{{USER_NAME}}": "Guest",
|
| 338 |
+
"{{USER_LOCATION}}": "Unknown",
|
| 339 |
+
"{{CURRENT_DATETIME}}": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 340 |
+
"{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
|
| 341 |
+
"{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
|
| 342 |
+
"{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
|
| 343 |
+
"{{CURRENT_TIMEZONE}}": "Asia/Shanghai", # 使用更合适的时区
|
| 344 |
+
"{{USER_LANGUAGE}}": "zh-CN",
|
| 345 |
+
},
|
| 346 |
+
"model_item": {
|
| 347 |
+
"id": upstream_model_id,
|
| 348 |
+
"name": requested_model,
|
| 349 |
+
"owned_by": "z.ai"
|
| 350 |
+
},
|
| 351 |
+
"chat_id": chat_id,
|
| 352 |
+
"id": generate_uuid(),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
# 处理工具支持
|
| 356 |
+
if settings.TOOL_SUPPORT and not is_thinking and request.get("tools"):
|
| 357 |
+
body["tools"] = request["tools"]
|
| 358 |
+
logger.info(f"启用工具支持: {len(request['tools'])} 个工具")
|
| 359 |
+
else:
|
| 360 |
+
body["tools"] = None
|
| 361 |
+
|
| 362 |
+
# 构建请求配置
|
| 363 |
+
dynamic_headers = get_zai_dynamic_headers(chat_id)
|
| 364 |
+
|
| 365 |
+
config = {
|
| 366 |
+
"url": self.api_url, # 使用原始URL
|
| 367 |
+
"headers": {
|
| 368 |
+
**dynamic_headers, # 使用动态生成的headers
|
| 369 |
+
"Authorization": f"Bearer {token}",
|
| 370 |
+
"Cache-Control": "no-cache",
|
| 371 |
+
"Connection": "keep-alive",
|
| 372 |
+
"Pragma": "no-cache",
|
| 373 |
+
"Sec-Fetch-Dest": "empty",
|
| 374 |
+
"Sec-Fetch-Mode": "cors",
|
| 375 |
+
"Sec-Fetch-Site": "same-origin",
|
| 376 |
+
},
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
logger.info("✅ 请求转换完成")
|
| 380 |
+
|
| 381 |
+
# 记录关键的请求信息用于调试
|
| 382 |
+
logger.debug(f" 📋 发送到Z.AI的关键信息:")
|
| 383 |
+
logger.debug(f" - 上游模型: {body['model']}")
|
| 384 |
+
logger.debug(f" - MCP服务器: {body['mcp_servers']}")
|
| 385 |
+
logger.debug(f" - web_search: {body['features']['web_search']}")
|
| 386 |
+
logger.debug(f" - auto_web_search: {body['features']['auto_web_search']}")
|
| 387 |
+
logger.debug(f" - 消息数量: {len(body['messages'])}")
|
| 388 |
+
tools_count = len(body.get('tools') or [])
|
| 389 |
+
logger.debug(f" - 工具数量: {tools_count}")
|
| 390 |
+
|
| 391 |
+
# 返回转换后的请求数据
|
| 392 |
+
return {
|
| 393 |
+
"body": body,
|
| 394 |
+
"config": config,
|
| 395 |
+
"token": token
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
async def transform_response_out(
|
| 399 |
+
self, response_stream: Generator, context: Dict[str, Any]
|
| 400 |
+
) -> AsyncGenerator[str, None]:
|
| 401 |
+
"""
|
| 402 |
+
转换z.ai响应为OpenAI格式
|
| 403 |
+
支持流式和非流式输出
|
| 404 |
+
"""
|
| 405 |
+
is_stream = context.get("req", {}).get("body", {}).get("stream", True)
|
| 406 |
+
|
| 407 |
+
# 初始化结果对象(用于非流式)
|
| 408 |
+
result = {
|
| 409 |
+
"id": "",
|
| 410 |
+
"choices": [
|
| 411 |
+
{
|
| 412 |
+
"finish_reason": None,
|
| 413 |
+
"index": 0,
|
| 414 |
+
"message": {
|
| 415 |
+
"content": "",
|
| 416 |
+
"role": "assistant",
|
| 417 |
+
},
|
| 418 |
+
}
|
| 419 |
+
],
|
| 420 |
+
"created": int(time.time()),
|
| 421 |
+
"model": context.get("req", {}).get("body", {}).get("model", ""),
|
| 422 |
+
"object": "chat.completion",
|
| 423 |
+
"usage": {
|
| 424 |
+
"completion_tokens": 0,
|
| 425 |
+
"prompt_tokens": 0,
|
| 426 |
+
"total_tokens": 0,
|
| 427 |
+
},
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
# 状态变量
|
| 431 |
+
current_id = ""
|
| 432 |
+
current_model = context.get("req", {}).get("body", {}).get("model", "")
|
| 433 |
+
has_tool_call = False
|
| 434 |
+
tool_args = ""
|
| 435 |
+
tool_id = ""
|
| 436 |
+
tool_call_usage = None
|
| 437 |
+
content_index = 0
|
| 438 |
+
has_thinking = False
|
| 439 |
+
|
| 440 |
+
async for line in response_stream:
|
| 441 |
+
if not line.strip():
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
if line.startswith("data:"):
|
| 445 |
+
chunk_str = line[5:].strip()
|
| 446 |
+
if not chunk_str:
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
try:
|
| 450 |
+
chunk = json.loads(chunk_str)
|
| 451 |
+
|
| 452 |
+
if chunk.get("type") == "chat:completion":
|
| 453 |
+
data = chunk.get("data", {})
|
| 454 |
+
|
| 455 |
+
# 保存ID和模型信息
|
| 456 |
+
if data.get("id"):
|
| 457 |
+
current_id = data["id"]
|
| 458 |
+
if data.get("model"):
|
| 459 |
+
current_model = data["model"]
|
| 460 |
+
|
| 461 |
+
# 处理不同阶段
|
| 462 |
+
phase = data.get("phase")
|
| 463 |
+
|
| 464 |
+
if phase == "tool_call":
|
| 465 |
+
# 处理工具调用
|
| 466 |
+
if not has_tool_call:
|
| 467 |
+
has_tool_call = True
|
| 468 |
+
|
| 469 |
+
if is_stream:
|
| 470 |
+
# 发送初始角色
|
| 471 |
+
role_chunk = {
|
| 472 |
+
"choices": [
|
| 473 |
+
{
|
| 474 |
+
"delta": {"role": "assistant"},
|
| 475 |
+
"finish_reason": None,
|
| 476 |
+
"index": 0,
|
| 477 |
+
}
|
| 478 |
+
],
|
| 479 |
+
"created": int(time.time()),
|
| 480 |
+
"id": current_id,
|
| 481 |
+
"model": current_model,
|
| 482 |
+
"object": "chat.completion.chunk",
|
| 483 |
+
}
|
| 484 |
+
yield f"data: {json.dumps(role_chunk)}\n\n"
|
| 485 |
+
|
| 486 |
+
# 处理工具调用块
|
| 487 |
+
tool_call_id = data.get("tool_call", {}).get("id", "")
|
| 488 |
+
tool_name = data.get("tool_call", {}).get("name", "")
|
| 489 |
+
delta_args = data.get("delta_tool_call", {}).get("arguments", "")
|
| 490 |
+
|
| 491 |
+
if tool_call_id and tool_call_id != tool_id:
|
| 492 |
+
# 新工具调用
|
| 493 |
+
if tool_id and is_stream:
|
| 494 |
+
# 关闭前一个工具调用
|
| 495 |
+
close_chunk = {
|
| 496 |
+
"choices": [
|
| 497 |
+
{
|
| 498 |
+
"delta": {
|
| 499 |
+
"tool_calls": [
|
| 500 |
+
{"index": content_index, "function": {"arguments": ""}}
|
| 501 |
+
]
|
| 502 |
+
},
|
| 503 |
+
"finish_reason": None,
|
| 504 |
+
"index": 0,
|
| 505 |
+
}
|
| 506 |
+
],
|
| 507 |
+
"created": int(time.time()),
|
| 508 |
+
"id": current_id,
|
| 509 |
+
"model": current_model,
|
| 510 |
+
"object": "chat.completion.chunk",
|
| 511 |
+
}
|
| 512 |
+
yield f"data: {json.dumps(close_chunk)}\n\n"
|
| 513 |
+
content_index += 1
|
| 514 |
+
|
| 515 |
+
tool_id = tool_call_id
|
| 516 |
+
tool_args = ""
|
| 517 |
+
|
| 518 |
+
if is_stream:
|
| 519 |
+
# 发送新工具调用
|
| 520 |
+
new_tool_chunk = {
|
| 521 |
+
"choices": [
|
| 522 |
+
{
|
| 523 |
+
"delta": {
|
| 524 |
+
"tool_calls": [
|
| 525 |
+
{
|
| 526 |
+
"index": content_index,
|
| 527 |
+
"id": tool_call_id,
|
| 528 |
+
"type": "function",
|
| 529 |
+
"function": {"name": tool_name, "arguments": ""},
|
| 530 |
+
}
|
| 531 |
+
]
|
| 532 |
+
},
|
| 533 |
+
"finish_reason": None,
|
| 534 |
+
"index": 0,
|
| 535 |
+
}
|
| 536 |
+
],
|
| 537 |
+
"created": int(time.time()),
|
| 538 |
+
"id": current_id,
|
| 539 |
+
"model": current_model,
|
| 540 |
+
"object": "chat.completion.chunk",
|
| 541 |
+
}
|
| 542 |
+
yield f"data: {json.dumps(new_tool_chunk)}\n\n"
|
| 543 |
+
|
| 544 |
+
# 处理参数增量
|
| 545 |
+
if delta_args:
|
| 546 |
+
tool_args += delta_args
|
| 547 |
+
if is_stream:
|
| 548 |
+
args_chunk = {
|
| 549 |
+
"choices": [
|
| 550 |
+
{
|
| 551 |
+
"delta": {
|
| 552 |
+
"tool_calls": [
|
| 553 |
+
{
|
| 554 |
+
"index": content_index,
|
| 555 |
+
"function": {"arguments": delta_args},
|
| 556 |
+
}
|
| 557 |
+
]
|
| 558 |
+
},
|
| 559 |
+
"finish_reason": None,
|
| 560 |
+
"index": 0,
|
| 561 |
+
}
|
| 562 |
+
],
|
| 563 |
+
"created": int(time.time()),
|
| 564 |
+
"id": current_id,
|
| 565 |
+
"model": current_model,
|
| 566 |
+
"object": "chat.completion.chunk",
|
| 567 |
+
}
|
| 568 |
+
yield f"data: {json.dumps(args_chunk)}\n\n"
|
| 569 |
+
|
| 570 |
+
elif phase == "thinking":
|
| 571 |
+
# 处理思考内容
|
| 572 |
+
if not has_thinking:
|
| 573 |
+
has_thinking = True
|
| 574 |
+
# 初始化thinking字段
|
| 575 |
+
if not is_stream:
|
| 576 |
+
result["choices"][0]["message"]["thinking"] = {"content": ""}
|
| 577 |
+
|
| 578 |
+
if is_stream:
|
| 579 |
+
# 发送初始角色
|
| 580 |
+
role_chunk = {
|
| 581 |
+
"choices": [
|
| 582 |
+
{
|
| 583 |
+
"delta": {"role": "assistant"},
|
| 584 |
+
"finish_reason": None,
|
| 585 |
+
"index": 0,
|
| 586 |
+
}
|
| 587 |
+
],
|
| 588 |
+
"created": int(time.time()),
|
| 589 |
+
"id": current_id,
|
| 590 |
+
"model": current_model,
|
| 591 |
+
"object": "chat.completion.chunk",
|
| 592 |
+
}
|
| 593 |
+
yield f"data: {json.dumps(role_chunk)}\n\n"
|
| 594 |
+
|
| 595 |
+
delta_content = data.get("delta_content", "")
|
| 596 |
+
if delta_content:
|
| 597 |
+
# 处理思考内容格式
|
| 598 |
+
if delta_content.startswith("<details"):
|
| 599 |
+
content = (
|
| 600 |
+
delta_content.split("</summary>\n>")[-1].strip()
|
| 601 |
+
if "</summary>\n>" in delta_content
|
| 602 |
+
else delta_content
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
content = delta_content
|
| 606 |
+
|
| 607 |
+
if is_stream:
|
| 608 |
+
thinking_chunk = {
|
| 609 |
+
"choices": [
|
| 610 |
+
{
|
| 611 |
+
"delta": {"thinking": {"content": content}},
|
| 612 |
+
"finish_reason": None,
|
| 613 |
+
"index": 0,
|
| 614 |
+
}
|
| 615 |
+
],
|
| 616 |
+
"created": int(time.time()),
|
| 617 |
+
"id": current_id,
|
| 618 |
+
"model": current_model,
|
| 619 |
+
"object": "chat.completion.chunk",
|
| 620 |
+
}
|
| 621 |
+
yield f"data: {json.dumps(thinking_chunk)}\n\n"
|
| 622 |
+
else:
|
| 623 |
+
result["choices"][0]["message"]["thinking"]["content"] += content
|
| 624 |
+
|
| 625 |
+
elif phase == "answer":
|
| 626 |
+
# 处理答案内容
|
| 627 |
+
edit_content = data.get("edit_content", "")
|
| 628 |
+
delta_content = data.get("delta_content", "")
|
| 629 |
+
|
| 630 |
+
# 处理思考结束和答案开始
|
| 631 |
+
if edit_content and "</details>\n" in edit_content:
|
| 632 |
+
if has_thinking:
|
| 633 |
+
signature = str(int(time.time() * 1000))
|
| 634 |
+
|
| 635 |
+
if is_stream:
|
| 636 |
+
# 发送思考签名
|
| 637 |
+
sig_chunk = {
|
| 638 |
+
"choices": [
|
| 639 |
+
{
|
| 640 |
+
"delta": {
|
| 641 |
+
"role": "assistant",
|
| 642 |
+
"thinking": {"content": "", "signature": signature},
|
| 643 |
+
},
|
| 644 |
+
"finish_reason": None,
|
| 645 |
+
"index": 0,
|
| 646 |
+
}
|
| 647 |
+
],
|
| 648 |
+
"created": int(time.time()),
|
| 649 |
+
"id": current_id,
|
| 650 |
+
"model": current_model,
|
| 651 |
+
"object": "chat.completion.chunk",
|
| 652 |
+
}
|
| 653 |
+
yield f"data: {json.dumps(sig_chunk)}\n\n"
|
| 654 |
+
content_index += 1
|
| 655 |
+
else:
|
| 656 |
+
result["choices"][0]["message"]["thinking"]["signature"] = signature
|
| 657 |
+
|
| 658 |
+
# 提取答案内容
|
| 659 |
+
content_after = edit_content.split("</details>\n")[-1]
|
| 660 |
+
if content_after:
|
| 661 |
+
if is_stream:
|
| 662 |
+
content_chunk = {
|
| 663 |
+
"choices": [
|
| 664 |
+
{
|
| 665 |
+
"delta": {"role": "assistant", "content": content_after},
|
| 666 |
+
"finish_reason": None,
|
| 667 |
+
"index": 0,
|
| 668 |
+
}
|
| 669 |
+
],
|
| 670 |
+
"created": int(time.time()),
|
| 671 |
+
"id": current_id,
|
| 672 |
+
"model": current_model,
|
| 673 |
+
"object": "chat.completion.chunk",
|
| 674 |
+
}
|
| 675 |
+
yield f"data: {json.dumps(content_chunk)}\n\n"
|
| 676 |
+
else:
|
| 677 |
+
result["choices"][0]["message"]["content"] += content_after
|
| 678 |
+
|
| 679 |
+
# 处理增量内容
|
| 680 |
+
elif delta_content:
|
| 681 |
+
if is_stream:
|
| 682 |
+
# 如果还没有发送角色
|
| 683 |
+
if not has_thinking and not has_tool_call:
|
| 684 |
+
role_chunk = {
|
| 685 |
+
"choices": [
|
| 686 |
+
{
|
| 687 |
+
"delta": {"role": "assistant"},
|
| 688 |
+
"finish_reason": None,
|
| 689 |
+
"index": 0,
|
| 690 |
+
}
|
| 691 |
+
],
|
| 692 |
+
"created": int(time.time()),
|
| 693 |
+
"id": current_id,
|
| 694 |
+
"model": current_model,
|
| 695 |
+
"object": "chat.completion.chunk",
|
| 696 |
+
}
|
| 697 |
+
yield f"data: {json.dumps(role_chunk)}\n\n"
|
| 698 |
+
|
| 699 |
+
content_chunk = {
|
| 700 |
+
"choices": [
|
| 701 |
+
{
|
| 702 |
+
"delta": {"role": "assistant", "content": delta_content},
|
| 703 |
+
"finish_reason": None,
|
| 704 |
+
"index": 0,
|
| 705 |
+
}
|
| 706 |
+
],
|
| 707 |
+
"created": int(time.time()),
|
| 708 |
+
"id": current_id,
|
| 709 |
+
"model": current_model,
|
| 710 |
+
"object": "chat.completion.chunk",
|
| 711 |
+
}
|
| 712 |
+
yield f"data: {json.dumps(content_chunk)}\n\n"
|
| 713 |
+
else:
|
| 714 |
+
result["choices"][0]["message"]["content"] += delta_content
|
| 715 |
+
|
| 716 |
+
# 处理完成
|
| 717 |
+
if data.get("usage"):
|
| 718 |
+
usage = data["usage"]
|
| 719 |
+
if is_stream:
|
| 720 |
+
finish_chunk = {
|
| 721 |
+
"choices": [
|
| 722 |
+
{
|
| 723 |
+
"delta": {"role": "assistant", "content": ""},
|
| 724 |
+
"finish_reason": "stop",
|
| 725 |
+
"index": 0,
|
| 726 |
+
}
|
| 727 |
+
],
|
| 728 |
+
"usage": usage,
|
| 729 |
+
"created": int(time.time()),
|
| 730 |
+
"id": current_id,
|
| 731 |
+
"model": current_model,
|
| 732 |
+
"object": "chat.completion.chunk",
|
| 733 |
+
}
|
| 734 |
+
yield f"data: {json.dumps(finish_chunk)}\n\n"
|
| 735 |
+
yield "data: [DONE]\n\n"
|
| 736 |
+
else:
|
| 737 |
+
result["id"] = current_id
|
| 738 |
+
result["model"] = current_model
|
| 739 |
+
result["usage"] = usage
|
| 740 |
+
result["choices"][0]["finish_reason"] = "stop"
|
| 741 |
+
|
| 742 |
+
elif phase == "other":
|
| 743 |
+
# 处理其他阶段(可能包含usage信息)
|
| 744 |
+
if data.get("usage"):
|
| 745 |
+
tool_call_usage = data["usage"]
|
| 746 |
+
if has_tool_call and is_stream:
|
| 747 |
+
# 关闭最后一个工具调用并发送完成
|
| 748 |
+
if tool_id:
|
| 749 |
+
close_chunk = {
|
| 750 |
+
"choices": [
|
| 751 |
+
{
|
| 752 |
+
"delta": {
|
| 753 |
+
"tool_calls": [
|
| 754 |
+
{"index": content_index, "function": {"arguments": ""}}
|
| 755 |
+
]
|
| 756 |
+
},
|
| 757 |
+
"finish_reason": "tool_calls",
|
| 758 |
+
"index": 0,
|
| 759 |
+
}
|
| 760 |
+
],
|
| 761 |
+
"usage": tool_call_usage,
|
| 762 |
+
"created": int(time.time()),
|
| 763 |
+
"id": current_id,
|
| 764 |
+
"model": current_model,
|
| 765 |
+
"object": "chat.completion.chunk",
|
| 766 |
+
}
|
| 767 |
+
yield f"data: {json.dumps(close_chunk)}\n\n"
|
| 768 |
+
yield "data: [DONE]\n\n"
|
| 769 |
+
|
| 770 |
+
except json.JSONDecodeError as e:
|
| 771 |
+
logger.debug(f"JSON解析错误: {e}")
|
| 772 |
+
except Exception as e:
|
| 773 |
+
logger.error(f"处理chunk错误: {e}")
|
| 774 |
+
|
| 775 |
+
# 非流式模式返回完整结果
|
| 776 |
+
if not is_stream:
|
| 777 |
+
yield json.dumps(result)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from app.models import schemas
|
| 5 |
+
|
| 6 |
+
__all__ = ["schemas"]
|
app/models/schemas.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional, Any, Union, Literal
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ContentPart(BaseModel):
|
| 9 |
+
"""Content part model for OpenAI's new content format"""
|
| 10 |
+
|
| 11 |
+
type: str
|
| 12 |
+
text: Optional[str] = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Message(BaseModel):
|
| 16 |
+
"""Chat message model"""
|
| 17 |
+
|
| 18 |
+
role: str
|
| 19 |
+
content: Optional[Union[str, List[ContentPart]]] = None
|
| 20 |
+
reasoning_content: Optional[str] = None
|
| 21 |
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class OpenAIRequest(BaseModel):
|
| 25 |
+
"""OpenAI-compatible request model"""
|
| 26 |
+
|
| 27 |
+
model: str
|
| 28 |
+
messages: List[Message]
|
| 29 |
+
stream: Optional[bool] = False
|
| 30 |
+
temperature: Optional[float] = None
|
| 31 |
+
max_tokens: Optional[int] = None
|
| 32 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
| 33 |
+
tool_choice: Optional[Any] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ModelItem(BaseModel):
|
| 37 |
+
"""Model information item"""
|
| 38 |
+
|
| 39 |
+
id: str
|
| 40 |
+
name: str
|
| 41 |
+
owned_by: str
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class UpstreamRequest(BaseModel):
|
| 45 |
+
"""Upstream service request model"""
|
| 46 |
+
|
| 47 |
+
stream: bool
|
| 48 |
+
model: str
|
| 49 |
+
messages: List[Message]
|
| 50 |
+
params: Dict[str, Any] = {}
|
| 51 |
+
features: Dict[str, Any] = {}
|
| 52 |
+
background_tasks: Optional[Dict[str, bool]] = None
|
| 53 |
+
chat_id: Optional[str] = None
|
| 54 |
+
id: Optional[str] = None
|
| 55 |
+
mcp_servers: Optional[List[str]] = None
|
| 56 |
+
model_item: Optional[Dict[str, Any]] = {} # Model item dictionary
|
| 57 |
+
tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility
|
| 58 |
+
variables: Optional[Dict[str, str]] = None
|
| 59 |
+
model_config = {"protected_namespaces": ()}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Delta(BaseModel):
|
| 63 |
+
"""Stream delta model"""
|
| 64 |
+
|
| 65 |
+
role: Optional[str] = None
|
| 66 |
+
content: Optional[str] = "" or None
|
| 67 |
+
reasoning_content: Optional[str] = None
|
| 68 |
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Choice(BaseModel):
|
| 72 |
+
"""Response choice model"""
|
| 73 |
+
|
| 74 |
+
index: int
|
| 75 |
+
message: Optional[Message] = None
|
| 76 |
+
delta: Optional[Delta] = None
|
| 77 |
+
finish_reason: Optional[str] = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Usage(BaseModel):
|
| 81 |
+
"""Token usage statistics"""
|
| 82 |
+
|
| 83 |
+
prompt_tokens: int = 0
|
| 84 |
+
completion_tokens: int = 0
|
| 85 |
+
total_tokens: int = 0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class OpenAIResponse(BaseModel):
|
| 89 |
+
"""OpenAI-compatible response model"""
|
| 90 |
+
|
| 91 |
+
id: str
|
| 92 |
+
object: str
|
| 93 |
+
created: int
|
| 94 |
+
model: str
|
| 95 |
+
choices: List[Choice]
|
| 96 |
+
usage: Optional[Usage] = None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class UpstreamError(BaseModel):
|
| 100 |
+
"""Upstream error model"""
|
| 101 |
+
|
| 102 |
+
detail: str
|
| 103 |
+
code: int
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class UpstreamDataInner(BaseModel):
|
| 107 |
+
"""Inner upstream data model"""
|
| 108 |
+
|
| 109 |
+
error: Optional[UpstreamError] = None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class UpstreamDataData(BaseModel):
|
| 113 |
+
"""Upstream data content model"""
|
| 114 |
+
|
| 115 |
+
delta_content: str = ""
|
| 116 |
+
edit_content: str = ""
|
| 117 |
+
phase: str = ""
|
| 118 |
+
done: bool = False
|
| 119 |
+
usage: Optional[Usage] = None
|
| 120 |
+
error: Optional[UpstreamError] = None
|
| 121 |
+
inner: Optional[UpstreamDataInner] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class UpstreamData(BaseModel):
|
| 125 |
+
"""Upstream data model"""
|
| 126 |
+
|
| 127 |
+
type: str
|
| 128 |
+
data: UpstreamDataData
|
| 129 |
+
error: Optional[UpstreamError] = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class Model(BaseModel):
|
| 133 |
+
"""Model information for listing"""
|
| 134 |
+
|
| 135 |
+
id: str
|
| 136 |
+
object: str = "model"
|
| 137 |
+
created: int
|
| 138 |
+
owned_by: str
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ModelsResponse(BaseModel):
|
| 142 |
+
"""Models list response model"""
|
| 143 |
+
|
| 144 |
+
object: str = "list"
|
| 145 |
+
data: List[Model]
|
app/providers/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
多提供商架构包
|
| 6 |
+
提供统一的提供商接口和路由机制
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from app.providers.base import BaseProvider, ProviderConfig, provider_registry
|
| 10 |
+
from app.providers.zai_provider import ZAIProvider
|
| 11 |
+
from app.providers.k2think_provider import K2ThinkProvider
|
| 12 |
+
from app.providers.longcat_provider import LongCatProvider
|
| 13 |
+
from app.providers.provider_factory import ProviderFactory, ProviderRouter, get_provider_router, initialize_providers
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"BaseProvider",
|
| 17 |
+
"ProviderConfig",
|
| 18 |
+
"provider_registry",
|
| 19 |
+
"ZAIProvider",
|
| 20 |
+
"K2ThinkProvider",
|
| 21 |
+
"LongCatProvider",
|
| 22 |
+
"ProviderFactory",
|
| 23 |
+
"ProviderRouter",
|
| 24 |
+
"get_provider_router",
|
| 25 |
+
"initialize_providers"
|
| 26 |
+
]
|
app/providers/base.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
基础提供商抽象层
|
| 6 |
+
定义统一的提供商接口规范
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import uuid
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from typing import Dict, List, Any, Optional, AsyncGenerator, Union
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
from app.models.schemas import OpenAIRequest, Message
|
| 17 |
+
from app.utils.logger import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ProviderConfig:
|
| 24 |
+
"""提供商配置"""
|
| 25 |
+
name: str
|
| 26 |
+
api_endpoint: str
|
| 27 |
+
timeout: int = 30
|
| 28 |
+
headers: Optional[Dict[str, str]] = None
|
| 29 |
+
extra_config: Optional[Dict[str, Any]] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ProviderResponse:
|
| 34 |
+
"""提供商响应"""
|
| 35 |
+
success: bool
|
| 36 |
+
content: str = ""
|
| 37 |
+
error: Optional[str] = None
|
| 38 |
+
usage: Optional[Dict[str, int]] = None
|
| 39 |
+
extra_data: Optional[Dict[str, Any]] = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BaseProvider(ABC):
|
| 43 |
+
"""基础提供商抽象类"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config: ProviderConfig):
|
| 46 |
+
"""初始化提供商"""
|
| 47 |
+
self.config = config
|
| 48 |
+
self.name = config.name
|
| 49 |
+
self.logger = get_logger()
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
async def chat_completion(
|
| 53 |
+
self,
|
| 54 |
+
request: OpenAIRequest,
|
| 55 |
+
**kwargs
|
| 56 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 57 |
+
"""
|
| 58 |
+
聊天完成接口
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
request: OpenAI格式的请求
|
| 62 |
+
**kwargs: 额外参数
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
非流式: Dict[str, Any] - OpenAI格式的响应
|
| 66 |
+
流式: AsyncGenerator[str, None] - SSE格式的流式响应
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@abstractmethod
|
| 71 |
+
async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
转换OpenAI请求为提供商特定格式
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
request: OpenAI格式的请求
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Dict[str, Any]: 提供商特定格式的请求
|
| 80 |
+
"""
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
@abstractmethod
|
| 84 |
+
async def transform_response(
|
| 85 |
+
self,
|
| 86 |
+
response: Any,
|
| 87 |
+
request: OpenAIRequest
|
| 88 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 89 |
+
"""
|
| 90 |
+
转换提供商响应为OpenAI格式
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
response: 提供商的原始响应
|
| 94 |
+
request: 原始请求(用于构造响应)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Union[Dict[str, Any], AsyncGenerator[str, None]]: OpenAI格式的响应
|
| 98 |
+
"""
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def get_supported_models(self) -> List[str]:
|
| 102 |
+
"""获取支持的模型列表"""
|
| 103 |
+
return []
|
| 104 |
+
|
| 105 |
+
def create_chat_id(self) -> str:
|
| 106 |
+
"""生成聊天ID"""
|
| 107 |
+
return f"chatcmpl-{uuid.uuid4().hex}"
|
| 108 |
+
|
| 109 |
+
def create_openai_chunk(
|
| 110 |
+
self,
|
| 111 |
+
chat_id: str,
|
| 112 |
+
model: str,
|
| 113 |
+
delta: Dict[str, Any],
|
| 114 |
+
finish_reason: Optional[str] = None
|
| 115 |
+
) -> Dict[str, Any]:
|
| 116 |
+
"""创建OpenAI格式的流式响应块"""
|
| 117 |
+
return {
|
| 118 |
+
"id": chat_id,
|
| 119 |
+
"object": "chat.completion.chunk",
|
| 120 |
+
"created": int(time.time()),
|
| 121 |
+
"model": model,
|
| 122 |
+
"choices": [{
|
| 123 |
+
"index": 0,
|
| 124 |
+
"delta": delta,
|
| 125 |
+
"finish_reason": finish_reason,
|
| 126 |
+
"logprobs": None,
|
| 127 |
+
}],
|
| 128 |
+
"system_fingerprint": f"fp_{self.name}_001",
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
def create_openai_response(
|
| 132 |
+
self,
|
| 133 |
+
chat_id: str,
|
| 134 |
+
model: str,
|
| 135 |
+
content: str,
|
| 136 |
+
usage: Optional[Dict[str, int]] = None
|
| 137 |
+
) -> Dict[str, Any]:
|
| 138 |
+
"""创建OpenAI格式的非流式响应"""
|
| 139 |
+
return {
|
| 140 |
+
"id": chat_id,
|
| 141 |
+
"object": "chat.completion",
|
| 142 |
+
"created": int(time.time()),
|
| 143 |
+
"model": model,
|
| 144 |
+
"choices": [{
|
| 145 |
+
"index": 0,
|
| 146 |
+
"message": {
|
| 147 |
+
"role": "assistant",
|
| 148 |
+
"content": content
|
| 149 |
+
},
|
| 150 |
+
"finish_reason": "stop",
|
| 151 |
+
"logprobs": None,
|
| 152 |
+
}],
|
| 153 |
+
"usage": usage or {
|
| 154 |
+
"prompt_tokens": 0,
|
| 155 |
+
"completion_tokens": 0,
|
| 156 |
+
"total_tokens": 0
|
| 157 |
+
},
|
| 158 |
+
"system_fingerprint": f"fp_{self.name}_001",
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def create_openai_response_with_reasoning(
|
| 162 |
+
self,
|
| 163 |
+
chat_id: str,
|
| 164 |
+
model: str,
|
| 165 |
+
content: str,
|
| 166 |
+
reasoning_content: str = None,
|
| 167 |
+
usage: Optional[Dict[str, int]] = None
|
| 168 |
+
) -> Dict[str, Any]:
|
| 169 |
+
"""创建包含推理内容的OpenAI格式非流式响应"""
|
| 170 |
+
message = {
|
| 171 |
+
"role": "assistant",
|
| 172 |
+
"content": content
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# 只有当推理内容存在且不为空时才添加
|
| 176 |
+
if reasoning_content and reasoning_content.strip():
|
| 177 |
+
message["reasoning_content"] = reasoning_content
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
"id": chat_id,
|
| 181 |
+
"object": "chat.completion",
|
| 182 |
+
"created": int(time.time()),
|
| 183 |
+
"model": model,
|
| 184 |
+
"choices": [{
|
| 185 |
+
"index": 0,
|
| 186 |
+
"message": message,
|
| 187 |
+
"finish_reason": "stop",
|
| 188 |
+
"logprobs": None,
|
| 189 |
+
}],
|
| 190 |
+
"usage": usage or {
|
| 191 |
+
"prompt_tokens": 0,
|
| 192 |
+
"completion_tokens": 0,
|
| 193 |
+
"total_tokens": 0
|
| 194 |
+
},
|
| 195 |
+
"system_fingerprint": f"fp_{self.name}_001",
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
async def format_sse_chunk(self, chunk: Dict[str, Any]) -> str:
|
| 199 |
+
"""格式化SSE响应块"""
|
| 200 |
+
return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
| 201 |
+
|
| 202 |
+
async def format_sse_done(self) -> str:
|
| 203 |
+
"""格式化SSE结束标记"""
|
| 204 |
+
return "data: [DONE]\n\n"
|
| 205 |
+
|
| 206 |
+
def log_request(self, request: OpenAIRequest):
|
| 207 |
+
"""记录请求日志"""
|
| 208 |
+
self.logger.info(f"🔄 {self.name} 处理请求: {request.model}")
|
| 209 |
+
self.logger.debug(f" 消息数量: {len(request.messages)}")
|
| 210 |
+
self.logger.debug(f" 流式模式: {request.stream}")
|
| 211 |
+
|
| 212 |
+
def log_response(self, success: bool, error: Optional[str] = None):
|
| 213 |
+
"""记录响应日志"""
|
| 214 |
+
if success:
|
| 215 |
+
self.logger.info(f"✅ {self.name} 响应成功")
|
| 216 |
+
else:
|
| 217 |
+
self.logger.error(f"❌ {self.name} 响应失败: {error}")
|
| 218 |
+
|
| 219 |
+
def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]:
|
| 220 |
+
"""统一错误处理"""
|
| 221 |
+
error_msg = f"{self.name} {context} 错误: {str(error)}"
|
| 222 |
+
self.logger.error(error_msg)
|
| 223 |
+
|
| 224 |
+
return {
|
| 225 |
+
"error": {
|
| 226 |
+
"message": error_msg,
|
| 227 |
+
"type": "provider_error",
|
| 228 |
+
"code": "internal_error"
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ProviderRegistry:
|
| 234 |
+
"""提供商注册表"""
|
| 235 |
+
|
| 236 |
+
def __init__(self):
|
| 237 |
+
self._providers: Dict[str, BaseProvider] = {}
|
| 238 |
+
self._model_mapping: Dict[str, str] = {}
|
| 239 |
+
|
| 240 |
+
def register(self, provider: BaseProvider, models: List[str]):
|
| 241 |
+
"""注册提供商"""
|
| 242 |
+
self._providers[provider.name] = provider
|
| 243 |
+
for model in models:
|
| 244 |
+
self._model_mapping[model] = provider.name
|
| 245 |
+
logger.info(f"📝 注册提供商: {provider.name}, 模型: {models}")
|
| 246 |
+
|
| 247 |
+
def get_provider(self, model: str) -> Optional[BaseProvider]:
|
| 248 |
+
"""根据模型获取提供商"""
|
| 249 |
+
provider_name = self._model_mapping.get(model)
|
| 250 |
+
if provider_name:
|
| 251 |
+
return self._providers.get(provider_name)
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
def get_provider_by_name(self, name: str) -> Optional[BaseProvider]:
|
| 255 |
+
"""根据名称获取提供商"""
|
| 256 |
+
return self._providers.get(name)
|
| 257 |
+
|
| 258 |
+
def list_models(self) -> List[str]:
|
| 259 |
+
"""列出所有支持的模型"""
|
| 260 |
+
return list(self._model_mapping.keys())
|
| 261 |
+
|
| 262 |
+
def list_providers(self) -> List[str]:
|
| 263 |
+
"""列出所有提供商"""
|
| 264 |
+
return list(self._providers.keys())
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# 全局提供商注册表
|
| 268 |
+
provider_registry = ProviderRegistry()
|
app/providers/k2think_provider.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
K2Think 提供商适配器
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
import time
|
| 11 |
+
import uuid
|
| 12 |
+
import httpx
|
| 13 |
+
from typing import Dict, List, Any, Optional, AsyncGenerator, Union
|
| 14 |
+
|
| 15 |
+
from app.providers.base import BaseProvider, ProviderConfig
|
| 16 |
+
from app.models.schemas import OpenAIRequest, Message
|
| 17 |
+
from app.utils.logger import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class K2ThinkProvider(BaseProvider):
|
| 23 |
+
"""K2Think 提供商"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
config = ProviderConfig(
|
| 27 |
+
name="k2think",
|
| 28 |
+
api_endpoint="https://www.k2think.ai/api/guest/chat/completions",
|
| 29 |
+
timeout=30,
|
| 30 |
+
headers={
|
| 31 |
+
'Accept': 'text/event-stream',
|
| 32 |
+
'Accept-Encoding': 'gzip, deflate, br, zstd',
|
| 33 |
+
'Accept-Language': 'en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7',
|
| 34 |
+
'Content-Type': 'application/json',
|
| 35 |
+
'Origin': 'https://www.k2think.ai',
|
| 36 |
+
'Pragma': 'no-cache',
|
| 37 |
+
'Referer': 'https://www.k2think.ai/guest',
|
| 38 |
+
'Sec-Ch-Ua': '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"',
|
| 39 |
+
'Sec-Ch-Ua-Mobile': '?0',
|
| 40 |
+
'Sec-Ch-Ua-Platform': '"macOS"',
|
| 41 |
+
'Sec-Fetch-Dest': 'empty',
|
| 42 |
+
'Sec-Fetch-Mode': 'cors',
|
| 43 |
+
'Sec-Fetch-Site': 'same-origin',
|
| 44 |
+
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
|
| 45 |
+
}
|
| 46 |
+
)
|
| 47 |
+
super().__init__(config)
|
| 48 |
+
|
| 49 |
+
# K2Think 特定配置
|
| 50 |
+
self.handshake_url = "https://www.k2think.ai/guest"
|
| 51 |
+
self.new_chat_url = "https://www.k2think.ai/api/v1/chats/guest/new"
|
| 52 |
+
|
| 53 |
+
# 内容解析正则表达式 - 使用DOTALL标志确保.匹配换行符
|
| 54 |
+
self.reasoning_pattern = re.compile(r'<details type="reasoning"[^>]*>.*?<summary>.*?</summary>(.*?)</details>', re.DOTALL)
|
| 55 |
+
self.answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL)
|
| 56 |
+
|
| 57 |
+
def get_supported_models(self) -> List[str]:
|
| 58 |
+
"""获取支持的模型列表"""
|
| 59 |
+
return ["MBZUAI-IFM/K2-Think"]
|
| 60 |
+
|
| 61 |
+
def parse_cookies(self, headers) -> str:
|
| 62 |
+
"""解析Cookie"""
|
| 63 |
+
cookies = []
|
| 64 |
+
for key, value in headers.items():
|
| 65 |
+
if key.lower() == 'set-cookie':
|
| 66 |
+
cookies.append(value.split(';')[0])
|
| 67 |
+
return '; '.join(cookies)
|
| 68 |
+
|
| 69 |
+
def extract_reasoning_and_answer(self, content: str) -> tuple[str, str]:
|
| 70 |
+
"""提取推理内容和答案内容"""
|
| 71 |
+
if not content:
|
| 72 |
+
return "", ""
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
reasoning_match = self.reasoning_pattern.search(content)
|
| 76 |
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
|
| 77 |
+
|
| 78 |
+
answer_match = self.answer_pattern.search(content)
|
| 79 |
+
answer = answer_match.group(1).strip() if answer_match else ""
|
| 80 |
+
|
| 81 |
+
return reasoning, answer
|
| 82 |
+
except Exception as e:
|
| 83 |
+
self.logger.error(f"提取K2内容错误: {e}")
|
| 84 |
+
return "", ""
|
| 85 |
+
|
| 86 |
+
def calculate_delta(self, previous: str, current: str) -> str:
|
| 87 |
+
"""计算内容增量"""
|
| 88 |
+
if not previous:
|
| 89 |
+
return current
|
| 90 |
+
if not current or len(current) < len(previous):
|
| 91 |
+
return ""
|
| 92 |
+
return current[len(previous):]
|
| 93 |
+
|
| 94 |
+
def parse_api_response(self, obj: Any) -> tuple[str, bool]:
|
| 95 |
+
"""解析API响应"""
|
| 96 |
+
if not obj or not isinstance(obj, dict):
|
| 97 |
+
return "", False
|
| 98 |
+
|
| 99 |
+
if obj.get("done") is True:
|
| 100 |
+
return "", True
|
| 101 |
+
|
| 102 |
+
choices = obj.get("choices", [])
|
| 103 |
+
if choices and len(choices) > 0:
|
| 104 |
+
delta = choices[0].get("delta", {})
|
| 105 |
+
return delta.get("content", ""), False
|
| 106 |
+
|
| 107 |
+
content = obj.get("content")
|
| 108 |
+
if isinstance(content, str):
|
| 109 |
+
return content, False
|
| 110 |
+
|
| 111 |
+
return "", False
|
| 112 |
+
|
| 113 |
+
async def get_k2_auth_data(self, request: OpenAIRequest) -> Dict[str, Any]:
|
| 114 |
+
"""获取K2Think认证数据"""
|
| 115 |
+
# 1. 握手请求 - 使用更简单的Accept-Encoding来避免Brotli问题
|
| 116 |
+
headers_for_handshake = {**self.config.headers}
|
| 117 |
+
headers_for_handshake['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
|
| 118 |
+
|
| 119 |
+
async with httpx.AsyncClient() as client:
|
| 120 |
+
handshake_response = await client.get(
|
| 121 |
+
self.handshake_url,
|
| 122 |
+
headers=headers_for_handshake,
|
| 123 |
+
follow_redirects=True
|
| 124 |
+
)
|
| 125 |
+
if not handshake_response.is_success:
|
| 126 |
+
try:
|
| 127 |
+
# 使用httpx的text属性,它会自动处理解压缩和编码
|
| 128 |
+
error_text = handshake_response.text
|
| 129 |
+
raise Exception(f"K2 握手失败: {handshake_response.status_code} {error_text[:200]}")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
raise Exception(f"K2 握手失败: {handshake_response.status_code}")
|
| 132 |
+
|
| 133 |
+
initial_cookies = self.parse_cookies(handshake_response.headers)
|
| 134 |
+
|
| 135 |
+
# 2. 准备消息
|
| 136 |
+
prepared_messages = self.prepare_k2_messages(request.messages)
|
| 137 |
+
first_user_message = next((m for m in prepared_messages if m["role"] == "user"), None)
|
| 138 |
+
if not first_user_message:
|
| 139 |
+
raise Exception("没有找到用户消息来初始化对话")
|
| 140 |
+
|
| 141 |
+
# 3. 创建新对话
|
| 142 |
+
message_id = str(uuid.uuid4())
|
| 143 |
+
now = int(time.time() * 1000)
|
| 144 |
+
model_id = request.model or "MBZUAI-IFM/K2-Think"
|
| 145 |
+
|
| 146 |
+
new_chat_payload = {
|
| 147 |
+
"chat": {
|
| 148 |
+
"id": "",
|
| 149 |
+
"title": "Guest Chat",
|
| 150 |
+
"models": [model_id],
|
| 151 |
+
"params": {},
|
| 152 |
+
"history": {
|
| 153 |
+
"messages": {
|
| 154 |
+
message_id: {
|
| 155 |
+
"id": message_id,
|
| 156 |
+
"parentId": None,
|
| 157 |
+
"childrenIds": [],
|
| 158 |
+
"role": "user",
|
| 159 |
+
"content": first_user_message["content"],
|
| 160 |
+
"timestamp": now // 1000,
|
| 161 |
+
"models": [model_id]
|
| 162 |
+
}
|
| 163 |
+
},
|
| 164 |
+
"currentId": message_id
|
| 165 |
+
},
|
| 166 |
+
"messages": [{
|
| 167 |
+
"id": message_id,
|
| 168 |
+
"parentId": None,
|
| 169 |
+
"childrenIds": [],
|
| 170 |
+
"role": "user",
|
| 171 |
+
"content": first_user_message["content"],
|
| 172 |
+
"timestamp": now // 1000,
|
| 173 |
+
"models": [model_id]
|
| 174 |
+
}],
|
| 175 |
+
"tags": [],
|
| 176 |
+
"timestamp": now
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
headers_with_cookies = {**self.config.headers, 'Cookie': initial_cookies}
|
| 181 |
+
headers_with_cookies['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
|
| 182 |
+
|
| 183 |
+
async with httpx.AsyncClient() as client:
|
| 184 |
+
new_chat_response = await client.post(
|
| 185 |
+
self.new_chat_url,
|
| 186 |
+
headers=headers_with_cookies,
|
| 187 |
+
json=new_chat_payload,
|
| 188 |
+
follow_redirects=True
|
| 189 |
+
)
|
| 190 |
+
if not new_chat_response.is_success:
|
| 191 |
+
try:
|
| 192 |
+
# 使用httpx的text属性,它会自动处理解压缩和编码
|
| 193 |
+
error_text = new_chat_response.text
|
| 194 |
+
except Exception:
|
| 195 |
+
error_text = f"Status: {new_chat_response.status_code}"
|
| 196 |
+
raise Exception(f"K2 新对话创建失败: {new_chat_response.status_code} {error_text[:200]}")
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
new_chat_data = new_chat_response.json()
|
| 200 |
+
except Exception as e:
|
| 201 |
+
# 如果JSON解析失败,尝试获取原始内容
|
| 202 |
+
try:
|
| 203 |
+
# 使用httpx的text属性,它会自动处理解压缩和编码
|
| 204 |
+
content_str = new_chat_response.text
|
| 205 |
+
self.logger.debug(f"K2 响应原始内容: {content_str[:500]}")
|
| 206 |
+
raise Exception(f"K2 响应JSON解析失败: {e}, 原始内容: {content_str[:200]}")
|
| 207 |
+
except Exception as decode_error:
|
| 208 |
+
# 如果text也失败,尝试手动处理
|
| 209 |
+
try:
|
| 210 |
+
raw_bytes = new_chat_response.content
|
| 211 |
+
content_str = raw_bytes.decode('utf-8', errors='replace')
|
| 212 |
+
raise Exception(f"K2 响应解析失败: {e}, 手动解码内容: {content_str[:200]}")
|
| 213 |
+
except Exception:
|
| 214 |
+
raise Exception(f"K2 响应解析完全失败: {e}, 解码错误: {decode_error}")
|
| 215 |
+
conversation_id = new_chat_data.get("id")
|
| 216 |
+
if not conversation_id:
|
| 217 |
+
raise Exception("无法从K2 /new端点获取conversation_id")
|
| 218 |
+
|
| 219 |
+
chat_specific_cookies = self.parse_cookies(new_chat_response.headers)
|
| 220 |
+
|
| 221 |
+
# 4. 组合最终Cookie
|
| 222 |
+
base_cookies = [initial_cookies, chat_specific_cookies]
|
| 223 |
+
base_cookies = [c for c in base_cookies if c]
|
| 224 |
+
final_cookie = '; '.join(base_cookies) + '; guest_conversation_count=1'
|
| 225 |
+
|
| 226 |
+
# 5. 构建最终请求载荷
|
| 227 |
+
final_payload = {
|
| 228 |
+
"stream": True,
|
| 229 |
+
"model": model_id,
|
| 230 |
+
"messages": prepared_messages,
|
| 231 |
+
"conversation_id": conversation_id,
|
| 232 |
+
"params": {}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# 添加可选参数
|
| 236 |
+
if request.temperature is not None:
|
| 237 |
+
final_payload["params"]["temperature"] = request.temperature
|
| 238 |
+
if request.max_tokens is not None:
|
| 239 |
+
final_payload["params"]["max_tokens"] = request.max_tokens
|
| 240 |
+
|
| 241 |
+
final_headers = {**self.config.headers, 'Cookie': final_cookie}
|
| 242 |
+
|
| 243 |
+
return {
|
| 244 |
+
"payload": final_payload,
|
| 245 |
+
"headers": final_headers
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
def prepare_k2_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
|
| 249 |
+
"""准备K2Think消息格式"""
|
| 250 |
+
result = []
|
| 251 |
+
system_content = ""
|
| 252 |
+
|
| 253 |
+
for msg in messages:
|
| 254 |
+
if msg.role == "system":
|
| 255 |
+
system_content = system_content + "\n\n" + msg.content if system_content else msg.content
|
| 256 |
+
else:
|
| 257 |
+
content = msg.content
|
| 258 |
+
if isinstance(content, list):
|
| 259 |
+
# 处理多模态内容,提取文本
|
| 260 |
+
text_parts = [part.text for part in content if hasattr(part, 'text') and part.text]
|
| 261 |
+
content = "\n".join(text_parts)
|
| 262 |
+
|
| 263 |
+
result.append({
|
| 264 |
+
"role": msg.role,
|
| 265 |
+
"content": content
|
| 266 |
+
})
|
| 267 |
+
|
| 268 |
+
# 将系统消息合并到第一个用户消息中
|
| 269 |
+
if system_content:
|
| 270 |
+
first_user_idx = next((i for i, m in enumerate(result) if m["role"] == "user"), -1)
|
| 271 |
+
if first_user_idx >= 0:
|
| 272 |
+
result[first_user_idx]["content"] = f"{system_content}\n\n{result[first_user_idx]['content']}"
|
| 273 |
+
else:
|
| 274 |
+
result.insert(0, {"role": "user", "content": system_content})
|
| 275 |
+
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
async def _handle_stream_request(
|
| 279 |
+
self,
|
| 280 |
+
transformed: Dict[str, Any],
|
| 281 |
+
request: OpenAIRequest
|
| 282 |
+
) -> AsyncGenerator[str, None]:
|
| 283 |
+
"""处理流式请求 - 在client.stream上下文内直接处理"""
|
| 284 |
+
chat_id = self.create_chat_id()
|
| 285 |
+
model = transformed["model"]
|
| 286 |
+
|
| 287 |
+
# 准备请求头
|
| 288 |
+
headers_for_request = {**transformed["headers"]}
|
| 289 |
+
headers_for_request['Accept-Encoding'] = 'gzip, deflate'
|
| 290 |
+
|
| 291 |
+
self.logger.info(f"🌊 开始K2Think流式请求")
|
| 292 |
+
|
| 293 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 294 |
+
async with client.stream(
|
| 295 |
+
"POST",
|
| 296 |
+
transformed["url"],
|
| 297 |
+
headers=headers_for_request,
|
| 298 |
+
json=transformed["payload"]
|
| 299 |
+
) as response:
|
| 300 |
+
if not response.is_success:
|
| 301 |
+
error_msg = f"K2Think API 错误: {response.status_code}"
|
| 302 |
+
self.log_response(False, error_msg)
|
| 303 |
+
# 对于流式响应,我们需要yield错误信息
|
| 304 |
+
yield await self.format_sse_chunk({
|
| 305 |
+
"error": {
|
| 306 |
+
"message": error_msg,
|
| 307 |
+
"type": "provider_error",
|
| 308 |
+
"code": "api_error"
|
| 309 |
+
}
|
| 310 |
+
})
|
| 311 |
+
return
|
| 312 |
+
|
| 313 |
+
# 发送初始角色块
|
| 314 |
+
yield await self.format_sse_chunk(
|
| 315 |
+
self.create_openai_chunk(chat_id, model, {"role": "assistant"})
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# 处理流式数据
|
| 319 |
+
accumulated_content = ""
|
| 320 |
+
previous_reasoning = ""
|
| 321 |
+
previous_answer = ""
|
| 322 |
+
reasoning_phase = True
|
| 323 |
+
chunk_count = 0
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
async for line in response.aiter_lines():
|
| 327 |
+
chunk_count += 1
|
| 328 |
+
self.logger.debug(f"📦 收到数据块 #{chunk_count}: {line[:100]}...")
|
| 329 |
+
|
| 330 |
+
if not line.startswith("data:"):
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
data_str = line[5:].strip()
|
| 334 |
+
if self._is_end_marker(data_str):
|
| 335 |
+
self.logger.debug(f"🏁 检测到结束标记: {data_str}")
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
content = self._parse_data_string(data_str)
|
| 339 |
+
if not content:
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
+
accumulated_content = content
|
| 343 |
+
current_reasoning, current_answer = self.extract_reasoning_and_answer(accumulated_content)
|
| 344 |
+
|
| 345 |
+
# 处理推理阶段
|
| 346 |
+
if reasoning_phase and current_reasoning:
|
| 347 |
+
delta = self.calculate_delta(previous_reasoning, current_reasoning)
|
| 348 |
+
if delta.strip():
|
| 349 |
+
self.logger.debug(f"🧠 推理增量: {delta[:50]}...")
|
| 350 |
+
yield await self.format_sse_chunk(
|
| 351 |
+
self.create_openai_chunk(chat_id, model, {"reasoning_content": delta})
|
| 352 |
+
)
|
| 353 |
+
previous_reasoning = current_reasoning
|
| 354 |
+
|
| 355 |
+
# 切换到答案阶段
|
| 356 |
+
if current_answer and reasoning_phase:
|
| 357 |
+
reasoning_phase = False
|
| 358 |
+
self.logger.debug("🔄 切换到答案阶段")
|
| 359 |
+
# 发送剩余的推理内容
|
| 360 |
+
final_reasoning_delta = self.calculate_delta(previous_reasoning, current_reasoning)
|
| 361 |
+
if final_reasoning_delta.strip():
|
| 362 |
+
yield await self.format_sse_chunk(
|
| 363 |
+
self.create_openai_chunk(chat_id, model, {"reasoning_content": final_reasoning_delta})
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# 处理答案阶段
|
| 367 |
+
if not reasoning_phase and current_answer:
|
| 368 |
+
delta = self.calculate_delta(previous_answer, current_answer)
|
| 369 |
+
if delta.strip():
|
| 370 |
+
self.logger.debug(f"💬 答案增量: {delta[:50]}...")
|
| 371 |
+
yield await self.format_sse_chunk(
|
| 372 |
+
self.create_openai_chunk(chat_id, model, {"content": delta})
|
| 373 |
+
)
|
| 374 |
+
previous_answer = current_answer
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
self.logger.error(f"流式响应处理错误: {e}")
|
| 378 |
+
yield await self.format_sse_chunk({
|
| 379 |
+
"error": {
|
| 380 |
+
"message": f"流式处理错误: {str(e)}",
|
| 381 |
+
"type": "stream_error",
|
| 382 |
+
"code": "processing_error"
|
| 383 |
+
}
|
| 384 |
+
})
|
| 385 |
+
return
|
| 386 |
+
|
| 387 |
+
# 发送结束块
|
| 388 |
+
self.logger.info(f"✅ K2Think流式响应完成,共处理 {chunk_count} 个数据块")
|
| 389 |
+
yield await self.format_sse_chunk(
|
| 390 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 391 |
+
)
|
| 392 |
+
yield await self.format_sse_done()
|
| 393 |
+
|
| 394 |
+
async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
|
| 395 |
+
"""转换OpenAI请求为K2Think格式"""
|
| 396 |
+
self.logger.info(f"🔄 转换 OpenAI 请求到 K2Think 格式: {request.model}")
|
| 397 |
+
|
| 398 |
+
auth_data = await self.get_k2_auth_data(request)
|
| 399 |
+
|
| 400 |
+
return {
|
| 401 |
+
"url": self.config.api_endpoint,
|
| 402 |
+
"headers": auth_data["headers"],
|
| 403 |
+
"payload": auth_data["payload"],
|
| 404 |
+
"model": request.model
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
async def chat_completion(
|
| 408 |
+
self,
|
| 409 |
+
request: OpenAIRequest
|
| 410 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 411 |
+
"""聊天完成接口"""
|
| 412 |
+
self.log_request(request)
|
| 413 |
+
|
| 414 |
+
try:
|
| 415 |
+
# 转换请求
|
| 416 |
+
transformed = await self.transform_request(request)
|
| 417 |
+
|
| 418 |
+
# 发送请求 - 使用更兼容的压缩设置
|
| 419 |
+
headers_for_request = {**transformed["headers"]}
|
| 420 |
+
headers_for_request['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
|
| 421 |
+
|
| 422 |
+
if request.stream:
|
| 423 |
+
# 流式请求 - 直接在这里处理流式响应
|
| 424 |
+
return self._handle_stream_request(transformed, request)
|
| 425 |
+
else:
|
| 426 |
+
# 非流式请求 - 使用传统的 client.post()
|
| 427 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 428 |
+
response = await client.post(
|
| 429 |
+
transformed["url"],
|
| 430 |
+
headers=headers_for_request,
|
| 431 |
+
json=transformed["payload"]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if not response.is_success:
|
| 435 |
+
error_msg = f"K2Think API 错误: {response.status_code}"
|
| 436 |
+
self.log_response(False, error_msg)
|
| 437 |
+
return self.handle_error(Exception(error_msg))
|
| 438 |
+
|
| 439 |
+
# 转换非流式响应
|
| 440 |
+
return await self.transform_response(response, request, transformed)
|
| 441 |
+
|
| 442 |
+
except Exception as e:
|
| 443 |
+
self.log_response(False, str(e))
|
| 444 |
+
return self.handle_error(e, "请求处理")
|
| 445 |
+
|
| 446 |
+
async def transform_response(
|
| 447 |
+
self,
|
| 448 |
+
response: httpx.Response,
|
| 449 |
+
request: OpenAIRequest,
|
| 450 |
+
transformed: Dict[str, Any]
|
| 451 |
+
) -> Dict[str, Any]:
|
| 452 |
+
"""转换K2Think响应为OpenAI格式 - 仅用于非流式请求"""
|
| 453 |
+
chat_id = self.create_chat_id()
|
| 454 |
+
model = transformed["model"]
|
| 455 |
+
|
| 456 |
+
# 流式请求现在由 _handle_stream_request 直接处理
|
| 457 |
+
# 这里只处理非流式请求
|
| 458 |
+
return await self._handle_non_stream_response(response, chat_id, model)
|
| 459 |
+
|
| 460 |
+
def _is_end_marker(self, data: str) -> bool:
|
| 461 |
+
"""检查是否为结束标记"""
|
| 462 |
+
return not data or data in ["-1", "[DONE]", "DONE", "done"]
|
| 463 |
+
|
| 464 |
+
def _parse_data_string(self, data_str: str) -> str:
|
| 465 |
+
"""解析数据字符串"""
|
| 466 |
+
try:
|
| 467 |
+
obj = json.loads(data_str)
|
| 468 |
+
content, is_done = self.parse_api_response(obj)
|
| 469 |
+
return "" if is_done else content
|
| 470 |
+
except:
|
| 471 |
+
return data_str
|
| 472 |
+
|
| 473 |
+
async def _handle_non_stream_response(
|
| 474 |
+
self,
|
| 475 |
+
response: httpx.Response,
|
| 476 |
+
chat_id: str,
|
| 477 |
+
model: str
|
| 478 |
+
) -> Dict[str, Any]:
|
| 479 |
+
"""处理K2Think非流式响应"""
|
| 480 |
+
# 聚合流式内容 - 使用httpx的aiter_lines,它���自动处理解压缩
|
| 481 |
+
final_content = ""
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
# 使用aiter_lines(),httpx会自动处理压缩和编码
|
| 485 |
+
async for line in response.aiter_lines():
|
| 486 |
+
if not line.startswith("data:"):
|
| 487 |
+
continue
|
| 488 |
+
|
| 489 |
+
data_str = line[5:].strip()
|
| 490 |
+
if self._is_end_marker(data_str):
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
content = self._parse_data_string(data_str)
|
| 494 |
+
if content:
|
| 495 |
+
final_content = content
|
| 496 |
+
|
| 497 |
+
except Exception as e:
|
| 498 |
+
self.logger.error(f"非流式响应处理错误: {e}")
|
| 499 |
+
raise
|
| 500 |
+
|
| 501 |
+
# 提取推理内容和答案内容
|
| 502 |
+
reasoning, answer = self.extract_reasoning_and_answer(final_content)
|
| 503 |
+
|
| 504 |
+
# 清理内容格式
|
| 505 |
+
reasoning = reasoning.replace("\\n", "\n") if reasoning else ""
|
| 506 |
+
answer = answer.replace("\\n", "\n") if answer else final_content
|
| 507 |
+
|
| 508 |
+
# 创建包含推理内容的响应
|
| 509 |
+
return self.create_openai_response_with_reasoning(chat_id, model, answer, reasoning)
|
app/providers/longcat_provider.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
LongCat 提供商适配器
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import httpx
|
| 11 |
+
import random
|
| 12 |
+
import asyncio
|
| 13 |
+
from typing import Dict, List, Any, Optional, AsyncGenerator, Union
|
| 14 |
+
|
| 15 |
+
from app.providers.base import BaseProvider, ProviderConfig
|
| 16 |
+
from app.models.schemas import OpenAIRequest, Message
|
| 17 |
+
from app.utils.logger import get_logger
|
| 18 |
+
from app.utils.user_agent import get_dynamic_headers
|
| 19 |
+
from app.core.config import settings
|
| 20 |
+
|
| 21 |
+
logger = get_logger()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class LongCatProvider(BaseProvider):
|
| 25 |
+
"""LongCat 提供商"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
# 使用动态生成的 headers,不包含 User-Agent(将在请求时动态生成)
|
| 29 |
+
config = ProviderConfig(
|
| 30 |
+
name="longcat",
|
| 31 |
+
api_endpoint="https://longcat.chat/api/v1/chat-completion",
|
| 32 |
+
timeout=30,
|
| 33 |
+
headers={
|
| 34 |
+
'accept': 'text/event-stream,application/json',
|
| 35 |
+
'content-type': 'application/json',
|
| 36 |
+
'origin': 'https://longcat.chat',
|
| 37 |
+
'referer': 'https://longcat.chat/t',
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
super().__init__(config)
|
| 41 |
+
self.base_url = "https://longcat.chat"
|
| 42 |
+
self.session_create_url = f"{self.base_url}/api/v1/session-create"
|
| 43 |
+
self.session_delete_url = f"{self.base_url}/api/v1/session-delete"
|
| 44 |
+
|
| 45 |
+
def get_supported_models(self) -> List[str]:
|
| 46 |
+
"""获取支持的模型列表"""
|
| 47 |
+
return ["LongCat-Flash", "LongCat", "LongCat-Search"]
|
| 48 |
+
|
| 49 |
+
def get_passport_token(self) -> Optional[str]:
|
| 50 |
+
"""获取 LongCat passport token"""
|
| 51 |
+
# 优先使用环境变量中的单个token
|
| 52 |
+
if settings.LONGCAT_PASSPORT_TOKEN:
|
| 53 |
+
return settings.LONGCAT_PASSPORT_TOKEN
|
| 54 |
+
|
| 55 |
+
# 从token文件中随机选择一个
|
| 56 |
+
token_list = settings.longcat_token_list
|
| 57 |
+
if token_list:
|
| 58 |
+
return random.choice(token_list)
|
| 59 |
+
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def create_headers_with_auth(self, token: str, user_agent: str, referer: str = None) -> Dict[str, str]:
|
| 63 |
+
"""创建带认证的请求头"""
|
| 64 |
+
headers = {
|
| 65 |
+
"User-Agent": user_agent,
|
| 66 |
+
"Content-Type": "application/json",
|
| 67 |
+
"x-requested-with": "XMLHttpRequest",
|
| 68 |
+
"X-Client-Language": "zh",
|
| 69 |
+
"Cookie": f"passport_token_key={token}",
|
| 70 |
+
"Accept": "text/event-stream,application/json",
|
| 71 |
+
"Origin": "https://longcat.chat"
|
| 72 |
+
}
|
| 73 |
+
if referer:
|
| 74 |
+
headers["Referer"] = referer
|
| 75 |
+
else:
|
| 76 |
+
headers["Referer"] = f"{self.base_url}/"
|
| 77 |
+
return headers
|
| 78 |
+
|
| 79 |
+
async def create_session(self, token: str, user_agent: str) -> str:
|
| 80 |
+
"""创建会话并返回 conversation_id"""
|
| 81 |
+
headers = self.create_headers_with_auth(token, user_agent)
|
| 82 |
+
data = {"model": "", "agentId": ""}
|
| 83 |
+
|
| 84 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 85 |
+
response = await client.post(
|
| 86 |
+
self.session_create_url,
|
| 87 |
+
headers=headers,
|
| 88 |
+
json=data
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if response.status_code != 200:
|
| 92 |
+
raise Exception(f"会话创建失败: {response.status_code}")
|
| 93 |
+
|
| 94 |
+
response_data = response.json()
|
| 95 |
+
if response_data.get("code") != 0:
|
| 96 |
+
raise Exception(f"会话创建错误: {response_data.get('message')}")
|
| 97 |
+
|
| 98 |
+
return response_data["data"]["conversationId"]
|
| 99 |
+
|
| 100 |
+
async def delete_session(self, conversation_id: str, token: str, user_agent: str) -> None:
|
| 101 |
+
"""删除会话"""
|
| 102 |
+
try:
|
| 103 |
+
headers = self.create_headers_with_auth(
|
| 104 |
+
token,
|
| 105 |
+
user_agent,
|
| 106 |
+
f"{self.base_url}/c/{conversation_id}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 110 |
+
url = f"{self.session_delete_url}?conversationId={conversation_id}"
|
| 111 |
+
response = await client.get(url, headers=headers)
|
| 112 |
+
|
| 113 |
+
if response.status_code == 200:
|
| 114 |
+
self.logger.debug(f"成功删除会话 {conversation_id}")
|
| 115 |
+
else:
|
| 116 |
+
self.logger.warning(f"删除会话失败: {response.status_code}")
|
| 117 |
+
except Exception as e:
|
| 118 |
+
self.logger.error(f"删除会话出错: {e}")
|
| 119 |
+
|
| 120 |
+
def schedule_session_deletion(self, conversation_id: str, token: str, user_agent: str):
|
| 121 |
+
"""异步删除会话(不等待)"""
|
| 122 |
+
asyncio.create_task(self.delete_session(conversation_id, token, user_agent))
|
| 123 |
+
|
| 124 |
+
def format_messages_for_longcat(self, messages: List[Message]) -> str:
|
| 125 |
+
"""格式化消息为 LongCat 格式"""
|
| 126 |
+
formatted_messages = []
|
| 127 |
+
for msg in messages:
|
| 128 |
+
content = msg.content
|
| 129 |
+
if isinstance(content, list):
|
| 130 |
+
# 处理多模态内容,提取文本
|
| 131 |
+
text_parts = []
|
| 132 |
+
for part in content:
|
| 133 |
+
if hasattr(part, 'text') and part.text:
|
| 134 |
+
text_parts.append(part.text)
|
| 135 |
+
content = "\n".join(text_parts)
|
| 136 |
+
formatted_messages.append(f"{msg.role}:{content}")
|
| 137 |
+
return ";".join(formatted_messages)
|
| 138 |
+
|
| 139 |
+
async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
|
| 140 |
+
"""转换OpenAI请求为LongCat格式"""
|
| 141 |
+
# 获取认证token
|
| 142 |
+
passport_token = self.get_passport_token()
|
| 143 |
+
if not passport_token:
|
| 144 |
+
raise Exception("未配置 LongCat passport token,请设置 LONGCAT_PASSPORT_TOKEN 环境变量或 LONGCAT_TOKENS_FILE")
|
| 145 |
+
|
| 146 |
+
# 生成动态 User-Agent
|
| 147 |
+
dynamic_headers = get_dynamic_headers()
|
| 148 |
+
user_agent = dynamic_headers.get("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
| 149 |
+
|
| 150 |
+
# 创建会话
|
| 151 |
+
conversation_id = await self.create_session(passport_token, user_agent)
|
| 152 |
+
|
| 153 |
+
# 格式化消息内容
|
| 154 |
+
formatted_content = self.format_messages_for_longcat(request.messages)
|
| 155 |
+
|
| 156 |
+
# 构建LongCat请求载荷
|
| 157 |
+
payload = {
|
| 158 |
+
"conversationId": conversation_id,
|
| 159 |
+
"content": formatted_content,
|
| 160 |
+
"reasonEnabled": 0,
|
| 161 |
+
"searchEnabled": 1 if "search" in request.model.lower() else 0,
|
| 162 |
+
"parentMessageId": 0
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# 创建带认证的请求头
|
| 166 |
+
headers = self.create_headers_with_auth(
|
| 167 |
+
passport_token,
|
| 168 |
+
user_agent,
|
| 169 |
+
f"{self.base_url}/c/{conversation_id}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"url": self.config.api_endpoint,
|
| 174 |
+
"headers": headers,
|
| 175 |
+
"payload": payload,
|
| 176 |
+
"model": request.model,
|
| 177 |
+
"conversation_id": conversation_id,
|
| 178 |
+
"passport_token": passport_token,
|
| 179 |
+
"user_agent": user_agent
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
async def chat_completion(
|
| 183 |
+
self,
|
| 184 |
+
request: OpenAIRequest,
|
| 185 |
+
**kwargs
|
| 186 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 187 |
+
"""聊天完成接口"""
|
| 188 |
+
self.log_request(request)
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
# 转换请求
|
| 192 |
+
transformed = await self.transform_request(request)
|
| 193 |
+
|
| 194 |
+
# 发送请求
|
| 195 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 196 |
+
response = await client.post(
|
| 197 |
+
transformed["url"],
|
| 198 |
+
headers=transformed["headers"],
|
| 199 |
+
json=transformed["payload"]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if not response.is_success:
|
| 203 |
+
error_msg = f"LongCat API 错误: {response.status_code}"
|
| 204 |
+
try:
|
| 205 |
+
error_detail = await response.atext()
|
| 206 |
+
self.logger.error(f"❌ API 错误详情: {error_detail}")
|
| 207 |
+
except:
|
| 208 |
+
pass
|
| 209 |
+
self.log_response(False, error_msg)
|
| 210 |
+
return self.handle_error(Exception(error_msg))
|
| 211 |
+
|
| 212 |
+
# 转换响应
|
| 213 |
+
return await self.transform_response(response, request, transformed)
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
self.logger.error(f"❌ LongCat 请求处理异常: {e}")
|
| 217 |
+
self.log_response(False, str(e))
|
| 218 |
+
return self.handle_error(e, "请求处理")
|
| 219 |
+
|
| 220 |
+
async def transform_response(
|
| 221 |
+
self,
|
| 222 |
+
response: httpx.Response,
|
| 223 |
+
request: OpenAIRequest,
|
| 224 |
+
transformed: Dict[str, Any]
|
| 225 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 226 |
+
"""转换LongCat响应为OpenAI格式"""
|
| 227 |
+
chat_id = self.create_chat_id()
|
| 228 |
+
model = transformed["model"]
|
| 229 |
+
conversation_id = transformed["conversation_id"]
|
| 230 |
+
passport_token = transformed["passport_token"]
|
| 231 |
+
user_agent = transformed["user_agent"]
|
| 232 |
+
|
| 233 |
+
if request.stream:
|
| 234 |
+
return self._handle_stream_response(
|
| 235 |
+
response, chat_id, model, conversation_id, passport_token, user_agent
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
return await self._handle_non_stream_response(
|
| 239 |
+
response, chat_id, model, conversation_id, passport_token, user_agent
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
async def _handle_stream_response(
|
| 243 |
+
self,
|
| 244 |
+
response: httpx.Response,
|
| 245 |
+
chat_id: str,
|
| 246 |
+
model: str,
|
| 247 |
+
conversation_id: str,
|
| 248 |
+
passport_token: str,
|
| 249 |
+
user_agent: str
|
| 250 |
+
) -> AsyncGenerator[str, None]:
|
| 251 |
+
"""处理LongCat流式响应"""
|
| 252 |
+
session_deleted = False
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
# 发送初始角色块
|
| 256 |
+
yield await self.format_sse_chunk(
|
| 257 |
+
self.create_openai_chunk(chat_id, model, {"role": "assistant"})
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
stream_finished = False
|
| 261 |
+
|
| 262 |
+
async for line in response.aiter_lines():
|
| 263 |
+
line = line.strip()
|
| 264 |
+
|
| 265 |
+
# 首先检查是否是错误响应(JSON格式但不是SSE格式)
|
| 266 |
+
if not line.startswith('data:'):
|
| 267 |
+
# 尝试解析为JSON错误响应
|
| 268 |
+
try:
|
| 269 |
+
error_data = json.loads(line)
|
| 270 |
+
if isinstance(error_data, dict) and 'code' in error_data and 'message' in error_data:
|
| 271 |
+
# 这是一个错误响应
|
| 272 |
+
self.logger.error(f"❌ LongCat API 返回错误: {error_data}")
|
| 273 |
+
error_message = error_data.get('message', '未知错误')
|
| 274 |
+
error_code = error_data.get('code', 'unknown')
|
| 275 |
+
|
| 276 |
+
# 使用统一的错误处理函数
|
| 277 |
+
error_exception = Exception(f"LongCat API 错误 ({error_code}): {error_message}")
|
| 278 |
+
error_response = self.handle_error(error_exception, "API响应")
|
| 279 |
+
|
| 280 |
+
# 发送错误响应块
|
| 281 |
+
yield await self.format_sse_chunk(error_response)
|
| 282 |
+
yield await self.format_sse_done()
|
| 283 |
+
|
| 284 |
+
# 清理会话
|
| 285 |
+
if not session_deleted:
|
| 286 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 287 |
+
session_deleted = True
|
| 288 |
+
return
|
| 289 |
+
except json.JSONDecodeError:
|
| 290 |
+
# 不是JSON,跳过这行
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# 如果不是错误响应,跳过
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
data_str = line[5:].strip()
|
| 297 |
+
if data_str == '[DONE]':
|
| 298 |
+
# 如果还没有发送完成块,发送一个
|
| 299 |
+
if not stream_finished:
|
| 300 |
+
yield await self.format_sse_chunk(
|
| 301 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 302 |
+
)
|
| 303 |
+
yield await self.format_sse_done()
|
| 304 |
+
|
| 305 |
+
# 清理会话
|
| 306 |
+
if not session_deleted:
|
| 307 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 308 |
+
session_deleted = True
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
longcat_data = json.loads(data_str)
|
| 313 |
+
|
| 314 |
+
# 获取 delta 内容
|
| 315 |
+
choices = longcat_data.get("choices", [])
|
| 316 |
+
if not choices:
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
delta = choices[0].get("delta", {})
|
| 320 |
+
content = delta.get("content")
|
| 321 |
+
finish_reason = choices[0].get("finishReason")
|
| 322 |
+
|
| 323 |
+
# 只有当内容不为空时才发送内容块
|
| 324 |
+
if content is not None and content != "":
|
| 325 |
+
openai_chunk = self.create_openai_chunk(
|
| 326 |
+
chat_id,
|
| 327 |
+
model,
|
| 328 |
+
{"content": content}
|
| 329 |
+
)
|
| 330 |
+
yield await self.format_sse_chunk(openai_chunk)
|
| 331 |
+
|
| 332 |
+
# 检查是否为流的结束
|
| 333 |
+
# LongCat 使用 lastOne=true 来标识最后一个块
|
| 334 |
+
if longcat_data.get("lastOne") and not stream_finished:
|
| 335 |
+
yield await self.format_sse_chunk(
|
| 336 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 337 |
+
)
|
| 338 |
+
yield await self.format_sse_done()
|
| 339 |
+
stream_finished = True
|
| 340 |
+
|
| 341 |
+
# 清理会话
|
| 342 |
+
if not session_deleted:
|
| 343 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 344 |
+
session_deleted = True
|
| 345 |
+
break
|
| 346 |
+
|
| 347 |
+
# 备用检查:如果有 finishReason 但没有 lastOne,也可能是结束
|
| 348 |
+
elif finish_reason == "stop" and longcat_data.get("contentStatus") == "FINISHED" and not stream_finished:
|
| 349 |
+
yield await self.format_sse_chunk(
|
| 350 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 351 |
+
)
|
| 352 |
+
yield await self.format_sse_done()
|
| 353 |
+
stream_finished = True
|
| 354 |
+
|
| 355 |
+
# 清理会话
|
| 356 |
+
if not session_deleted:
|
| 357 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 358 |
+
session_deleted = True
|
| 359 |
+
break
|
| 360 |
+
|
| 361 |
+
except json.JSONDecodeError as e:
|
| 362 |
+
self.logger.error(f"❌ 解析LongCat流数据错误: {e}")
|
| 363 |
+
continue
|
| 364 |
+
except Exception as e:
|
| 365 |
+
self.logger.error(f"❌ 处理LongCat流数据错误: {e}")
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
except Exception as e:
|
| 369 |
+
self.logger.error(f"❌ LongCat流处理错误: {e}")
|
| 370 |
+
# 发送错误结束块(只有在还没有结束的情况下)
|
| 371 |
+
if not stream_finished:
|
| 372 |
+
yield await self.format_sse_chunk(
|
| 373 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 374 |
+
)
|
| 375 |
+
yield await self.format_sse_done()
|
| 376 |
+
finally:
|
| 377 |
+
# 确保会话被清理
|
| 378 |
+
if not session_deleted:
|
| 379 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 380 |
+
|
| 381 |
+
async def _handle_non_stream_response(
|
| 382 |
+
self,
|
| 383 |
+
response: httpx.Response,
|
| 384 |
+
chat_id: str,
|
| 385 |
+
model: str,
|
| 386 |
+
conversation_id: str,
|
| 387 |
+
passport_token: str,
|
| 388 |
+
user_agent: str
|
| 389 |
+
) -> Dict[str, Any]:
|
| 390 |
+
"""处理LongCat非流式响应"""
|
| 391 |
+
full_content = ""
|
| 392 |
+
usage_info = {
|
| 393 |
+
"prompt_tokens": 0,
|
| 394 |
+
"completion_tokens": 0,
|
| 395 |
+
"total_tokens": 0
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
async for line in response.aiter_lines():
|
| 400 |
+
line = line.strip()
|
| 401 |
+
if not line.startswith('data:'):
|
| 402 |
+
# 检查是否是错误响应
|
| 403 |
+
try:
|
| 404 |
+
error_data = json.loads(line)
|
| 405 |
+
if isinstance(error_data, dict) and 'code' in error_data and 'message' in error_data:
|
| 406 |
+
# 这是一个错误响应
|
| 407 |
+
self.logger.error(f"❌ LongCat API 返回错误: {error_data}")
|
| 408 |
+
error_message = error_data.get('message', '未知错误')
|
| 409 |
+
error_code = error_data.get('code', 'unknown')
|
| 410 |
+
|
| 411 |
+
# 使用统一的错误处理函数
|
| 412 |
+
error_exception = Exception(f"LongCat API 错误 ({error_code}): {error_message}")
|
| 413 |
+
|
| 414 |
+
# 清理会话
|
| 415 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 416 |
+
|
| 417 |
+
return self.handle_error(error_exception, "API响应")
|
| 418 |
+
except json.JSONDecodeError:
|
| 419 |
+
# 不是JSON,跳过这行
|
| 420 |
+
pass
|
| 421 |
+
continue
|
| 422 |
+
|
| 423 |
+
data_str = line[5:].strip()
|
| 424 |
+
if data_str == '[DONE]':
|
| 425 |
+
break
|
| 426 |
+
|
| 427 |
+
try:
|
| 428 |
+
chunk = json.loads(data_str)
|
| 429 |
+
|
| 430 |
+
# 提取内容 - 只有当内容不为空时才添加
|
| 431 |
+
choices = chunk.get("choices", [])
|
| 432 |
+
if choices:
|
| 433 |
+
delta = choices[0].get("delta", {})
|
| 434 |
+
content = delta.get("content")
|
| 435 |
+
if content is not None and content != "":
|
| 436 |
+
full_content += content
|
| 437 |
+
|
| 438 |
+
# 提取使用信息(通常在最后的块中)
|
| 439 |
+
if chunk.get("tokenInfo"):
|
| 440 |
+
token_info = chunk["tokenInfo"]
|
| 441 |
+
usage_info = {
|
| 442 |
+
"prompt_tokens": token_info.get("promptTokens", 0),
|
| 443 |
+
"completion_tokens": token_info.get("completionTokens", 0),
|
| 444 |
+
"total_tokens": token_info.get("totalTokens", 0)
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
# 如果是最后一个块,可以提前结束
|
| 448 |
+
if chunk.get("lastOne"):
|
| 449 |
+
break
|
| 450 |
+
|
| 451 |
+
except json.JSONDecodeError:
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
except Exception as e:
|
| 455 |
+
self.logger.error(f"❌ 处理LongCat非流式响应错误: {e}")
|
| 456 |
+
full_content = "处理响应时发生错误"
|
| 457 |
+
finally:
|
| 458 |
+
# 清理会话
|
| 459 |
+
self.schedule_session_deletion(conversation_id, passport_token, user_agent)
|
| 460 |
+
|
| 461 |
+
return self.create_openai_response(
|
| 462 |
+
chat_id,
|
| 463 |
+
model,
|
| 464 |
+
full_content.strip(),
|
| 465 |
+
usage_info
|
| 466 |
+
)
|
app/providers/provider_factory.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
提供商工厂和路由机制
|
| 6 |
+
负责根据模型名称自动选择合适的提供商
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
from typing import Dict, List, Optional, Union, AsyncGenerator, Any
|
| 11 |
+
from app.providers.base import BaseProvider, provider_registry
|
| 12 |
+
from app.providers.zai_provider import ZAIProvider
|
| 13 |
+
from app.providers.k2think_provider import K2ThinkProvider
|
| 14 |
+
from app.providers.longcat_provider import LongCatProvider
|
| 15 |
+
from app.models.schemas import OpenAIRequest
|
| 16 |
+
from app.core.config import settings
|
| 17 |
+
from app.utils.logger import get_logger
|
| 18 |
+
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ProviderFactory:
|
| 23 |
+
"""提供商工厂"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self._initialized = False
|
| 27 |
+
self._default_provider = "zai"
|
| 28 |
+
|
| 29 |
+
def initialize(self):
|
| 30 |
+
"""初始化所有提供商"""
|
| 31 |
+
if self._initialized:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# 注册 Z.AI 提供商
|
| 36 |
+
zai_provider = ZAIProvider()
|
| 37 |
+
provider_registry.register(
|
| 38 |
+
zai_provider,
|
| 39 |
+
zai_provider.get_supported_models()
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# 注册 K2Think 提供商
|
| 43 |
+
k2think_provider = K2ThinkProvider()
|
| 44 |
+
provider_registry.register(
|
| 45 |
+
k2think_provider,
|
| 46 |
+
k2think_provider.get_supported_models()
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 注册 LongCat 提供商
|
| 50 |
+
longcat_provider = LongCatProvider()
|
| 51 |
+
provider_registry.register(
|
| 52 |
+
longcat_provider,
|
| 53 |
+
longcat_provider.get_supported_models()
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self._initialized = True
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"❌ 提供商工厂初始化失败: {e}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
def get_provider_for_model(self, model: str) -> Optional[BaseProvider]:
|
| 63 |
+
"""根据模型名称获取提供商"""
|
| 64 |
+
if not self._initialized:
|
| 65 |
+
self.initialize()
|
| 66 |
+
|
| 67 |
+
# 首先尝试从配置的映射中获取
|
| 68 |
+
provider_mapping = settings.provider_model_mapping
|
| 69 |
+
provider_name = provider_mapping.get(model)
|
| 70 |
+
|
| 71 |
+
if provider_name:
|
| 72 |
+
provider = provider_registry.get_provider_by_name(provider_name)
|
| 73 |
+
if provider:
|
| 74 |
+
logger.debug(f"🎯 模型 {model} 映射到提供商 {provider_name}")
|
| 75 |
+
return provider
|
| 76 |
+
|
| 77 |
+
# 尝试从注册表中直接获取
|
| 78 |
+
provider = provider_registry.get_provider(model)
|
| 79 |
+
if provider:
|
| 80 |
+
logger.debug(f"🎯 模型 {model} 找到提供商 {provider.name}")
|
| 81 |
+
return provider
|
| 82 |
+
|
| 83 |
+
# 使用默认提供商
|
| 84 |
+
default_provider = provider_registry.get_provider_by_name(self._default_provider)
|
| 85 |
+
if default_provider:
|
| 86 |
+
logger.warning(f"⚠️ 模型 {model} 未找到专用提供商,使用默认提供商 {self._default_provider}")
|
| 87 |
+
return default_provider
|
| 88 |
+
|
| 89 |
+
logger.error(f"❌ 无法为模型 {model} 找到任何提供商")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
def list_supported_models(self) -> List[str]:
|
| 93 |
+
"""列出所有支持的模型"""
|
| 94 |
+
if not self._initialized:
|
| 95 |
+
self.initialize()
|
| 96 |
+
return provider_registry.list_models()
|
| 97 |
+
|
| 98 |
+
def list_providers(self) -> List[str]:
|
| 99 |
+
"""列出所有提供商"""
|
| 100 |
+
if not self._initialized:
|
| 101 |
+
self.initialize()
|
| 102 |
+
return provider_registry.list_providers()
|
| 103 |
+
|
| 104 |
+
def get_models_for_provider(self, provider_name: str) -> List[str]:
|
| 105 |
+
"""获取指定提供商支持的模型"""
|
| 106 |
+
if not self._initialized:
|
| 107 |
+
self.initialize()
|
| 108 |
+
|
| 109 |
+
provider = provider_registry.get_provider_by_name(provider_name)
|
| 110 |
+
if provider:
|
| 111 |
+
return provider.get_supported_models()
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ProviderRouter:
|
| 116 |
+
"""提供商路由器"""
|
| 117 |
+
|
| 118 |
+
def __init__(self):
|
| 119 |
+
self.factory = ProviderFactory()
|
| 120 |
+
|
| 121 |
+
async def route_request(
|
| 122 |
+
self,
|
| 123 |
+
request: OpenAIRequest,
|
| 124 |
+
**kwargs
|
| 125 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 126 |
+
"""路由请求到合适的提供商"""
|
| 127 |
+
logger.info(f"🚦 路由请求: 模型={request.model}, 流式={request.stream}")
|
| 128 |
+
|
| 129 |
+
# 获取提供商
|
| 130 |
+
provider = self.factory.get_provider_for_model(request.model)
|
| 131 |
+
if not provider:
|
| 132 |
+
error_msg = f"不支持的模型: {request.model}"
|
| 133 |
+
logger.error(f"❌ {error_msg}")
|
| 134 |
+
return {
|
| 135 |
+
"error": {
|
| 136 |
+
"message": error_msg,
|
| 137 |
+
"type": "invalid_request_error",
|
| 138 |
+
"code": "model_not_found"
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
logger.info(f"✅ 使用提供商: {provider.name}")
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# 调用提供商处理请求
|
| 146 |
+
result = await provider.chat_completion(request, **kwargs)
|
| 147 |
+
logger.info(f"🎉 请求处理��成: {provider.name}")
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
error_msg = f"提供商 {provider.name} 处理请求失败: {str(e)}"
|
| 152 |
+
logger.error(f"❌ {error_msg}")
|
| 153 |
+
return provider.handle_error(e, "路由处理")
|
| 154 |
+
|
| 155 |
+
def get_models_list(self) -> Dict[str, Any]:
|
| 156 |
+
"""获取模型列表(OpenAI格式)"""
|
| 157 |
+
models = []
|
| 158 |
+
current_time = int(time.time())
|
| 159 |
+
|
| 160 |
+
# 按提供商分组获取模型
|
| 161 |
+
for provider_name in self.factory.list_providers():
|
| 162 |
+
provider_models = self.factory.get_models_for_provider(provider_name)
|
| 163 |
+
for model in provider_models:
|
| 164 |
+
models.append({
|
| 165 |
+
"id": model,
|
| 166 |
+
"object": "model",
|
| 167 |
+
"created": current_time,
|
| 168 |
+
"owned_by": provider_name
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"object": "list",
|
| 173 |
+
"data": models
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# 全局路由器实例
|
| 178 |
+
_router: Optional[ProviderRouter] = None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_provider_router() -> ProviderRouter:
|
| 182 |
+
"""获取全局提供商路由器"""
|
| 183 |
+
global _router
|
| 184 |
+
if _router is None:
|
| 185 |
+
_router = ProviderRouter()
|
| 186 |
+
# 确保工厂已初始化
|
| 187 |
+
_router.factory.initialize()
|
| 188 |
+
return _router
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def initialize_providers():
|
| 192 |
+
"""初始化提供商系统"""
|
| 193 |
+
logger.info("🚀 初始化提供商系统...")
|
| 194 |
+
router = get_provider_router()
|
| 195 |
+
logger.info("✅ 提供商系统初始化完成")
|
| 196 |
+
return router
|
app/providers/zai_provider.py
ADDED
|
@@ -0,0 +1,764 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Z.AI 提供商适配器
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
import httpx
|
| 12 |
+
import asyncio
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from typing import Dict, List, Any, Optional, AsyncGenerator, Union
|
| 15 |
+
|
| 16 |
+
from app.providers.base import BaseProvider, ProviderConfig
|
| 17 |
+
from app.models.schemas import OpenAIRequest, Message
|
| 18 |
+
from app.core.config import settings
|
| 19 |
+
from app.utils.logger import get_logger
|
| 20 |
+
from app.utils.token_pool import get_token_pool
|
| 21 |
+
from app.core.zai_transformer import generate_uuid, get_zai_dynamic_headers
|
| 22 |
+
from app.utils.sse_tool_handler import SSEToolHandler
|
| 23 |
+
|
| 24 |
+
logger = get_logger()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ZAIProvider(BaseProvider):
|
| 28 |
+
"""Z.AI 提供商"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
config = ProviderConfig(
|
| 32 |
+
name="zai",
|
| 33 |
+
api_endpoint=settings.API_ENDPOINT,
|
| 34 |
+
timeout=30,
|
| 35 |
+
headers=get_zai_dynamic_headers()
|
| 36 |
+
)
|
| 37 |
+
super().__init__(config)
|
| 38 |
+
|
| 39 |
+
# Z.AI 特定配置
|
| 40 |
+
self.base_url = "https://chat.z.ai"
|
| 41 |
+
self.auth_url = f"{self.base_url}/api/v1/auths/"
|
| 42 |
+
|
| 43 |
+
# 模型映射
|
| 44 |
+
self.model_mapping = {
|
| 45 |
+
settings.PRIMARY_MODEL: "0727-360B-API", # GLM-4.5
|
| 46 |
+
settings.THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking
|
| 47 |
+
settings.SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search
|
| 48 |
+
settings.AIR_MODEL: "0727-106B-API", # GLM-4.5-Air
|
| 49 |
+
settings.GLM46_MODEL: "GLM-4-6-API-V1", # GLM-4.6
|
| 50 |
+
settings.GLM46_THINKING_MODEL: "GLM-4-6-API-V1", # GLM-4.6-Thinking
|
| 51 |
+
settings.GLM46_SEARCH_MODEL: "GLM-4-6-API-V1", # GLM-4.6-Search
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
def get_supported_models(self) -> List[str]:
|
| 55 |
+
"""获取支持的模型列表"""
|
| 56 |
+
return [
|
| 57 |
+
settings.PRIMARY_MODEL,
|
| 58 |
+
settings.THINKING_MODEL,
|
| 59 |
+
settings.SEARCH_MODEL,
|
| 60 |
+
settings.AIR_MODEL,
|
| 61 |
+
settings.GLM46_MODEL,
|
| 62 |
+
settings.GLM46_THINKING_MODEL,
|
| 63 |
+
settings.GLM46_SEARCH_MODEL,
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
async def get_token(self) -> str:
|
| 67 |
+
"""获取认证令牌"""
|
| 68 |
+
# 如果启用匿名模式,只尝试获取访客令牌
|
| 69 |
+
if settings.ANONYMOUS_MODE:
|
| 70 |
+
try:
|
| 71 |
+
headers = get_zai_dynamic_headers()
|
| 72 |
+
async with httpx.AsyncClient() as client:
|
| 73 |
+
response = await client.get(self.auth_url, headers=headers, timeout=10.0)
|
| 74 |
+
if response.status_code == 200:
|
| 75 |
+
data = response.json()
|
| 76 |
+
token = data.get("token", "")
|
| 77 |
+
if token:
|
| 78 |
+
self.logger.debug(f"获取访客令牌成功: {token[:20]}...")
|
| 79 |
+
return token
|
| 80 |
+
except Exception as e:
|
| 81 |
+
self.logger.warning(f"异步获取访客令牌失败: {e}")
|
| 82 |
+
|
| 83 |
+
# 匿名模式下,如果获取访客令牌失败,直接返回空
|
| 84 |
+
self.logger.error("❌ 匿名模式下获取访客令牌失败")
|
| 85 |
+
return ""
|
| 86 |
+
|
| 87 |
+
# 非匿名模式:首先使用token池获取备份令牌
|
| 88 |
+
token_pool = get_token_pool()
|
| 89 |
+
if token_pool:
|
| 90 |
+
token = token_pool.get_next_token()
|
| 91 |
+
if token:
|
| 92 |
+
self.logger.debug(f"从token池获取令牌: {token[:20]}...")
|
| 93 |
+
return token
|
| 94 |
+
|
| 95 |
+
# 如果token池为空或没有可用token,使用配置的AUTH_TOKEN
|
| 96 |
+
if settings.AUTH_TOKEN and settings.AUTH_TOKEN != "sk-your-api-key":
|
| 97 |
+
self.logger.debug("使用配置的AUTH_TOKEN")
|
| 98 |
+
return settings.AUTH_TOKEN
|
| 99 |
+
|
| 100 |
+
self.logger.error("❌ 无法获取有效的认证令牌")
|
| 101 |
+
return ""
|
| 102 |
+
|
| 103 |
+
def mark_token_failure(self, token: str, error: Exception = None):
|
| 104 |
+
"""标记token使用失败"""
|
| 105 |
+
token_pool = get_token_pool()
|
| 106 |
+
if token_pool:
|
| 107 |
+
token_pool.mark_token_failure(token, error)
|
| 108 |
+
|
| 109 |
+
async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
|
| 110 |
+
"""转换OpenAI请求为Z.AI格式"""
|
| 111 |
+
self.logger.info(f"🔄 转换 OpenAI 请求到 Z.AI 格式: {request.model}")
|
| 112 |
+
|
| 113 |
+
# 获取认证令牌
|
| 114 |
+
token = await self.get_token()
|
| 115 |
+
|
| 116 |
+
# 处理消息格式
|
| 117 |
+
messages = []
|
| 118 |
+
for msg in request.messages:
|
| 119 |
+
if isinstance(msg.content, str):
|
| 120 |
+
messages.append({
|
| 121 |
+
"role": msg.role,
|
| 122 |
+
"content": msg.content
|
| 123 |
+
})
|
| 124 |
+
elif isinstance(msg.content, list):
|
| 125 |
+
# 处理多模态内容
|
| 126 |
+
content_parts = []
|
| 127 |
+
for part in msg.content:
|
| 128 |
+
if hasattr(part, 'type') and hasattr(part, 'text'):
|
| 129 |
+
content_parts.append({
|
| 130 |
+
"type": part.type,
|
| 131 |
+
"text": part.text
|
| 132 |
+
})
|
| 133 |
+
messages.append({
|
| 134 |
+
"role": msg.role,
|
| 135 |
+
"content": content_parts
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
# 确定请求的模型特性
|
| 139 |
+
requested_model = request.model
|
| 140 |
+
is_thinking = "-thinking" in requested_model.casefold()
|
| 141 |
+
is_search = "-search" in requested_model.casefold()
|
| 142 |
+
is_air = "-air" in requested_model.casefold()
|
| 143 |
+
|
| 144 |
+
# 获取上游模型ID
|
| 145 |
+
upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
|
| 146 |
+
|
| 147 |
+
# 构建MCP服务器列表
|
| 148 |
+
mcp_servers = []
|
| 149 |
+
if is_search and "-4.5" in requested_model:
|
| 150 |
+
mcp_servers.append("deep-web-search")
|
| 151 |
+
self.logger.info("🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
|
| 152 |
+
|
| 153 |
+
# 构建上游请求体
|
| 154 |
+
chat_id = generate_uuid()
|
| 155 |
+
|
| 156 |
+
body = {
|
| 157 |
+
"stream": True, # 总是使用流式
|
| 158 |
+
"model": upstream_model_id,
|
| 159 |
+
"messages": messages,
|
| 160 |
+
"params": {},
|
| 161 |
+
"features": {
|
| 162 |
+
"image_generation": False,
|
| 163 |
+
"web_search": is_search,
|
| 164 |
+
"auto_web_search": is_search,
|
| 165 |
+
"preview_mode": False,
|
| 166 |
+
"flags": [],
|
| 167 |
+
"features": [
|
| 168 |
+
{
|
| 169 |
+
"type": "mcp",
|
| 170 |
+
"server": "vibe-coding",
|
| 171 |
+
"status": "hidden"
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"type": "mcp",
|
| 175 |
+
"server": "ppt-maker",
|
| 176 |
+
"status": "hidden"
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"type": "mcp",
|
| 180 |
+
"server": "image-search",
|
| 181 |
+
"status": "hidden"
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"type": "mcp",
|
| 185 |
+
"server": "deep-research",
|
| 186 |
+
"status": "hidden"
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"type": "tool_selector",
|
| 190 |
+
"server": "tool_selector",
|
| 191 |
+
"status": "hidden"
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"type": "mcp",
|
| 195 |
+
"server": "advanced-search",
|
| 196 |
+
"status": "hidden"
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"enable_thinking": is_thinking,
|
| 200 |
+
},
|
| 201 |
+
"background_tasks": {
|
| 202 |
+
"title_generation": False,
|
| 203 |
+
"tags_generation": False,
|
| 204 |
+
},
|
| 205 |
+
"mcp_servers": mcp_servers,
|
| 206 |
+
"variables": {
|
| 207 |
+
"{{USER_NAME}}": "Guest",
|
| 208 |
+
"{{USER_LOCATION}}": "Unknown",
|
| 209 |
+
"{{CURRENT_DATETIME}}": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 210 |
+
"{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
|
| 211 |
+
"{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
|
| 212 |
+
"{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
|
| 213 |
+
"{{CURRENT_TIMEZONE}}": "Asia/Shanghai",
|
| 214 |
+
"{{USER_LANGUAGE}}": "zh-CN",
|
| 215 |
+
},
|
| 216 |
+
"model_item": {
|
| 217 |
+
"id": upstream_model_id,
|
| 218 |
+
"name": requested_model,
|
| 219 |
+
"owned_by": "z.ai"
|
| 220 |
+
},
|
| 221 |
+
"chat_id": chat_id,
|
| 222 |
+
"id": generate_uuid(),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# 处理工具支持
|
| 226 |
+
if settings.TOOL_SUPPORT and not is_thinking and request.tools:
|
| 227 |
+
body["tools"] = request.tools
|
| 228 |
+
self.logger.info(f"启用工具支持: {len(request.tools)} 个工具")
|
| 229 |
+
else:
|
| 230 |
+
body["tools"] = None
|
| 231 |
+
|
| 232 |
+
# 处理其他参数
|
| 233 |
+
if request.temperature is not None:
|
| 234 |
+
body["params"]["temperature"] = request.temperature
|
| 235 |
+
if request.max_tokens is not None:
|
| 236 |
+
body["params"]["max_tokens"] = request.max_tokens
|
| 237 |
+
|
| 238 |
+
# 构建请求头
|
| 239 |
+
headers = get_zai_dynamic_headers(chat_id)
|
| 240 |
+
if token:
|
| 241 |
+
headers["Authorization"] = f"Bearer {token}"
|
| 242 |
+
|
| 243 |
+
# 存储当前token用于错误处理
|
| 244 |
+
self._current_token = token
|
| 245 |
+
|
| 246 |
+
return {
|
| 247 |
+
"url": self.config.api_endpoint,
|
| 248 |
+
"headers": headers,
|
| 249 |
+
"body": body,
|
| 250 |
+
"token": token,
|
| 251 |
+
"chat_id": chat_id,
|
| 252 |
+
"model": requested_model
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
async def chat_completion(
|
| 256 |
+
self,
|
| 257 |
+
request: OpenAIRequest,
|
| 258 |
+
**kwargs
|
| 259 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 260 |
+
"""聊天完成接口"""
|
| 261 |
+
self.log_request(request)
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# 转换请求
|
| 265 |
+
transformed = await self.transform_request(request)
|
| 266 |
+
|
| 267 |
+
# 根据请求类型返回响应
|
| 268 |
+
if request.stream:
|
| 269 |
+
# 流式响应
|
| 270 |
+
return self._create_stream_response_with_retry(request, transformed)
|
| 271 |
+
else:
|
| 272 |
+
# 非流式响应
|
| 273 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 274 |
+
response = await client.post(
|
| 275 |
+
transformed["url"],
|
| 276 |
+
headers=transformed["headers"],
|
| 277 |
+
json=transformed["body"]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if not response.is_success:
|
| 281 |
+
error_msg = f"Z.AI API 错误: {response.status_code}"
|
| 282 |
+
self.log_response(False, error_msg)
|
| 283 |
+
return self.handle_error(Exception(error_msg))
|
| 284 |
+
|
| 285 |
+
return await self.transform_response(response, request, transformed)
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
self.log_response(False, str(e))
|
| 289 |
+
return self.handle_error(e, "请求处理")
|
| 290 |
+
|
| 291 |
+
async def _create_stream_response_with_retry(
|
| 292 |
+
self,
|
| 293 |
+
request: OpenAIRequest,
|
| 294 |
+
transformed: Dict[str, Any]
|
| 295 |
+
) -> AsyncGenerator[str, None]:
|
| 296 |
+
"""创建带重试机制的流式响应生成器"""
|
| 297 |
+
retry_count = 0
|
| 298 |
+
last_error = None
|
| 299 |
+
current_token = transformed.get("token", "")
|
| 300 |
+
|
| 301 |
+
while retry_count <= settings.MAX_RETRIES:
|
| 302 |
+
try:
|
| 303 |
+
# 如果是重试,重新获取令牌并更新请求
|
| 304 |
+
if retry_count > 0:
|
| 305 |
+
delay = settings.RETRY_DELAY
|
| 306 |
+
self.logger.warning(f"重试请求 ({retry_count}/{settings.MAX_RETRIES}) - 等待 {delay:.1f}s")
|
| 307 |
+
await asyncio.sleep(delay)
|
| 308 |
+
|
| 309 |
+
# 标记前一个token失败(如果不是匿名模式)
|
| 310 |
+
if current_token and not settings.ANONYMOUS_MODE:
|
| 311 |
+
self.mark_token_failure(current_token, Exception(f"Retry {retry_count}: {last_error}"))
|
| 312 |
+
|
| 313 |
+
# 重新获取令牌
|
| 314 |
+
self.logger.info("🔑 重新获取令牌用于重试...")
|
| 315 |
+
new_token = await self.get_token()
|
| 316 |
+
if not new_token:
|
| 317 |
+
self.logger.error("❌ 重试时无法获取有效的认证令牌")
|
| 318 |
+
raise Exception("重试时无法获取有效的认证令牌")
|
| 319 |
+
transformed["headers"]["Authorization"] = f"Bearer {new_token}"
|
| 320 |
+
current_token = new_token
|
| 321 |
+
|
| 322 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 323 |
+
# 发送请求到上游
|
| 324 |
+
self.logger.info(f"🎯 发送请求到 Z.AI: {transformed['url']}")
|
| 325 |
+
async with client.stream(
|
| 326 |
+
"POST",
|
| 327 |
+
transformed["url"],
|
| 328 |
+
json=transformed["body"],
|
| 329 |
+
headers=transformed["headers"],
|
| 330 |
+
) as response:
|
| 331 |
+
# 检查响应状态码
|
| 332 |
+
if response.status_code == 400:
|
| 333 |
+
# 400 错误,触发重试
|
| 334 |
+
error_text = await response.aread()
|
| 335 |
+
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 336 |
+
self.logger.warning(f"❌ 上游返回 400 错误 (尝试 {retry_count + 1}/{settings.MAX_RETRIES + 1})")
|
| 337 |
+
|
| 338 |
+
retry_count += 1
|
| 339 |
+
last_error = f"400 Bad Request: {error_msg}"
|
| 340 |
+
|
| 341 |
+
# 如果还有重试机会,继续循环
|
| 342 |
+
if retry_count <= settings.MAX_RETRIES:
|
| 343 |
+
continue
|
| 344 |
+
else:
|
| 345 |
+
# 达到最大重试次数,抛出错误
|
| 346 |
+
self.logger.error(f"❌ 达到最大重试次数 ({settings.MAX_RETRIES}),请求失败")
|
| 347 |
+
error_response = {
|
| 348 |
+
"error": {
|
| 349 |
+
"message": f"Request failed after {settings.MAX_RETRIES} retries: {last_error}",
|
| 350 |
+
"type": "upstream_error",
|
| 351 |
+
"code": 400
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 355 |
+
yield "data: [DONE]\n\n"
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
elif response.status_code != 200:
|
| 359 |
+
# 其他错误,直接返回
|
| 360 |
+
self.logger.error(f"❌ 上游返回错误: {response.status_code}")
|
| 361 |
+
error_text = await response.aread()
|
| 362 |
+
error_msg = error_text.decode('utf-8', errors='ignore')
|
| 363 |
+
self.logger.error(f"❌ 错误详情: {error_msg}")
|
| 364 |
+
|
| 365 |
+
error_response = {
|
| 366 |
+
"error": {
|
| 367 |
+
"message": f"Upstream error: {response.status_code}",
|
| 368 |
+
"type": "upstream_error",
|
| 369 |
+
"code": response.status_code
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 373 |
+
yield "data: [DONE]\n\n"
|
| 374 |
+
return
|
| 375 |
+
|
| 376 |
+
# 200 成功,处理响应
|
| 377 |
+
if retry_count > 0:
|
| 378 |
+
self.logger.info(f"✨ 第 {retry_count} 次重试成功")
|
| 379 |
+
|
| 380 |
+
# 标记token使用成功(如果不是匿名模式)
|
| 381 |
+
if current_token and not settings.ANONYMOUS_MODE:
|
| 382 |
+
token_pool = get_token_pool()
|
| 383 |
+
if token_pool:
|
| 384 |
+
token_pool.mark_token_success(current_token)
|
| 385 |
+
|
| 386 |
+
# 处理流式响应
|
| 387 |
+
chat_id = transformed["chat_id"]
|
| 388 |
+
model = transformed["model"]
|
| 389 |
+
async for chunk in self._handle_stream_response(response, chat_id, model, request, transformed):
|
| 390 |
+
yield chunk
|
| 391 |
+
return
|
| 392 |
+
|
| 393 |
+
except Exception as e:
|
| 394 |
+
self.logger.error(f"❌ 流处理错误: {e}")
|
| 395 |
+
import traceback
|
| 396 |
+
self.logger.error(traceback.format_exc())
|
| 397 |
+
|
| 398 |
+
# 标记token失败(如果不是匿名模式)
|
| 399 |
+
if current_token and not settings.ANONYMOUS_MODE:
|
| 400 |
+
self.mark_token_failure(current_token, e)
|
| 401 |
+
|
| 402 |
+
# 检查是否还可以重试
|
| 403 |
+
retry_count += 1
|
| 404 |
+
last_error = str(e)
|
| 405 |
+
|
| 406 |
+
if retry_count > settings.MAX_RETRIES:
|
| 407 |
+
# 达到最大重试次数,返回错误
|
| 408 |
+
self.logger.error(f"❌ 达到最大重试次数 ({settings.MAX_RETRIES}),流处理失败")
|
| 409 |
+
error_response = {
|
| 410 |
+
"error": {
|
| 411 |
+
"message": f"Stream processing failed after {settings.MAX_RETRIES} retries: {last_error}",
|
| 412 |
+
"type": "stream_error"
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
| 416 |
+
yield "data: [DONE]\n\n"
|
| 417 |
+
return
|
| 418 |
+
|
| 419 |
+
async def transform_response(
|
| 420 |
+
self,
|
| 421 |
+
response: httpx.Response,
|
| 422 |
+
request: OpenAIRequest,
|
| 423 |
+
transformed: Dict[str, Any]
|
| 424 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
| 425 |
+
"""转换Z.AI响应为OpenAI格式"""
|
| 426 |
+
chat_id = transformed["chat_id"]
|
| 427 |
+
model = transformed["model"]
|
| 428 |
+
|
| 429 |
+
if request.stream:
|
| 430 |
+
return self._handle_stream_response(response, chat_id, model, request, transformed)
|
| 431 |
+
else:
|
| 432 |
+
return await self._handle_non_stream_response(response, chat_id, model)
|
| 433 |
+
|
| 434 |
+
async def _handle_stream_response(
|
| 435 |
+
self,
|
| 436 |
+
response: httpx.Response,
|
| 437 |
+
chat_id: str,
|
| 438 |
+
model: str,
|
| 439 |
+
request: OpenAIRequest,
|
| 440 |
+
transformed: Dict[str, Any]
|
| 441 |
+
) -> AsyncGenerator[str, None]:
|
| 442 |
+
"""处理Z.AI流式响应"""
|
| 443 |
+
self.logger.info(f"✅ Z.AI 响应成功,开始处理 SSE 流")
|
| 444 |
+
|
| 445 |
+
# 初始化工具处理器(如果需要)
|
| 446 |
+
has_tools = transformed["body"].get("tools") is not None
|
| 447 |
+
tool_handler = None
|
| 448 |
+
|
| 449 |
+
if has_tools:
|
| 450 |
+
tool_handler = SSEToolHandler(model, stream=True)
|
| 451 |
+
self.logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个工具")
|
| 452 |
+
|
| 453 |
+
# 处理状态
|
| 454 |
+
has_thinking = False
|
| 455 |
+
thinking_signature = None
|
| 456 |
+
|
| 457 |
+
# 处理SSE流
|
| 458 |
+
buffer = ""
|
| 459 |
+
line_count = 0
|
| 460 |
+
self.logger.debug("📡 开始接收 SSE 流数据...")
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
async for line in response.aiter_lines():
|
| 464 |
+
line_count += 1
|
| 465 |
+
if not line:
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
# 累积到buffer处理完整的数据行
|
| 469 |
+
buffer += line + "\n"
|
| 470 |
+
|
| 471 |
+
# 检查是否有完整的data行
|
| 472 |
+
while "\n" in buffer:
|
| 473 |
+
current_line, buffer = buffer.split("\n", 1)
|
| 474 |
+
if not current_line.strip():
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
if current_line.startswith("data:"):
|
| 478 |
+
chunk_str = current_line[5:].strip()
|
| 479 |
+
if not chunk_str or chunk_str == "[DONE]":
|
| 480 |
+
if chunk_str == "[DONE]":
|
| 481 |
+
yield "data: [DONE]\n\n"
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
self.logger.debug(f"📦 解析数据块: {chunk_str[:1000]}..." if len(chunk_str) > 1000 else f"📦 解析数据块: {chunk_str}")
|
| 485 |
+
|
| 486 |
+
try:
|
| 487 |
+
chunk = json.loads(chunk_str)
|
| 488 |
+
|
| 489 |
+
if chunk.get("type") == "chat:completion":
|
| 490 |
+
data = chunk.get("data", {})
|
| 491 |
+
phase = data.get("phase")
|
| 492 |
+
|
| 493 |
+
# 记录每个阶段(只在阶段变化时记录)
|
| 494 |
+
if phase and phase != getattr(self, '_last_phase', None):
|
| 495 |
+
self.logger.info(f"📈 SSE 阶段: {phase}")
|
| 496 |
+
self._last_phase = phase
|
| 497 |
+
|
| 498 |
+
# 使用工具处理器处理所有阶段
|
| 499 |
+
if tool_handler:
|
| 500 |
+
# 构建 SSE 数据块,包含所有必要字段
|
| 501 |
+
sse_chunk = {
|
| 502 |
+
"phase": phase,
|
| 503 |
+
"edit_content": data.get("edit_content", ""),
|
| 504 |
+
"delta_content": data.get("delta_content", ""),
|
| 505 |
+
"edit_index": data.get("edit_index"),
|
| 506 |
+
"usage": data.get("usage", {})
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
# 处理工具调用并输出结果
|
| 510 |
+
for output in tool_handler.process_sse_chunk(sse_chunk):
|
| 511 |
+
yield output
|
| 512 |
+
|
| 513 |
+
# 非工具调用模式 - 处理思考内容
|
| 514 |
+
elif phase == "thinking":
|
| 515 |
+
if not has_thinking:
|
| 516 |
+
has_thinking = True
|
| 517 |
+
# 发送初始角色
|
| 518 |
+
role_chunk = self.create_openai_chunk(
|
| 519 |
+
chat_id,
|
| 520 |
+
model,
|
| 521 |
+
{"role": "assistant"}
|
| 522 |
+
)
|
| 523 |
+
yield await self.format_sse_chunk(role_chunk)
|
| 524 |
+
|
| 525 |
+
delta_content = data.get("delta_content", "")
|
| 526 |
+
if delta_content:
|
| 527 |
+
# 处理思考内容格式
|
| 528 |
+
if delta_content.startswith("<details"):
|
| 529 |
+
content = (
|
| 530 |
+
delta_content.split("</summary>\n>")[-1].strip()
|
| 531 |
+
if "</summary>\n>" in delta_content
|
| 532 |
+
else delta_content
|
| 533 |
+
)
|
| 534 |
+
else:
|
| 535 |
+
content = delta_content
|
| 536 |
+
|
| 537 |
+
thinking_chunk = self.create_openai_chunk(
|
| 538 |
+
chat_id,
|
| 539 |
+
model,
|
| 540 |
+
{
|
| 541 |
+
"role": "assistant",
|
| 542 |
+
"thinking": {"content": content}
|
| 543 |
+
}
|
| 544 |
+
)
|
| 545 |
+
yield await self.format_sse_chunk(thinking_chunk)
|
| 546 |
+
|
| 547 |
+
# 处理答案内容
|
| 548 |
+
elif phase == "answer":
|
| 549 |
+
edit_content = data.get("edit_content", "")
|
| 550 |
+
delta_content = data.get("delta_content", "")
|
| 551 |
+
|
| 552 |
+
# 处理思考结束和答案开始
|
| 553 |
+
if edit_content and "</details>\n" in edit_content:
|
| 554 |
+
if has_thinking:
|
| 555 |
+
# 发送思考签名
|
| 556 |
+
thinking_signature = str(int(time.time() * 1000))
|
| 557 |
+
sig_chunk = self.create_openai_chunk(
|
| 558 |
+
chat_id,
|
| 559 |
+
model,
|
| 560 |
+
{
|
| 561 |
+
"role": "assistant",
|
| 562 |
+
"thinking": {
|
| 563 |
+
"content": "",
|
| 564 |
+
"signature": thinking_signature,
|
| 565 |
+
}
|
| 566 |
+
}
|
| 567 |
+
)
|
| 568 |
+
yield await self.format_sse_chunk(sig_chunk)
|
| 569 |
+
|
| 570 |
+
# 提取答案内容
|
| 571 |
+
content_after = edit_content.split("</details>\n")[-1]
|
| 572 |
+
if content_after:
|
| 573 |
+
content_chunk = self.create_openai_chunk(
|
| 574 |
+
chat_id,
|
| 575 |
+
model,
|
| 576 |
+
{
|
| 577 |
+
"role": "assistant",
|
| 578 |
+
"content": content_after
|
| 579 |
+
}
|
| 580 |
+
)
|
| 581 |
+
yield await self.format_sse_chunk(content_chunk)
|
| 582 |
+
|
| 583 |
+
# 处理增量内容
|
| 584 |
+
elif delta_content:
|
| 585 |
+
# 如果还没有发送角色
|
| 586 |
+
if not has_thinking:
|
| 587 |
+
role_chunk = self.create_openai_chunk(
|
| 588 |
+
chat_id,
|
| 589 |
+
model,
|
| 590 |
+
{"role": "assistant"}
|
| 591 |
+
)
|
| 592 |
+
yield await self.format_sse_chunk(role_chunk)
|
| 593 |
+
|
| 594 |
+
content_chunk = self.create_openai_chunk(
|
| 595 |
+
chat_id,
|
| 596 |
+
model,
|
| 597 |
+
{
|
| 598 |
+
"role": "assistant",
|
| 599 |
+
"content": delta_content
|
| 600 |
+
}
|
| 601 |
+
)
|
| 602 |
+
output_data = await self.format_sse_chunk(content_chunk)
|
| 603 |
+
self.logger.debug(f"➡️ 输出内容块到客户端: {output_data}")
|
| 604 |
+
yield output_data
|
| 605 |
+
|
| 606 |
+
# 处理完成
|
| 607 |
+
if data.get("usage"):
|
| 608 |
+
self.logger.info(f"📦 完成响应 - 使用统计: {json.dumps(data['usage'])}")
|
| 609 |
+
|
| 610 |
+
# 只有在非工具调用模式下才发送普通完成信号
|
| 611 |
+
if not tool_handler:
|
| 612 |
+
finish_chunk = self.create_openai_chunk(
|
| 613 |
+
chat_id,
|
| 614 |
+
model,
|
| 615 |
+
{"role": "assistant", "content": ""},
|
| 616 |
+
"stop"
|
| 617 |
+
)
|
| 618 |
+
finish_chunk["usage"] = data["usage"]
|
| 619 |
+
|
| 620 |
+
finish_output = await self.format_sse_chunk(finish_chunk)
|
| 621 |
+
self.logger.debug(f"➡️ 发送完成信号: {finish_output[:1000]}...")
|
| 622 |
+
yield finish_output
|
| 623 |
+
self.logger.debug("➡️ 发送 [DONE]")
|
| 624 |
+
yield "data: [DONE]\n\n"
|
| 625 |
+
|
| 626 |
+
except json.JSONDecodeError as e:
|
| 627 |
+
self.logger.debug(f"❌ JSON解析错误: {e}, 内容: {chunk_str[:1000]}")
|
| 628 |
+
except Exception as e:
|
| 629 |
+
self.logger.error(f"❌ 处理chunk错误: {e}")
|
| 630 |
+
|
| 631 |
+
# 工具处理器会自动发送结束信号,这里不需要重复发送
|
| 632 |
+
if not tool_handler:
|
| 633 |
+
self.logger.debug("📤 发送最终 [DONE] 信号")
|
| 634 |
+
yield "data: [DONE]\n\n"
|
| 635 |
+
|
| 636 |
+
self.logger.info(f"✅ SSE 流处理完成,共处理 {line_count} 行数据")
|
| 637 |
+
|
| 638 |
+
except Exception as e:
|
| 639 |
+
self.logger.error(f"❌ 流式响应处理错误: {e}")
|
| 640 |
+
import traceback
|
| 641 |
+
self.logger.error(traceback.format_exc())
|
| 642 |
+
# 发送错误结束块
|
| 643 |
+
yield await self.format_sse_chunk(
|
| 644 |
+
self.create_openai_chunk(chat_id, model, {}, "stop")
|
| 645 |
+
)
|
| 646 |
+
yield "data: [DONE]\n\n"
|
| 647 |
+
|
| 648 |
+
async def _handle_non_stream_response(
|
| 649 |
+
self,
|
| 650 |
+
response: httpx.Response,
|
| 651 |
+
chat_id: str,
|
| 652 |
+
model: str
|
| 653 |
+
) -> Dict[str, Any]:
|
| 654 |
+
"""处理非流式响应
|
| 655 |
+
|
| 656 |
+
说明:上游始终以 SSE 形式返回(transform_request 固定 stream=True),
|
| 657 |
+
因此这里需要聚合 aiter_lines() 的 data: 块,提取 usage、思考内容与答案内容,
|
| 658 |
+
并最终产出一次性 OpenAI 格式响应。
|
| 659 |
+
"""
|
| 660 |
+
final_content = ""
|
| 661 |
+
reasoning_content = ""
|
| 662 |
+
usage_info: Dict[str, int] = {
|
| 663 |
+
"prompt_tokens": 0,
|
| 664 |
+
"completion_tokens": 0,
|
| 665 |
+
"total_tokens": 0,
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
try:
|
| 669 |
+
async for line in response.aiter_lines():
|
| 670 |
+
if not line:
|
| 671 |
+
continue
|
| 672 |
+
|
| 673 |
+
line = line.strip()
|
| 674 |
+
|
| 675 |
+
# 仅处理以 data: 开头的 SSE 行,其余行尝试作为错误/JSON 忽略
|
| 676 |
+
if not line.startswith("data:"):
|
| 677 |
+
# 尝试解析为错误 JSON
|
| 678 |
+
try:
|
| 679 |
+
maybe_err = json.loads(line)
|
| 680 |
+
if isinstance(maybe_err, dict) and (
|
| 681 |
+
"error" in maybe_err or "code" in maybe_err or "message" in maybe_err
|
| 682 |
+
):
|
| 683 |
+
# 统一错误处理
|
| 684 |
+
msg = (
|
| 685 |
+
(maybe_err.get("error") or {}).get("message")
|
| 686 |
+
if isinstance(maybe_err.get("error"), dict)
|
| 687 |
+
else maybe_err.get("message")
|
| 688 |
+
) or "上游返回错误"
|
| 689 |
+
return self.handle_error(Exception(msg), "API响应")
|
| 690 |
+
except Exception:
|
| 691 |
+
pass
|
| 692 |
+
continue
|
| 693 |
+
|
| 694 |
+
data_str = line[5:].strip()
|
| 695 |
+
if not data_str or data_str in ("[DONE]", "DONE", "done"):
|
| 696 |
+
continue
|
| 697 |
+
|
| 698 |
+
# 解析 SSE 数据块
|
| 699 |
+
try:
|
| 700 |
+
chunk = json.loads(data_str)
|
| 701 |
+
except json.JSONDecodeError:
|
| 702 |
+
continue
|
| 703 |
+
|
| 704 |
+
if chunk.get("type") != "chat:completion":
|
| 705 |
+
continue
|
| 706 |
+
|
| 707 |
+
data = chunk.get("data", {})
|
| 708 |
+
phase = data.get("phase")
|
| 709 |
+
delta_content = data.get("delta_content", "")
|
| 710 |
+
edit_content = data.get("edit_content", "")
|
| 711 |
+
|
| 712 |
+
# 记录用量(通常在最后块中出现,但这里每次覆盖保持最新)
|
| 713 |
+
if data.get("usage"):
|
| 714 |
+
try:
|
| 715 |
+
usage_info = data["usage"]
|
| 716 |
+
except Exception:
|
| 717 |
+
pass
|
| 718 |
+
|
| 719 |
+
# 思考阶段聚合(去除 <details><summary>... 包裹头)
|
| 720 |
+
if phase == "thinking":
|
| 721 |
+
if delta_content:
|
| 722 |
+
if delta_content.startswith("<details"):
|
| 723 |
+
cleaned = (
|
| 724 |
+
delta_content.split("</summary>\n>")[-1].strip()
|
| 725 |
+
if "</summary>\n>" in delta_content
|
| 726 |
+
else delta_content
|
| 727 |
+
)
|
| 728 |
+
else:
|
| 729 |
+
cleaned = delta_content
|
| 730 |
+
reasoning_content += cleaned
|
| 731 |
+
|
| 732 |
+
# 答案阶段聚合
|
| 733 |
+
elif phase == "answer":
|
| 734 |
+
# 当 edit_content 同时包含思考结束标记与答案时,提取答案部分
|
| 735 |
+
if edit_content and "</details>\n" in edit_content:
|
| 736 |
+
content_after = edit_content.split("</details>\n")[-1]
|
| 737 |
+
if content_after:
|
| 738 |
+
final_content += content_after
|
| 739 |
+
elif delta_content:
|
| 740 |
+
final_content += delta_content
|
| 741 |
+
|
| 742 |
+
except Exception as e:
|
| 743 |
+
self.logger.error(f"❌ 非流式响应处理错误: {e}")
|
| 744 |
+
import traceback
|
| 745 |
+
self.logger.error(traceback.format_exc())
|
| 746 |
+
# 返回统一错误响应
|
| 747 |
+
return self.handle_error(e, "非流式聚合")
|
| 748 |
+
|
| 749 |
+
# 清理并返回
|
| 750 |
+
final_content = (final_content or "").strip()
|
| 751 |
+
reasoning_content = (reasoning_content or "").strip()
|
| 752 |
+
|
| 753 |
+
# 若没有聚合到答案,但有思考内容,则保底返回思考内容
|
| 754 |
+
if not final_content and reasoning_content:
|
| 755 |
+
final_content = reasoning_content
|
| 756 |
+
|
| 757 |
+
# 返回包含推理内容的标准响应(若无推理则不会携带)
|
| 758 |
+
return self.create_openai_response_with_reasoning(
|
| 759 |
+
chat_id,
|
| 760 |
+
model,
|
| 761 |
+
final_content,
|
| 762 |
+
reasoning_content,
|
| 763 |
+
usage_info,
|
| 764 |
+
)
|
app/utils/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from app.utils import sse_tool_handler, reload_config, logger
|
| 5 |
+
|
| 6 |
+
__all__ = ["sse_tool_handler", "reload_config", "logger"]
|
app/utils/logger.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
# Global logger instance
|
| 9 |
+
app_logger = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def setup_logger(log_dir, log_retention_days=7, log_rotation="1 day", debug_mode=False):
|
| 13 |
+
"""
|
| 14 |
+
Create a logger instance
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
log_dir (str): 日志目录
|
| 18 |
+
log_retention_days (int): 日志保留天数
|
| 19 |
+
log_rotation (str): 日志轮转间隔
|
| 20 |
+
debug_mode (bool): 是否开启调试模式
|
| 21 |
+
"""
|
| 22 |
+
global app_logger
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
logger.remove()
|
| 26 |
+
|
| 27 |
+
log_level = "DEBUG" if debug_mode else "INFO"
|
| 28 |
+
|
| 29 |
+
console_format = (
|
| 30 |
+
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>"
|
| 31 |
+
if not debug_mode
|
| 32 |
+
else "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
| 33 |
+
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | <level>{message}</level>"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
logger.add(sys.stderr, level=log_level, format=console_format, colorize=True)
|
| 37 |
+
|
| 38 |
+
if debug_mode:
|
| 39 |
+
log_path = Path(log_dir)
|
| 40 |
+
log_path.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
log_file = log_path / "{time:YYYY-MM-DD}.log"
|
| 43 |
+
file_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"
|
| 44 |
+
|
| 45 |
+
logger.add(
|
| 46 |
+
str(log_file),
|
| 47 |
+
level=log_level,
|
| 48 |
+
format=file_format,
|
| 49 |
+
rotation=log_rotation,
|
| 50 |
+
retention=f"{log_retention_days} days",
|
| 51 |
+
encoding="utf-8",
|
| 52 |
+
compression="zip",
|
| 53 |
+
enqueue=True,
|
| 54 |
+
catch=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
app_logger = logger
|
| 58 |
+
|
| 59 |
+
return logger
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.remove()
|
| 63 |
+
logger.add(sys.stderr, level="ERROR")
|
| 64 |
+
logger.error(f"日志系统配置失败: {e}")
|
| 65 |
+
raise
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_logger():
|
| 69 |
+
"""Get the logger instance"""
|
| 70 |
+
global app_logger
|
| 71 |
+
if app_logger is None:
|
| 72 |
+
# 如果没有设置过logger,使用默认配置
|
| 73 |
+
logger.remove() # 移除所有现有处理器
|
| 74 |
+
logger.add(sys.stderr, level="INFO", format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | <level>{message}</level>")
|
| 75 |
+
app_logger = logger
|
| 76 |
+
return app_logger
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
"""Test the logger"""
|
| 81 |
+
import tempfile
|
| 82 |
+
|
| 83 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 84 |
+
try:
|
| 85 |
+
setup_logger(temp_dir, debug_mode=True)
|
| 86 |
+
|
| 87 |
+
logger.debug("这是一条调试日志")
|
| 88 |
+
logger.info("这是一条信息日志")
|
| 89 |
+
logger.warning("这是一条警告日志")
|
| 90 |
+
logger.error("这是一条错误日志")
|
| 91 |
+
logger.critical("这是一条严重日志")
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
1 / 0
|
| 95 |
+
except ZeroDivisionError:
|
| 96 |
+
logger.exception("发生了除零异常")
|
| 97 |
+
|
| 98 |
+
print("✅ 日志测试完成")
|
| 99 |
+
|
| 100 |
+
logger.remove()
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(f"❌ 日志测试失败: {e}")
|
| 104 |
+
logger.remove()
|
| 105 |
+
raise
|
app/utils/reload_config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
热重载配置模块
|
| 6 |
+
定义 Granian 服务器热重载时需要忽略的目录和文件模式
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# 忽略的目录列表
|
| 10 |
+
RELOAD_IGNORE_DIRS = [
|
| 11 |
+
"logs", # 忽略日志目录
|
| 12 |
+
"storage", # 忽略存储目录
|
| 13 |
+
"__pycache__", # 忽略 Python 缓存
|
| 14 |
+
".git", # 忽略 git 目录
|
| 15 |
+
"node_modules", # 忽略 node_modules
|
| 16 |
+
"migrations", # 忽略数据库迁移目录
|
| 17 |
+
".pytest_cache", # 忽略 pytest 缓存
|
| 18 |
+
".venv", # 忽略虚拟环境
|
| 19 |
+
"venv", # 忽略虚拟环境
|
| 20 |
+
"env", # 忽略环境目录
|
| 21 |
+
".mypy_cache", # 忽略 mypy 缓存
|
| 22 |
+
".ruff_cache", # 忽略 ruff 缓存
|
| 23 |
+
"dist", # 忽略构建分发目录
|
| 24 |
+
"build", # 忽略构建目录
|
| 25 |
+
".coverage", # 忽略测试覆盖率文件
|
| 26 |
+
"htmlcov", # 忽略覆盖率报告目录
|
| 27 |
+
"tests", # 忽略测试目录
|
| 28 |
+
"z-ai2api-server.pid", # 忽略 PID 文件
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
# 忽略的文件模式(正则表达式)
|
| 32 |
+
RELOAD_IGNORE_PATTERNS = [
|
| 33 |
+
# 日志文件
|
| 34 |
+
r".*\.log$",
|
| 35 |
+
r".*\.log\.\d+$",
|
| 36 |
+
# 数据库文件
|
| 37 |
+
r".*\.sqlite3.*",
|
| 38 |
+
r".*\.db$",
|
| 39 |
+
r".*\.db-.*$",
|
| 40 |
+
# Python 相关
|
| 41 |
+
r".*\.pyc$",
|
| 42 |
+
r".*\.pyo$",
|
| 43 |
+
r".*\.pyd$",
|
| 44 |
+
# 临时文件
|
| 45 |
+
r".*\.tmp$",
|
| 46 |
+
r".*\.temp$",
|
| 47 |
+
r".*\.swp$",
|
| 48 |
+
r".*\.swo$",
|
| 49 |
+
r".*~$",
|
| 50 |
+
# 系统文件
|
| 51 |
+
r".*\.DS_Store$",
|
| 52 |
+
r".*Thumbs\.db$",
|
| 53 |
+
r".*\.directory$",
|
| 54 |
+
# 编辑器文件
|
| 55 |
+
r".*\.vscode.*",
|
| 56 |
+
r".*\.idea.*",
|
| 57 |
+
# 测试和覆盖率
|
| 58 |
+
r".*\.coverage$",
|
| 59 |
+
r".*\.pytest_cache.*",
|
| 60 |
+
# 构建文件
|
| 61 |
+
r".*\.egg-info.*",
|
| 62 |
+
r".*\.wheel$",
|
| 63 |
+
r".*\.whl$",
|
| 64 |
+
# 版本控制
|
| 65 |
+
r".*\.git.*",
|
| 66 |
+
r".*\.gitignore$",
|
| 67 |
+
r".*\.gitkeep$",
|
| 68 |
+
# 配置文件备份
|
| 69 |
+
r".*\.bak$",
|
| 70 |
+
r".*\.backup$",
|
| 71 |
+
r".*\.orig$",
|
| 72 |
+
# 锁文件
|
| 73 |
+
r".*\.lock$",
|
| 74 |
+
r".*\.pid$",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
# 监视的路径(只监视应用相关代码)
|
| 78 |
+
RELOAD_WATCH_PATHS = [
|
| 79 |
+
"app", # 应用主目录
|
| 80 |
+
"main.py", # 主入口文件
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
# 热重载配置
|
| 84 |
+
RELOAD_CONFIG = {
|
| 85 |
+
"reload_ignore_dirs": RELOAD_IGNORE_DIRS,
|
| 86 |
+
"reload_ignore_patterns": RELOAD_IGNORE_PATTERNS,
|
| 87 |
+
"reload_paths": RELOAD_WATCH_PATHS,
|
| 88 |
+
"reload_tick": 100, # 监视频率(毫秒)
|
| 89 |
+
}
|
app/utils/sse_tool_handler.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
SSE Tool Handler
|
| 6 |
+
|
| 7 |
+
处理 Z.AI SSE 流数据并转换为 OpenAI 兼容格式的工具调用处理器。
|
| 8 |
+
|
| 9 |
+
主要功能:
|
| 10 |
+
- 解析 glm_block 格式的工具调用
|
| 11 |
+
- 从 metadata.arguments 提取完整参数
|
| 12 |
+
- 支持多阶段处理:thinking → tool_call → other → answer
|
| 13 |
+
- 输出符合 OpenAI API 规范的流式响应
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import time
|
| 18 |
+
from typing import Dict, Any, Generator
|
| 19 |
+
from enum import Enum
|
| 20 |
+
|
| 21 |
+
from app.utils.logger import get_logger
|
| 22 |
+
|
| 23 |
+
logger = get_logger()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SSEPhase(Enum):
|
| 27 |
+
"""SSE 处理阶段枚举"""
|
| 28 |
+
THINKING = "thinking"
|
| 29 |
+
TOOL_CALL = "tool_call"
|
| 30 |
+
OTHER = "other"
|
| 31 |
+
ANSWER = "answer"
|
| 32 |
+
DONE = "done"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SSEToolHandler:
|
| 36 |
+
"""SSE 工具调用处理器"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, model: str, stream: bool = True):
|
| 39 |
+
self.model = model
|
| 40 |
+
self.stream = stream
|
| 41 |
+
|
| 42 |
+
# 状态管理
|
| 43 |
+
self.current_phase = None
|
| 44 |
+
self.has_tool_call = False
|
| 45 |
+
|
| 46 |
+
# 工具调用状态
|
| 47 |
+
self.tool_id = ""
|
| 48 |
+
self.tool_name = ""
|
| 49 |
+
self.tool_args = ""
|
| 50 |
+
self.tool_call_usage = {}
|
| 51 |
+
self.content_index = 0 # 工具调用索引
|
| 52 |
+
|
| 53 |
+
# 性能优化:内容缓冲
|
| 54 |
+
self.content_buffer = ""
|
| 55 |
+
self.buffer_size = 0
|
| 56 |
+
self.last_flush_time = time.time()
|
| 57 |
+
self.flush_interval = 0.05 # 50ms 刷新间隔
|
| 58 |
+
self.max_buffer_size = 100 # 最大缓冲字符数
|
| 59 |
+
|
| 60 |
+
logger.debug(f"🔧 初始化工具处理器: model={model}, stream={stream}")
|
| 61 |
+
|
| 62 |
+
def process_sse_chunk(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]:
|
| 63 |
+
"""
|
| 64 |
+
处理 SSE 数据块,返回 OpenAI 格式的流式响应
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
chunk_data: Z.AI SSE 数据块
|
| 68 |
+
|
| 69 |
+
Yields:
|
| 70 |
+
str: OpenAI 格式的 SSE 响应行
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
phase = chunk_data.get("phase")
|
| 74 |
+
edit_content = chunk_data.get("edit_content", "")
|
| 75 |
+
delta_content = chunk_data.get("delta_content", "")
|
| 76 |
+
edit_index = chunk_data.get("edit_index")
|
| 77 |
+
usage = chunk_data.get("usage", {})
|
| 78 |
+
|
| 79 |
+
# 数据验证
|
| 80 |
+
if not phase:
|
| 81 |
+
logger.warning("⚠️ 收到无效的 SSE 块:缺少 phase 字段")
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# 阶段变化检测和日志
|
| 85 |
+
if phase != self.current_phase:
|
| 86 |
+
# 阶段变化时强制刷新缓冲区
|
| 87 |
+
if hasattr(self, 'content_buffer') and self.content_buffer:
|
| 88 |
+
yield from self._flush_content_buffer()
|
| 89 |
+
|
| 90 |
+
logger.info(f"📈 SSE 阶段变化: {self.current_phase} → {phase}")
|
| 91 |
+
content_preview = edit_content or delta_content
|
| 92 |
+
if content_preview:
|
| 93 |
+
logger.debug(f" 📝 内容预览: {content_preview[:1000]}{'...' if len(content_preview) > 1000 else ''}")
|
| 94 |
+
if edit_index is not None:
|
| 95 |
+
logger.debug(f" 📍 edit_index: {edit_index}")
|
| 96 |
+
self.current_phase = phase
|
| 97 |
+
|
| 98 |
+
# 根据阶段处理
|
| 99 |
+
if phase == SSEPhase.THINKING.value:
|
| 100 |
+
yield from self._process_thinking_phase(delta_content)
|
| 101 |
+
|
| 102 |
+
elif phase == SSEPhase.TOOL_CALL.value:
|
| 103 |
+
yield from self._process_tool_call_phase(edit_content)
|
| 104 |
+
|
| 105 |
+
elif phase == SSEPhase.OTHER.value:
|
| 106 |
+
yield from self._process_other_phase(usage, edit_content)
|
| 107 |
+
|
| 108 |
+
elif phase == SSEPhase.ANSWER.value:
|
| 109 |
+
yield from self._process_answer_phase(delta_content)
|
| 110 |
+
|
| 111 |
+
elif phase == SSEPhase.DONE.value:
|
| 112 |
+
yield from self._process_done_phase(chunk_data)
|
| 113 |
+
else:
|
| 114 |
+
logger.warning(f"⚠️ 未知的 SSE 阶段: {phase}")
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"❌ 处理 SSE 块时发生错误: {e}")
|
| 118 |
+
logger.debug(f" 📦 错误块数据: {chunk_data}")
|
| 119 |
+
# 不中断流,继续处理后续块
|
| 120 |
+
|
| 121 |
+
def _process_thinking_phase(self, delta_content: str) -> Generator[str, None, None]:
|
| 122 |
+
"""处理思考阶段"""
|
| 123 |
+
if not delta_content:
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
logger.debug(f"🤔 思考内容: +{len(delta_content)} 字符")
|
| 127 |
+
|
| 128 |
+
# 在流模式下输出思考内容
|
| 129 |
+
if self.stream:
|
| 130 |
+
chunk = self._create_content_chunk(delta_content)
|
| 131 |
+
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
| 132 |
+
|
| 133 |
+
def _process_tool_call_phase(self, edit_content: str) -> Generator[str, None, None]:
|
| 134 |
+
"""处理工具调用阶段"""
|
| 135 |
+
if not edit_content:
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
logger.debug(f"🔧 进入工具调用阶段,内容长度: {len(edit_content)}")
|
| 139 |
+
|
| 140 |
+
# 检测 glm_block 标记
|
| 141 |
+
if "<glm_block " in edit_content:
|
| 142 |
+
yield from self._handle_glm_blocks(edit_content)
|
| 143 |
+
else:
|
| 144 |
+
# 没有 glm_block 标记,可能是参数补充
|
| 145 |
+
if self.has_tool_call:
|
| 146 |
+
# 只累积���数部分,找到第一个 ", "result"" 之前的内容
|
| 147 |
+
result_pos = edit_content.find('", "result"')
|
| 148 |
+
if result_pos > 0:
|
| 149 |
+
param_fragment = edit_content[:result_pos]
|
| 150 |
+
self.tool_args += param_fragment
|
| 151 |
+
logger.debug(f"📦 累积参数片段: {param_fragment}")
|
| 152 |
+
else:
|
| 153 |
+
# 如果没有找到结束标记,累积整个内容(可能是中间片段)
|
| 154 |
+
self.tool_args += edit_content
|
| 155 |
+
logger.debug(f"📦 累积参数片段: {edit_content[:100]}...")
|
| 156 |
+
|
| 157 |
+
def _handle_glm_blocks(self, edit_content: str) -> Generator[str, None, None]:
|
| 158 |
+
"""处理 glm_block 标记的内容"""
|
| 159 |
+
blocks = edit_content.split('<glm_block ')
|
| 160 |
+
logger.debug(f"📦 分割得到 {len(blocks)} 个块")
|
| 161 |
+
|
| 162 |
+
for index, block in enumerate(blocks):
|
| 163 |
+
if not block.strip():
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
if index == 0:
|
| 167 |
+
# 第一个块:提取参数片段
|
| 168 |
+
if self.has_tool_call:
|
| 169 |
+
logger.debug(f"📦 从第一个块提取参数片段")
|
| 170 |
+
# 找到 "result" 的位置,提取之前的参数片段
|
| 171 |
+
result_pos = edit_content.find('"result"')
|
| 172 |
+
if result_pos > 0:
|
| 173 |
+
# 往前退3个字符去掉 ", "
|
| 174 |
+
param_fragment = edit_content[:result_pos - 3]
|
| 175 |
+
self.tool_args += param_fragment
|
| 176 |
+
logger.debug(f"📦 累积参数片段: {param_fragment}")
|
| 177 |
+
else:
|
| 178 |
+
# 没有活跃工具调用,跳过第一个块
|
| 179 |
+
continue
|
| 180 |
+
else:
|
| 181 |
+
# 后续块:处理新工具调用
|
| 182 |
+
if "</glm_block>" not in block:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
# 如果有活跃的工具调用,先完成它
|
| 186 |
+
if self.has_tool_call:
|
| 187 |
+
# 补全参数并完成工具调用
|
| 188 |
+
self.tool_args += '"' # 补全最后的引号
|
| 189 |
+
yield from self._finish_current_tool()
|
| 190 |
+
|
| 191 |
+
# 处理新工具调用
|
| 192 |
+
yield from self._process_metadata_block(block)
|
| 193 |
+
|
| 194 |
+
def _process_metadata_block(self, block: str) -> Generator[str, None, None]:
|
| 195 |
+
"""处理包含工具元数据的块"""
|
| 196 |
+
try:
|
| 197 |
+
# 提取 JSON 内容
|
| 198 |
+
start_pos = block.find('>')
|
| 199 |
+
end_pos = block.rfind('</glm_block>')
|
| 200 |
+
|
| 201 |
+
if start_pos == -1 or end_pos == -1:
|
| 202 |
+
logger.warning(f"❌ 无法找到 JSON 内容边界: {block[:1000]}...")
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
json_content = block[start_pos + 1:end_pos]
|
| 206 |
+
logger.debug(f"📦 提取的 JSON 内容: {json_content[:1000]}...")
|
| 207 |
+
|
| 208 |
+
# 解析工具元数据
|
| 209 |
+
metadata_obj = json.loads(json_content)
|
| 210 |
+
|
| 211 |
+
if "data" in metadata_obj and "metadata" in metadata_obj["data"]:
|
| 212 |
+
metadata = metadata_obj["data"]["metadata"]
|
| 213 |
+
|
| 214 |
+
# 开始新的工具调用
|
| 215 |
+
self.tool_id = metadata.get("id", f"call_{int(time.time() * 1000000)}")
|
| 216 |
+
self.tool_name = metadata.get("name", "unknown")
|
| 217 |
+
self.has_tool_call = True
|
| 218 |
+
|
| 219 |
+
# 只有在这是第二个及以后的工具调用时才递增 index
|
| 220 |
+
# 第一个工具调用应该使用 index 0
|
| 221 |
+
|
| 222 |
+
# 从 metadata.arguments 获取参数起始部分
|
| 223 |
+
if "arguments" in metadata:
|
| 224 |
+
arguments_str = metadata["arguments"]
|
| 225 |
+
# 去掉最后一个字符
|
| 226 |
+
self.tool_args = arguments_str[:-1] if arguments_str.endswith('"') else arguments_str
|
| 227 |
+
logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 初始参数: {self.tool_args}")
|
| 228 |
+
else:
|
| 229 |
+
self.tool_args = "{}"
|
| 230 |
+
logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 空参数")
|
| 231 |
+
|
| 232 |
+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
| 233 |
+
logger.error(f"❌ 解析工具元数据失败: {e}, 块内容: {block[:1000]}...")
|
| 234 |
+
|
| 235 |
+
# 确保返回生成器(即使为空)
|
| 236 |
+
if False: # 永远不会执行,但确保函数是生成器
|
| 237 |
+
yield
|
| 238 |
+
|
| 239 |
+
def _process_other_phase(self, usage: Dict[str, Any], edit_content: str = "") -> Generator[str, None, None]:
|
| 240 |
+
"""处理其他阶段"""
|
| 241 |
+
# 保存使用统计信息
|
| 242 |
+
if usage:
|
| 243 |
+
self.tool_call_usage = usage
|
| 244 |
+
logger.debug(f"📊 保存使用统计: {usage}")
|
| 245 |
+
|
| 246 |
+
# 工具调用完成判断:检测到 "null," 开头的 edit_content
|
| 247 |
+
if self.has_tool_call and edit_content and edit_content.startswith("null,"):
|
| 248 |
+
logger.info(f"🏁 检测到工具调用结束标记")
|
| 249 |
+
|
| 250 |
+
# 完成当前工具调用
|
| 251 |
+
yield from self._finish_current_tool()
|
| 252 |
+
|
| 253 |
+
# 发��流结束标记
|
| 254 |
+
if self.stream:
|
| 255 |
+
yield "data: [DONE]\n\n"
|
| 256 |
+
|
| 257 |
+
# 重置状态
|
| 258 |
+
self._reset_all_state()
|
| 259 |
+
|
| 260 |
+
def _process_answer_phase(self, delta_content: str) -> Generator[str, None, None]:
|
| 261 |
+
"""处理回答阶段(优化版本)"""
|
| 262 |
+
if not delta_content:
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
logger.info(f"📝 工具处理器收到答案内容: {delta_content[:50]}...")
|
| 266 |
+
|
| 267 |
+
# 添加到缓冲区
|
| 268 |
+
self.content_buffer += delta_content
|
| 269 |
+
self.buffer_size += len(delta_content)
|
| 270 |
+
|
| 271 |
+
current_time = time.time()
|
| 272 |
+
time_since_last_flush = current_time - self.last_flush_time
|
| 273 |
+
|
| 274 |
+
# 检查是否需要刷新缓冲区
|
| 275 |
+
should_flush = (
|
| 276 |
+
self.buffer_size >= self.max_buffer_size or # 缓冲区满了
|
| 277 |
+
time_since_last_flush >= self.flush_interval or # 时间间隔到了
|
| 278 |
+
'\n' in delta_content or # 包含换行符
|
| 279 |
+
'。' in delta_content or '!' in delta_content or '?' in delta_content # 包含句子结束符
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if should_flush and self.content_buffer:
|
| 283 |
+
yield from self._flush_content_buffer()
|
| 284 |
+
|
| 285 |
+
def _flush_content_buffer(self) -> Generator[str, None, None]:
|
| 286 |
+
"""刷新内容缓冲区"""
|
| 287 |
+
if not self.content_buffer:
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
+
logger.info(f"💬 工具处理器刷新缓冲区: {self.buffer_size} 字符 - {self.content_buffer[:50]}...")
|
| 291 |
+
|
| 292 |
+
if self.stream:
|
| 293 |
+
chunk = self._create_content_chunk(self.content_buffer)
|
| 294 |
+
output_data = f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
| 295 |
+
logger.info(f"➡️ 工具处理器输出: {output_data[:100]}...")
|
| 296 |
+
yield output_data
|
| 297 |
+
|
| 298 |
+
# 清空缓冲区
|
| 299 |
+
self.content_buffer = ""
|
| 300 |
+
self.buffer_size = 0
|
| 301 |
+
self.last_flush_time = time.time()
|
| 302 |
+
|
| 303 |
+
def _process_done_phase(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]:
|
| 304 |
+
"""处理完成阶段"""
|
| 305 |
+
logger.info("🏁 对话完成")
|
| 306 |
+
|
| 307 |
+
# 先刷新任何剩余的缓冲内容
|
| 308 |
+
if self.content_buffer:
|
| 309 |
+
yield from self._flush_content_buffer()
|
| 310 |
+
|
| 311 |
+
# 完成任何未完成的工具调用
|
| 312 |
+
if self.has_tool_call:
|
| 313 |
+
yield from self._finish_current_tool()
|
| 314 |
+
|
| 315 |
+
# 发送流结束标记
|
| 316 |
+
if self.stream:
|
| 317 |
+
# 创建最终的完成块
|
| 318 |
+
final_chunk = {
|
| 319 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 320 |
+
"object": "chat.completion.chunk",
|
| 321 |
+
"created": int(time.time()),
|
| 322 |
+
"model": self.model,
|
| 323 |
+
"choices": [{
|
| 324 |
+
"index": 0,
|
| 325 |
+
"delta": {},
|
| 326 |
+
"finish_reason": "stop"
|
| 327 |
+
}]
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
# 如果有 usage 信息,添加到最终块中
|
| 331 |
+
if "usage" in chunk_data:
|
| 332 |
+
final_chunk["usage"] = chunk_data["usage"]
|
| 333 |
+
|
| 334 |
+
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
|
| 335 |
+
yield "data: [DONE]\n\n"
|
| 336 |
+
|
| 337 |
+
# 重置所有状态
|
| 338 |
+
self._reset_all_state()
|
| 339 |
+
|
| 340 |
+
def _finish_current_tool(self) -> Generator[str, None, None]:
|
| 341 |
+
"""完成当前工具调用"""
|
| 342 |
+
if not self.has_tool_call:
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
# 修复参数格式
|
| 346 |
+
fixed_args = self._fix_tool_arguments(self.tool_args)
|
| 347 |
+
logger.debug(f"✅ 完成工具调用: {self.tool_name}, 参数: {fixed_args}")
|
| 348 |
+
|
| 349 |
+
# 输出工具调用(开始 + 参数 + 完成)
|
| 350 |
+
if self.stream:
|
| 351 |
+
# 发送工具开始块
|
| 352 |
+
start_chunk = self._create_tool_start_chunk()
|
| 353 |
+
yield f"data: {json.dumps(start_chunk, ensure_ascii=False)}\n\n"
|
| 354 |
+
|
| 355 |
+
# 发送参数块
|
| 356 |
+
args_chunk = self._create_tool_arguments_chunk(fixed_args)
|
| 357 |
+
yield f"data: {json.dumps(args_chunk, ensure_ascii=False)}\n\n"
|
| 358 |
+
|
| 359 |
+
# 发送完成块
|
| 360 |
+
finish_chunk = self._create_tool_finish_chunk()
|
| 361 |
+
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
|
| 362 |
+
|
| 363 |
+
# 重置工具状态
|
| 364 |
+
self._reset_tool_state()
|
| 365 |
+
|
| 366 |
+
def _fix_tool_arguments(self, raw_args: str) -> str:
|
| 367 |
+
"""使用 json-repair 库修复工具参数格式"""
|
| 368 |
+
if not raw_args or raw_args == "{}":
|
| 369 |
+
return "{}"
|
| 370 |
+
|
| 371 |
+
logger.debug(f"🔧 开始修复参数: {raw_args[:1000]}{'...' if len(raw_args) > 1000 else ''}")
|
| 372 |
+
|
| 373 |
+
# 统一的修复流程:预处理 -> json-repair -> 后处理
|
| 374 |
+
try:
|
| 375 |
+
# 1. 预处理:只处理 json-repair 无法处理的问题
|
| 376 |
+
processed_args = self._preprocess_json_string(raw_args.strip())
|
| 377 |
+
|
| 378 |
+
# 2. 使用 json-repair 进行主要修复
|
| 379 |
+
from json_repair import repair_json
|
| 380 |
+
repaired_json = repair_json(processed_args)
|
| 381 |
+
logger.debug(f"🔧 json-repair 修复结果: {repaired_json}")
|
| 382 |
+
|
| 383 |
+
# 3. 解析并后处理
|
| 384 |
+
args_obj = json.loads(repaired_json)
|
| 385 |
+
args_obj = self._post_process_args(args_obj)
|
| 386 |
+
|
| 387 |
+
# 4. 生成最终结果
|
| 388 |
+
fixed_result = json.dumps(args_obj, ensure_ascii=False)
|
| 389 |
+
|
| 390 |
+
return fixed_result
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
logger.error(f"❌ JSON 修复失败: {e}, 原始参数: {raw_args[:1000]}..., 使用空参数")
|
| 394 |
+
return "{}"
|
| 395 |
+
|
| 396 |
+
def _post_process_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
|
| 397 |
+
"""统一的后处理方法"""
|
| 398 |
+
# 修复路径中的过度转义
|
| 399 |
+
args_obj = self._fix_path_escaping_in_args(args_obj)
|
| 400 |
+
|
| 401 |
+
# 修复命令中的多余引号
|
| 402 |
+
args_obj = self._fix_command_quotes(args_obj)
|
| 403 |
+
|
| 404 |
+
return args_obj
|
| 405 |
+
|
| 406 |
+
def _preprocess_json_string(self, text: str) -> str:
|
| 407 |
+
"""预处理 JSON 字符串,只处理 json-repair 无法处理的问题"""
|
| 408 |
+
import re
|
| 409 |
+
|
| 410 |
+
# 只保留 json-repair 无法处理的预处理步骤
|
| 411 |
+
|
| 412 |
+
# 1. 修复缺少开始括号的情况(json-repair 无法处理)
|
| 413 |
+
if not text.startswith('{') and text.endswith('}'):
|
| 414 |
+
text = '{' + text
|
| 415 |
+
logger.debug(f"🔧 补全开始括号")
|
| 416 |
+
|
| 417 |
+
# 2. 修复末尾多余的反斜杠和引号(json-repair 可能处理不当)
|
| 418 |
+
# 匹配模式:字符串值末尾的 \" 后面跟着 } 或 ,
|
| 419 |
+
# 例如:{"url":"https://www.bilibili.com\"} -> {"url":"https://www.bilibili.com"}
|
| 420 |
+
# 例如:{"url":"https://www.bilibili.com\",} -> {"url":"https://www.bilibili.com",}
|
| 421 |
+
pattern = r'([^\\])\\"([}\s,])'
|
| 422 |
+
if re.search(pattern, text):
|
| 423 |
+
text = re.sub(pattern, r'\1"\2', text)
|
| 424 |
+
logger.debug(f"🔧 修复末尾多余的反斜杠")
|
| 425 |
+
|
| 426 |
+
return text
|
| 427 |
+
|
| 428 |
+
def _fix_path_escaping_in_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
|
| 429 |
+
"""修复参数对象中路径的过度转义问题"""
|
| 430 |
+
import re
|
| 431 |
+
|
| 432 |
+
# 需要检查的路径字段
|
| 433 |
+
path_fields = ['file_path', 'path', 'directory', 'folder']
|
| 434 |
+
|
| 435 |
+
for field in path_fields:
|
| 436 |
+
if field in args_obj and isinstance(args_obj[field], str):
|
| 437 |
+
path_value = args_obj[field]
|
| 438 |
+
|
| 439 |
+
# 检查是否是Windows路径且包含过度转义
|
| 440 |
+
if path_value.startswith('C:') and '\\\\' in path_value:
|
| 441 |
+
logger.debug(f"🔍 检查路径字段 {field}: {repr(path_value)}")
|
| 442 |
+
|
| 443 |
+
# 分析路径结构:正常路径应该是 C:\Users\...
|
| 444 |
+
# 但过度转义的路径可能是 C:\Users\\Documents(多了一个反斜杠)
|
| 445 |
+
# 我们需要找到不正常的双反斜杠模式并修复
|
| 446 |
+
|
| 447 |
+
# 先检查是否有不正常的双反斜杠(不在路径开头)
|
| 448 |
+
# 正常:C:\Users\Documents
|
| 449 |
+
# 异常:C:\Users\\Documents 或 C:\Users\\\\Documents
|
| 450 |
+
|
| 451 |
+
# 使用更精确的模式:匹配路径分隔符后的额外反斜杠
|
| 452 |
+
# 但要保留正常的路径分隔符
|
| 453 |
+
fixed_path = path_value
|
| 454 |
+
|
| 455 |
+
# 检查是否有连续的多个反斜杠(超过正常的路径分隔符)
|
| 456 |
+
if '\\\\' in path_value:
|
| 457 |
+
# 计算反斜杠的数量,如果超过正常数量就修复
|
| 458 |
+
parts = path_value.split('\\')
|
| 459 |
+
# 重新组装路径,去除空的部分(由多余的反斜杠造成)
|
| 460 |
+
clean_parts = [part for part in parts if part]
|
| 461 |
+
if len(clean_parts) > 1:
|
| 462 |
+
fixed_path = '\\'.join(clean_parts)
|
| 463 |
+
|
| 464 |
+
logger.debug(f"🔍 修复后路径: {repr(fixed_path)}")
|
| 465 |
+
|
| 466 |
+
if fixed_path != path_value:
|
| 467 |
+
args_obj[field] = fixed_path
|
| 468 |
+
logger.debug(f"🔧 修复字段 {field} 的路径转义: {path_value} -> {fixed_path}")
|
| 469 |
+
else:
|
| 470 |
+
logger.debug(f"🔍 路径无需修复: {path_value}")
|
| 471 |
+
|
| 472 |
+
return args_obj
|
| 473 |
+
|
| 474 |
+
def _fix_command_quotes(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
|
| 475 |
+
"""修复命令中的多余引号问题"""
|
| 476 |
+
import re
|
| 477 |
+
|
| 478 |
+
# 检查命令字段
|
| 479 |
+
if 'command' in args_obj and isinstance(args_obj['command'], str):
|
| 480 |
+
command = args_obj['command']
|
| 481 |
+
|
| 482 |
+
# 检查是否以双引号结尾(多余的引号)
|
| 483 |
+
if command.endswith('""'):
|
| 484 |
+
logger.debug(f"🔧 发现命令末尾多余引号: {command}")
|
| 485 |
+
# 移除最后一个多余的引号
|
| 486 |
+
fixed_command = command[:-1]
|
| 487 |
+
args_obj['command'] = fixed_command
|
| 488 |
+
logger.debug(f"🔧 修复命令引号: {command} -> {fixed_command}")
|
| 489 |
+
|
| 490 |
+
# 检查其他可能的引号问题
|
| 491 |
+
# 例如:路径末尾的 \"" 模式
|
| 492 |
+
elif re.search(r'\\""+$', command):
|
| 493 |
+
logger.debug(f"🔧 发现命令末尾引号模式问题: {command}")
|
| 494 |
+
# 修复路径末尾的引号问题
|
| 495 |
+
fixed_command = re.sub(r'\\""+$', '\\"', command)
|
| 496 |
+
args_obj['command'] = fixed_command
|
| 497 |
+
logger.debug(f"🔧 修复命令引号模式: {command} -> {fixed_command}")
|
| 498 |
+
|
| 499 |
+
return args_obj
|
| 500 |
+
|
| 501 |
+
def _create_content_chunk(self, content: str) -> Dict[str, Any]:
|
| 502 |
+
"""创建内容块"""
|
| 503 |
+
return {
|
| 504 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 505 |
+
"object": "chat.completion.chunk",
|
| 506 |
+
"created": int(time.time()),
|
| 507 |
+
"model": self.model,
|
| 508 |
+
"choices": [{
|
| 509 |
+
"index": 0,
|
| 510 |
+
"delta": {
|
| 511 |
+
"role": "assistant",
|
| 512 |
+
"content": content
|
| 513 |
+
},
|
| 514 |
+
"finish_reason": None
|
| 515 |
+
}]
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
def _create_tool_start_chunk(self) -> Dict[str, Any]:
|
| 519 |
+
"""创建工具开始块"""
|
| 520 |
+
return {
|
| 521 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 522 |
+
"object": "chat.completion.chunk",
|
| 523 |
+
"created": int(time.time()),
|
| 524 |
+
"model": self.model,
|
| 525 |
+
"choices": [{
|
| 526 |
+
"index": 0,
|
| 527 |
+
"delta": {
|
| 528 |
+
"role": "assistant",
|
| 529 |
+
"tool_calls": [{
|
| 530 |
+
"index": self.content_index,
|
| 531 |
+
"id": self.tool_id,
|
| 532 |
+
"type": "function",
|
| 533 |
+
"function": {
|
| 534 |
+
"name": self.tool_name,
|
| 535 |
+
"arguments": ""
|
| 536 |
+
}
|
| 537 |
+
}]
|
| 538 |
+
},
|
| 539 |
+
"finish_reason": None
|
| 540 |
+
}]
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
def _create_tool_arguments_chunk(self, arguments: str) -> Dict[str, Any]:
|
| 544 |
+
"""创建工具参数块"""
|
| 545 |
+
return {
|
| 546 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 547 |
+
"object": "chat.completion.chunk",
|
| 548 |
+
"created": int(time.time()),
|
| 549 |
+
"model": self.model,
|
| 550 |
+
"choices": [{
|
| 551 |
+
"index": 0,
|
| 552 |
+
"delta": {
|
| 553 |
+
"tool_calls": [{
|
| 554 |
+
"index": self.content_index,
|
| 555 |
+
"id": self.tool_id,
|
| 556 |
+
"function": {
|
| 557 |
+
"arguments": arguments
|
| 558 |
+
}
|
| 559 |
+
}]
|
| 560 |
+
},
|
| 561 |
+
"finish_reason": None
|
| 562 |
+
}]
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
def _create_tool_finish_chunk(self) -> Dict[str, Any]:
|
| 566 |
+
"""创建工具完成块"""
|
| 567 |
+
chunk = {
|
| 568 |
+
"id": f"chatcmpl-{int(time.time())}",
|
| 569 |
+
"object": "chat.completion.chunk",
|
| 570 |
+
"created": int(time.time()),
|
| 571 |
+
"model": self.model,
|
| 572 |
+
"choices": [{
|
| 573 |
+
"index": 0,
|
| 574 |
+
"delta": {
|
| 575 |
+
"tool_calls": []
|
| 576 |
+
},
|
| 577 |
+
"finish_reason": "tool_calls"
|
| 578 |
+
}]
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
# 添加使用统计(如果有)
|
| 582 |
+
if self.tool_call_usage:
|
| 583 |
+
chunk["usage"] = self.tool_call_usage
|
| 584 |
+
|
| 585 |
+
return chunk
|
| 586 |
+
|
| 587 |
+
def _reset_tool_state(self):
|
| 588 |
+
"""重置工具状态"""
|
| 589 |
+
self.tool_id = ""
|
| 590 |
+
self.tool_name = ""
|
| 591 |
+
self.tool_args = ""
|
| 592 |
+
self.has_tool_call = False
|
| 593 |
+
# content_index 在单次对话中应该保持不变,只有在新的工具调用开始时才递增
|
| 594 |
+
|
| 595 |
+
def _reset_all_state(self):
|
| 596 |
+
"""重置所有状态"""
|
| 597 |
+
# 先刷新任何剩余的缓冲内容
|
| 598 |
+
if hasattr(self, 'content_buffer') and self.content_buffer:
|
| 599 |
+
list(self._flush_content_buffer()) # 消费生成器
|
| 600 |
+
|
| 601 |
+
self._reset_tool_state()
|
| 602 |
+
self.current_phase = None
|
| 603 |
+
self.tool_call_usage = {}
|
| 604 |
+
|
| 605 |
+
# 重置缓冲区
|
| 606 |
+
self.content_buffer = ""
|
| 607 |
+
self.buffer_size = 0
|
| 608 |
+
self.last_flush_time = time.time()
|
| 609 |
+
|
| 610 |
+
# content_index 重置为 0,为下一轮对话做准备
|
| 611 |
+
self.content_index = 0
|
| 612 |
+
logger.debug("🔄 重置所有处理器状态")
|
app/utils/token_pool.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Token池管理器
|
| 6 |
+
实现AUTH_TOKEN的轮询机制,提供负载均衡和容错功能
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import time
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from threading import Lock
|
| 14 |
+
import httpx
|
| 15 |
+
|
| 16 |
+
from app.utils.logger import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TokenStatus:
|
| 21 |
+
"""Token状态信息"""
|
| 22 |
+
token: str
|
| 23 |
+
is_available: bool = True
|
| 24 |
+
failure_count: int = 0
|
| 25 |
+
last_failure_time: float = 0.0
|
| 26 |
+
last_success_time: float = 0.0
|
| 27 |
+
total_requests: int = 0
|
| 28 |
+
successful_requests: int = 0
|
| 29 |
+
token_type: str = "unknown" # "user", "guest", "unknown"
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def success_rate(self) -> float:
|
| 33 |
+
"""成功率"""
|
| 34 |
+
if self.total_requests == 0:
|
| 35 |
+
return 1.0
|
| 36 |
+
return self.successful_requests / self.total_requests
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def is_healthy(self) -> bool:
|
| 40 |
+
"""
|
| 41 |
+
是否健康
|
| 42 |
+
|
| 43 |
+
健康的定义:
|
| 44 |
+
1. 必须是认证用户token (token_type = "user")
|
| 45 |
+
2. 当前可用 (is_available = True)
|
| 46 |
+
3. 成功率 >= 50% 或者总请求数 <= 3(新token容错)
|
| 47 |
+
|
| 48 |
+
注意:guest token不应该在AUTH_TOKENS中
|
| 49 |
+
"""
|
| 50 |
+
# guest token永远不健康
|
| 51 |
+
if self.token_type == "guest":
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
# 未知类型token不健康
|
| 55 |
+
if self.token_type != "user":
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
# 不可用的token不健康
|
| 59 |
+
if not self.is_available:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
# 对于认证用户token,基于成功率判断
|
| 63 |
+
# 新token或请求数很少时,给予容错
|
| 64 |
+
if self.total_requests <= 3:
|
| 65 |
+
return self.failure_count == 0
|
| 66 |
+
|
| 67 |
+
# 基于成功率判断健康状态
|
| 68 |
+
return self.success_rate >= 0.5
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TokenPool:
|
| 72 |
+
"""Token池管理器"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, tokens: List[str], failure_threshold: int = 3, recovery_timeout: int = 1800):
|
| 75 |
+
"""
|
| 76 |
+
初始化Token池
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
tokens: token列表
|
| 80 |
+
failure_threshold: 失败阈值,超过此次数将标记为不可用
|
| 81 |
+
recovery_timeout: 恢复超时时间(秒),失败token在此时间后重新尝试
|
| 82 |
+
"""
|
| 83 |
+
self.failure_threshold = failure_threshold
|
| 84 |
+
self.recovery_timeout = recovery_timeout
|
| 85 |
+
self._lock = Lock()
|
| 86 |
+
self._current_index = 0
|
| 87 |
+
|
| 88 |
+
# 初始化token状态
|
| 89 |
+
self.token_statuses: Dict[str, TokenStatus] = {}
|
| 90 |
+
original_count = len(tokens)
|
| 91 |
+
unique_tokens = []
|
| 92 |
+
|
| 93 |
+
# 去重处理
|
| 94 |
+
for token in tokens:
|
| 95 |
+
if token and token not in self.token_statuses: # 过滤空token和重复token
|
| 96 |
+
# 预设为认证用户token,因为这些是用户手动配置的token
|
| 97 |
+
self.token_statuses[token] = TokenStatus(token=token, token_type="user")
|
| 98 |
+
unique_tokens.append(token)
|
| 99 |
+
|
| 100 |
+
duplicate_count = original_count - len(unique_tokens)
|
| 101 |
+
if duplicate_count > 0:
|
| 102 |
+
logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复token,已自动去重")
|
| 103 |
+
|
| 104 |
+
if not self.token_statuses:
|
| 105 |
+
logger.warning("⚠️ Token池为空,将依赖匿名模式")
|
| 106 |
+
# else:
|
| 107 |
+
# logger.info(f"🔧 初始化Token池,共 {len(self.token_statuses)} 个token")
|
| 108 |
+
|
| 109 |
+
def get_next_token(self) -> Optional[str]:
|
| 110 |
+
"""
|
| 111 |
+
获取下一个可用的token(轮询算法)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
可用的token,如果没有可用token则返回None
|
| 115 |
+
"""
|
| 116 |
+
with self._lock:
|
| 117 |
+
if not self.token_statuses:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
available_tokens = self._get_available_tokens()
|
| 121 |
+
if not available_tokens:
|
| 122 |
+
# 尝试恢复过期的失败token
|
| 123 |
+
self._try_recover_failed_tokens()
|
| 124 |
+
available_tokens = self._get_available_tokens()
|
| 125 |
+
|
| 126 |
+
if not available_tokens:
|
| 127 |
+
logger.warning("⚠️ 没有可用的token")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
# 轮询选择token
|
| 131 |
+
token = available_tokens[self._current_index % len(available_tokens)]
|
| 132 |
+
self._current_index = (self._current_index + 1) % len(available_tokens)
|
| 133 |
+
|
| 134 |
+
return token
|
| 135 |
+
|
| 136 |
+
def _get_available_tokens(self) -> List[str]:
|
| 137 |
+
"""
|
| 138 |
+
获取当前可用的认证用户token列表
|
| 139 |
+
|
| 140 |
+
返回满足以下条件的token:
|
| 141 |
+
1. is_available = True (可用状态)
|
| 142 |
+
2. token_type == "user" (认证用户token)
|
| 143 |
+
|
| 144 |
+
这确保轮询机制只会选择有效的认证用户token,跳过匿名用户token
|
| 145 |
+
"""
|
| 146 |
+
available_user_tokens = [
|
| 147 |
+
status.token for status in self.token_statuses.values()
|
| 148 |
+
if status.is_available and status.token_type == "user"
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
# 检查是否有匿名用户token并给出警告
|
| 152 |
+
if not available_user_tokens and self.token_statuses:
|
| 153 |
+
guest_tokens = [
|
| 154 |
+
status.token for status in self.token_statuses.values()
|
| 155 |
+
if status.token_type == "guest"
|
| 156 |
+
]
|
| 157 |
+
if guest_tokens:
|
| 158 |
+
logger.warning(f"⚠️ 检测到 {len(guest_tokens)} 个匿名用户token,轮询机制将跳过这些token")
|
| 159 |
+
|
| 160 |
+
return available_user_tokens
|
| 161 |
+
|
| 162 |
+
def _try_recover_failed_tokens(self):
|
| 163 |
+
"""尝试恢复失败的token"""
|
| 164 |
+
current_time = time.time()
|
| 165 |
+
recovered_count = 0
|
| 166 |
+
|
| 167 |
+
for status in self.token_statuses.values():
|
| 168 |
+
if (not status.is_available and
|
| 169 |
+
current_time - status.last_failure_time > self.recovery_timeout):
|
| 170 |
+
status.is_available = True
|
| 171 |
+
status.failure_count = 0
|
| 172 |
+
recovered_count += 1
|
| 173 |
+
logger.info(f"🔄 恢复失败token: {status.token[:20]}...")
|
| 174 |
+
|
| 175 |
+
if recovered_count > 0:
|
| 176 |
+
logger.info(f"✅ 恢复了 {recovered_count} 个失败的token")
|
| 177 |
+
|
| 178 |
+
def mark_token_success(self, token: str):
|
| 179 |
+
"""标记token使用成功"""
|
| 180 |
+
with self._lock:
|
| 181 |
+
if token in self.token_statuses:
|
| 182 |
+
status = self.token_statuses[token]
|
| 183 |
+
status.total_requests += 1
|
| 184 |
+
status.successful_requests += 1
|
| 185 |
+
status.last_success_time = time.time()
|
| 186 |
+
status.failure_count = 0 # 重置失败计数
|
| 187 |
+
|
| 188 |
+
if not status.is_available:
|
| 189 |
+
status.is_available = True
|
| 190 |
+
logger.info(f"✅ Token恢复可用: {token[:20]}...")
|
| 191 |
+
|
| 192 |
+
def mark_token_failure(self, token: str, error: Exception = None):
|
| 193 |
+
"""标记token使用失败"""
|
| 194 |
+
with self._lock:
|
| 195 |
+
if token in self.token_statuses:
|
| 196 |
+
status = self.token_statuses[token]
|
| 197 |
+
status.total_requests += 1
|
| 198 |
+
status.failure_count += 1
|
| 199 |
+
status.last_failure_time = time.time()
|
| 200 |
+
|
| 201 |
+
if status.failure_count >= self.failure_threshold:
|
| 202 |
+
status.is_available = False
|
| 203 |
+
logger.warning(f"🚫 Token已禁用: {token[:20]}... (失败 {status.failure_count} 次)")
|
| 204 |
+
|
| 205 |
+
def get_pool_status(self) -> Dict:
|
| 206 |
+
"""获取token池状态信息"""
|
| 207 |
+
with self._lock:
|
| 208 |
+
available_count = len(self._get_available_tokens())
|
| 209 |
+
total_count = len(self.token_statuses)
|
| 210 |
+
|
| 211 |
+
# 统计健康token数量
|
| 212 |
+
healthy_count = sum(1 for status in self.token_statuses.values() if status.is_healthy)
|
| 213 |
+
|
| 214 |
+
status_info = {
|
| 215 |
+
"total_tokens": total_count,
|
| 216 |
+
"available_tokens": available_count,
|
| 217 |
+
"unavailable_tokens": total_count - available_count,
|
| 218 |
+
"healthy_tokens": healthy_count,
|
| 219 |
+
"unhealthy_tokens": total_count - healthy_count,
|
| 220 |
+
"current_index": self._current_index,
|
| 221 |
+
"tokens": []
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
for token, status in self.token_statuses.items():
|
| 225 |
+
status_info["tokens"].append({
|
| 226 |
+
"token": f"{token[:10]}...{token[-10:]}",
|
| 227 |
+
"token_type": status.token_type,
|
| 228 |
+
"is_available": status.is_available,
|
| 229 |
+
"failure_count": status.failure_count,
|
| 230 |
+
"success_count": status.successful_requests,
|
| 231 |
+
"success_rate": f"{status.success_rate:.2%}",
|
| 232 |
+
"total_requests": status.total_requests,
|
| 233 |
+
"is_healthy": status.is_healthy,
|
| 234 |
+
"last_failure_time": status.last_failure_time,
|
| 235 |
+
"last_success_time": status.last_success_time
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
return status_info
|
| 239 |
+
|
| 240 |
+
def update_tokens(self, new_tokens: List[str]):
|
| 241 |
+
"""动态更新token列表"""
|
| 242 |
+
with self._lock:
|
| 243 |
+
# 保留现有token的状态信息
|
| 244 |
+
old_statuses = self.token_statuses.copy()
|
| 245 |
+
self.token_statuses.clear()
|
| 246 |
+
|
| 247 |
+
original_count = len(new_tokens)
|
| 248 |
+
unique_tokens = []
|
| 249 |
+
|
| 250 |
+
# 去重并添加新token,保留已存在token的状态
|
| 251 |
+
for token in new_tokens:
|
| 252 |
+
if token and token not in self.token_statuses: # 过滤空token和重复token
|
| 253 |
+
if token in old_statuses:
|
| 254 |
+
self.token_statuses[token] = old_statuses[token]
|
| 255 |
+
else:
|
| 256 |
+
# 预设为认证用户token,因为这些是用户手动配置的token
|
| 257 |
+
self.token_statuses[token] = TokenStatus(token=token, token_type="user")
|
| 258 |
+
unique_tokens.append(token)
|
| 259 |
+
|
| 260 |
+
# 记录去重信息
|
| 261 |
+
duplicate_count = original_count - len(unique_tokens)
|
| 262 |
+
if duplicate_count > 0:
|
| 263 |
+
logger.warning(f"⚠️ 更新时检测到 {duplicate_count} 个重复token,已自动去重")
|
| 264 |
+
|
| 265 |
+
# 重置索引
|
| 266 |
+
self._current_index = 0
|
| 267 |
+
|
| 268 |
+
logger.info(f"🔄 更新Token池,共 {len(self.token_statuses)} 个token")
|
| 269 |
+
|
| 270 |
+
async def health_check_token(self, token: str, auth_url: str = "https://chat.z.ai/api/v1/auths/") -> bool:
|
| 271 |
+
"""
|
| 272 |
+
异步健康检查单个token
|
| 273 |
+
|
| 274 |
+
使用Z.AI认证API验证token的有效性,通过检查响应内容判断token是否有效
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
token: 要检查的token
|
| 278 |
+
auth_url: 认证URL
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
token是否健康
|
| 282 |
+
"""
|
| 283 |
+
try:
|
| 284 |
+
# 构建完整的请求头,模拟真实浏览器请求
|
| 285 |
+
headers = {
|
| 286 |
+
"Accept": "*/*",
|
| 287 |
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
| 288 |
+
"Authorization": f"Bearer {token}",
|
| 289 |
+
"Connection": "keep-alive",
|
| 290 |
+
"Content-Type": "application/json",
|
| 291 |
+
"DNT": "1",
|
| 292 |
+
"Referer": "https://chat.z.ai/",
|
| 293 |
+
"Sec-Fetch-Dest": "empty",
|
| 294 |
+
"Sec-Fetch-Mode": "cors",
|
| 295 |
+
"Sec-Fetch-Site": "same-origin",
|
| 296 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36",
|
| 297 |
+
"sec-ch-ua": '"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"',
|
| 298 |
+
"sec-ch-ua-mobile": "?0",
|
| 299 |
+
"sec-ch-ua-platform": "Windows"
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
async with httpx.AsyncClient(timeout=15.0) as client:
|
| 303 |
+
response = await client.get(auth_url, headers=headers)
|
| 304 |
+
|
| 305 |
+
# 验证token有效性并获取类型
|
| 306 |
+
token_type, is_healthy = self._validate_token_response(response)
|
| 307 |
+
|
| 308 |
+
# 更新token类型
|
| 309 |
+
if token in self.token_statuses:
|
| 310 |
+
self.token_statuses[token].token_type = token_type
|
| 311 |
+
|
| 312 |
+
if is_healthy:
|
| 313 |
+
self.mark_token_success(token)
|
| 314 |
+
else:
|
| 315 |
+
# 简化错误信息,只记录关键错误类型
|
| 316 |
+
if token_type == "guest":
|
| 317 |
+
error_msg = "匿名用户token"
|
| 318 |
+
elif response.status_code != 200:
|
| 319 |
+
error_msg = f"HTTP {response.status_code}"
|
| 320 |
+
else:
|
| 321 |
+
error_msg = "认证失败"
|
| 322 |
+
|
| 323 |
+
self.mark_token_failure(token, Exception(error_msg))
|
| 324 |
+
|
| 325 |
+
return is_healthy
|
| 326 |
+
|
| 327 |
+
except (httpx.TimeoutException, httpx.ConnectError, Exception) as e:
|
| 328 |
+
self.mark_token_failure(token, e)
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
def _validate_token_response(self, response: httpx.Response) -> bool:
|
| 332 |
+
"""
|
| 333 |
+
基于Z.AI API响应中的role字段验证token类型
|
| 334 |
+
|
| 335 |
+
验证规则:
|
| 336 |
+
- role: "user" = 认证用户token(有效,可用于AUTH_TOKENS)
|
| 337 |
+
- role: "guest" = 匿名用户token(无效,不应在AUTH_TOKENS中)
|
| 338 |
+
- 无role字段或其他值 = 无效token
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
response: HTTP响应对象
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
token是否为有效的认证用户token
|
| 345 |
+
"""
|
| 346 |
+
# 首先检查HTTP状态码
|
| 347 |
+
if response.status_code != 200:
|
| 348 |
+
return ("unknown", False)
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
# 尝试解析JSON响应
|
| 352 |
+
response_data = response.json()
|
| 353 |
+
|
| 354 |
+
if not isinstance(response_data, dict):
|
| 355 |
+
return ("unknown", False)
|
| 356 |
+
|
| 357 |
+
# 检查是否包含错误信息
|
| 358 |
+
if "error" in response_data:
|
| 359 |
+
return ("unknown", False)
|
| 360 |
+
|
| 361 |
+
if "message" in response_data and "error" in response_data.get("message", "").lower():
|
| 362 |
+
return ("unknown", False)
|
| 363 |
+
|
| 364 |
+
# 核心验证:检查role字段
|
| 365 |
+
role = response_data.get("role")
|
| 366 |
+
|
| 367 |
+
if role == "user":
|
| 368 |
+
return ("user", True)
|
| 369 |
+
elif role == "guest":
|
| 370 |
+
|
| 371 |
+
if not hasattr(self, '_guest_token_warned'):
|
| 372 |
+
logger.warning("⚠️ 检测到匿名用户token,建议仅在AUTH_TOKENS中配置认证用户token")
|
| 373 |
+
self._guest_token_warned = True
|
| 374 |
+
return ("guest", False)
|
| 375 |
+
else:
|
| 376 |
+
return ("unknown", False)
|
| 377 |
+
|
| 378 |
+
except (ValueError, Exception):
|
| 379 |
+
return ("unknown", False)
|
| 380 |
+
|
| 381 |
+
async def health_check_all(self, auth_url: str = "https://chat.z.ai/api/v1/auths/"):
|
| 382 |
+
"""异步健康检查所有token"""
|
| 383 |
+
if not self.token_statuses:
|
| 384 |
+
logger.warning("⚠️ Token池为空,跳过健康检查")
|
| 385 |
+
return
|
| 386 |
+
|
| 387 |
+
total_tokens = len(self.token_statuses)
|
| 388 |
+
logger.info(f"🔍 开始Token池健康检查... (共 {total_tokens} 个token)")
|
| 389 |
+
|
| 390 |
+
# 并发执行所有token的健康检查
|
| 391 |
+
tasks = []
|
| 392 |
+
token_list = list(self.token_statuses.keys())
|
| 393 |
+
|
| 394 |
+
for token in token_list:
|
| 395 |
+
task = self.health_check_token(token, auth_url)
|
| 396 |
+
tasks.append(task)
|
| 397 |
+
|
| 398 |
+
# 执行并收集结果
|
| 399 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 400 |
+
|
| 401 |
+
# 统计结果
|
| 402 |
+
healthy_count = 0
|
| 403 |
+
failed_count = 0
|
| 404 |
+
exception_count = 0
|
| 405 |
+
|
| 406 |
+
for i, result in enumerate(results):
|
| 407 |
+
if result is True:
|
| 408 |
+
healthy_count += 1
|
| 409 |
+
elif result is False:
|
| 410 |
+
failed_count += 1
|
| 411 |
+
else:
|
| 412 |
+
# 异常情况
|
| 413 |
+
exception_count += 1
|
| 414 |
+
token = token_list[i]
|
| 415 |
+
logger.error(f"💥 Token {token[:20]}... 健康检查异常: {result}")
|
| 416 |
+
|
| 417 |
+
health_rate = (healthy_count / total_tokens) * 100 if total_tokens > 0 else 0
|
| 418 |
+
|
| 419 |
+
if healthy_count == 0 and total_tokens > 0:
|
| 420 |
+
logger.warning(f"⚠️ 健康检查完成: 0/{total_tokens} 个token健康 - 请检查token配置")
|
| 421 |
+
elif failed_count > 0:
|
| 422 |
+
logger.warning(f"⚠️ 健康检查完成: {healthy_count}/{total_tokens} 个token健康 ({health_rate:.1f}%)")
|
| 423 |
+
else:
|
| 424 |
+
logger.info(f"✅ 健康检查完成: {healthy_count}/{total_tokens} 个token健康")
|
| 425 |
+
|
| 426 |
+
if exception_count > 0:
|
| 427 |
+
logger.error(f"💥 {exception_count} 个token检查异常")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# 全局token池实例
|
| 431 |
+
_token_pool: Optional[TokenPool] = None
|
| 432 |
+
_pool_lock = Lock()
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_token_pool() -> Optional[TokenPool]:
|
| 436 |
+
"""获取全局token池实例"""
|
| 437 |
+
return _token_pool
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def initialize_token_pool(tokens: List[str], failure_threshold: int = 3, recovery_timeout: int = 1800) -> TokenPool:
|
| 441 |
+
"""初始化全局token池"""
|
| 442 |
+
global _token_pool
|
| 443 |
+
with _pool_lock:
|
| 444 |
+
_token_pool = TokenPool(tokens, failure_threshold, recovery_timeout)
|
| 445 |
+
return _token_pool
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def update_token_pool(tokens: List[str]):
|
| 449 |
+
"""更新全局token池"""
|
| 450 |
+
global _token_pool
|
| 451 |
+
with _pool_lock:
|
| 452 |
+
if _token_pool:
|
| 453 |
+
_token_pool.update_tokens(tokens)
|
| 454 |
+
else:
|
| 455 |
+
_token_pool = TokenPool(tokens)
|
app/utils/user_agent.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
用户代理工具模块
|
| 6 |
+
提供动态随机用户代理生成功能
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
+
from fake_useragent import UserAgent
|
| 12 |
+
|
| 13 |
+
# 全局 UserAgent 实例(单例模式)
|
| 14 |
+
_user_agent_instance: Optional[UserAgent] = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_user_agent_instance() -> UserAgent:
|
| 18 |
+
"""获取或创建 UserAgent 实例(单例模式)"""
|
| 19 |
+
global _user_agent_instance
|
| 20 |
+
if _user_agent_instance is None:
|
| 21 |
+
_user_agent_instance = UserAgent()
|
| 22 |
+
return _user_agent_instance
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_random_user_agent(browser_type: Optional[str] = None) -> str:
|
| 26 |
+
"""
|
| 27 |
+
获取随机用户代理字符串
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
browser_type: 指定浏览器类型 ('chrome', 'firefox', 'safari', 'edge')
|
| 31 |
+
如果为 None,则随机选择
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
str: 用户代理字符串
|
| 35 |
+
"""
|
| 36 |
+
ua = get_user_agent_instance()
|
| 37 |
+
|
| 38 |
+
# 如果没有指定浏览器类型,随机选择一个(偏向 Chrome 和 Edge)
|
| 39 |
+
if browser_type is None:
|
| 40 |
+
browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"]
|
| 41 |
+
browser_type = random.choice(browser_choices)
|
| 42 |
+
|
| 43 |
+
# 根据浏览器类型获取用户代理
|
| 44 |
+
if browser_type == "chrome":
|
| 45 |
+
user_agent = ua.chrome
|
| 46 |
+
elif browser_type == "edge":
|
| 47 |
+
user_agent = ua.edge
|
| 48 |
+
elif browser_type == "firefox":
|
| 49 |
+
user_agent = ua.firefox
|
| 50 |
+
elif browser_type == "safari":
|
| 51 |
+
user_agent = ua.safari
|
| 52 |
+
else:
|
| 53 |
+
user_agent = ua.random
|
| 54 |
+
|
| 55 |
+
return user_agent
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# 通用 UserAgent headers 生成函数
|
| 59 |
+
def get_dynamic_headers(
|
| 60 |
+
referer: Optional[str] = None,
|
| 61 |
+
origin: Optional[str] = None,
|
| 62 |
+
browser_type: Optional[str] = None,
|
| 63 |
+
additional_headers: Optional[Dict[str, str]] = None
|
| 64 |
+
) -> Dict[str, str]:
|
| 65 |
+
"""
|
| 66 |
+
生成动态浏览器 headers,包含随机 User-Agent
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
referer: 引用页面 URL
|
| 70 |
+
origin: 源站 URL
|
| 71 |
+
browser_type: 指定浏览器类型
|
| 72 |
+
additional_headers: 额外的 headers
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, str]: 包含动态 User-Agent 的 headers
|
| 76 |
+
"""
|
| 77 |
+
user_agent = get_random_user_agent(browser_type)
|
| 78 |
+
|
| 79 |
+
# 基础 headers
|
| 80 |
+
headers = {
|
| 81 |
+
"User-Agent": user_agent,
|
| 82 |
+
"Accept": "application/json, text/event-stream",
|
| 83 |
+
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
|
| 84 |
+
"Accept-Encoding": "gzip, deflate, br",
|
| 85 |
+
"Cache-Control": "no-cache",
|
| 86 |
+
"Connection": "keep-alive",
|
| 87 |
+
"Pragma": "no-cache",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# 添加可选的 headers
|
| 91 |
+
if referer:
|
| 92 |
+
headers["Referer"] = referer
|
| 93 |
+
|
| 94 |
+
if origin:
|
| 95 |
+
headers["Origin"] = origin
|
| 96 |
+
|
| 97 |
+
# 根据用户代理添加浏览器特定的 headers
|
| 98 |
+
if "Chrome/" in user_agent or "Edg/" in user_agent:
|
| 99 |
+
# Chrome/Edge 特定的 headers
|
| 100 |
+
chrome_version = "139"
|
| 101 |
+
edge_version = "139"
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
if "Chrome/" in user_agent:
|
| 105 |
+
chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
|
| 106 |
+
except:
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
if "Edg/" in user_agent:
|
| 111 |
+
edge_version = user_agent.split("Edg/")[1].split(".")[0]
|
| 112 |
+
sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
|
| 113 |
+
else:
|
| 114 |
+
sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
|
| 115 |
+
except:
|
| 116 |
+
sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
|
| 117 |
+
|
| 118 |
+
headers.update({
|
| 119 |
+
"sec-ch-ua": sec_ch_ua,
|
| 120 |
+
"sec-ch-ua-mobile": "?0",
|
| 121 |
+
"sec-ch-ua-platform": '"Windows"',
|
| 122 |
+
"Sec-Fetch-Dest": "empty",
|
| 123 |
+
"Sec-Fetch-Mode": "cors",
|
| 124 |
+
"Sec-Fetch-Site": "same-origin",
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# 添加额外的 headers
|
| 128 |
+
if additional_headers:
|
| 129 |
+
headers.update(additional_headers)
|
| 130 |
+
|
| 131 |
+
return headers
|
| 132 |
+
|
| 133 |
+
|
longcat_tokens.txt.example
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LongCat Passport Token 配置文件(可选)
|
| 2 |
+
#
|
| 3 |
+
# 说明:
|
| 4 |
+
# 1. 此文件是可选的,如果不需要多个token可以删除此文件
|
| 5 |
+
# 2. 支持两种格式:每行一个token 或 逗号分隔的token
|
| 6 |
+
# 3. 只包含有效的 passport_token_key 值
|
| 7 |
+
# 4. 系统会自动去重和验证token有效性
|
| 8 |
+
# 5. 自动跳过空格、换行符和空token
|
| 9 |
+
# 6. 当设置了 LONGCAT_PASSPORT_TOKEN 环境变量时,优先使用环境变量中的token
|
| 10 |
+
#
|
| 11 |
+
# 格式1:纯换行分隔
|
| 12 |
+
# token1
|
| 13 |
+
# token2
|
| 14 |
+
# token3
|
| 15 |
+
|
| 16 |
+
# 格式2:纯逗号分隔
|
| 17 |
+
# token1,token2,token3
|
| 18 |
+
|
| 19 |
+
# 格式3:混合格式
|
| 20 |
+
# token1,token2
|
| 21 |
+
# token3
|
| 22 |
+
# token4,token5,token6
|
| 23 |
+
# token7
|
| 24 |
+
|
| 25 |
+
# 请在下方添加您的 LongCat passport token(使用任一格式):
|
| 26 |
+
|
main.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import psutil
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from fastapi import FastAPI, Response
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
|
| 11 |
+
from app.core.config import settings
|
| 12 |
+
from app.core import openai
|
| 13 |
+
from app.utils.reload_config import RELOAD_CONFIG
|
| 14 |
+
from app.utils.logger import setup_logger
|
| 15 |
+
from app.utils.token_pool import initialize_token_pool
|
| 16 |
+
from app.providers import initialize_providers
|
| 17 |
+
|
| 18 |
+
from granian import Granian
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Setup logger
|
| 22 |
+
logger = setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@asynccontextmanager
|
| 26 |
+
async def lifespan(app: FastAPI):
|
| 27 |
+
# 初始化提供商系统
|
| 28 |
+
initialize_providers()
|
| 29 |
+
|
| 30 |
+
# 初始化 token 池
|
| 31 |
+
token_list = settings.auth_token_list
|
| 32 |
+
if token_list:
|
| 33 |
+
token_pool = initialize_token_pool(
|
| 34 |
+
tokens=token_list,
|
| 35 |
+
failure_threshold=settings.TOKEN_FAILURE_THRESHOLD,
|
| 36 |
+
recovery_timeout=settings.TOKEN_RECOVERY_TIMEOUT
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
yield
|
| 40 |
+
|
| 41 |
+
logger.info("🔄 应用正在关闭...")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Create FastAPI app with lifespan
|
| 45 |
+
app = FastAPI(lifespan=lifespan)
|
| 46 |
+
|
| 47 |
+
# Add CORS middleware
|
| 48 |
+
app.add_middleware(
|
| 49 |
+
CORSMiddleware,
|
| 50 |
+
allow_origins=["*"],
|
| 51 |
+
allow_credentials=True,
|
| 52 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
| 53 |
+
allow_headers=["Content-Type", "Authorization"],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Include API routers
|
| 57 |
+
app.include_router(openai.router)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.options("/")
|
| 61 |
+
async def handle_options():
|
| 62 |
+
"""Handle OPTIONS requests"""
|
| 63 |
+
return Response(status_code=200)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@app.get("/")
|
| 67 |
+
async def root():
|
| 68 |
+
"""Root endpoint"""
|
| 69 |
+
return {"message": "OpenAI Compatible API Server"}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run_server():
|
| 73 |
+
service_name = settings.SERVICE_NAME
|
| 74 |
+
|
| 75 |
+
logger.info(f"🚀 启动 {service_name} 服务...")
|
| 76 |
+
logger.info(f"📡 监听地址: 0.0.0.0:{settings.LISTEN_PORT}")
|
| 77 |
+
logger.info(f"🔧 调试模式: {'开启' if settings.DEBUG_LOGGING else '关闭'}")
|
| 78 |
+
logger.info(f"🔐 匿名模式: {'开启' if settings.ANONYMOUS_MODE else '关闭'}")
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
Granian(
|
| 82 |
+
"main:app",
|
| 83 |
+
interface="asgi",
|
| 84 |
+
address="0.0.0.0",
|
| 85 |
+
port=settings.LISTEN_PORT,
|
| 86 |
+
reload=False, # 生产环境请关闭热重载
|
| 87 |
+
process_name=service_name, # 设置进程名称
|
| 88 |
+
**RELOAD_CONFIG, # 热重载配置
|
| 89 |
+
).serve()
|
| 90 |
+
except KeyboardInterrupt:
|
| 91 |
+
logger.info("🛑 收到中断信号,正在关闭服务...")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"❌ 服务启动失败: {e}")
|
| 94 |
+
sys.exit(1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
run_server()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "z-ai2api-python"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "一个为 Z.ai 提供 OpenAI 兼容接口的 Python 代理服务"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9,<=3.12"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
authors = [{ name = "Contributors" }]
|
| 13 |
+
classifiers = [
|
| 14 |
+
"Development Status :: 4 - Beta",
|
| 15 |
+
"Intended Audience :: Developers",
|
| 16 |
+
"License :: OSI Approved :: MIT License",
|
| 17 |
+
"Operating System :: OS Independent",
|
| 18 |
+
"Programming Language :: Python :: 3",
|
| 19 |
+
"Programming Language :: Python :: 3.9",
|
| 20 |
+
"Programming Language :: Python :: 3.10",
|
| 21 |
+
"Programming Language :: Python :: 3.11",
|
| 22 |
+
"Programming Language :: Python :: 3.12",
|
| 23 |
+
"Topic :: Internet :: WWW/HTTP :: HTTP Servers",
|
| 24 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 25 |
+
]
|
| 26 |
+
dependencies = [
|
| 27 |
+
"fastapi==0.116.1",
|
| 28 |
+
"granian[reload,pname]==2.5.2",
|
| 29 |
+
"httpx==0.28.1",
|
| 30 |
+
"pydantic==2.11.7",
|
| 31 |
+
"pydantic-settings==2.10.1",
|
| 32 |
+
"pydantic-core==2.33.2",
|
| 33 |
+
"typing-inspection==0.4.1",
|
| 34 |
+
"fake-useragent==2.2.0",
|
| 35 |
+
"loguru==0.7.3",
|
| 36 |
+
"psutil>=7.0.0",
|
| 37 |
+
"json-repair==0.44.1"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[project.scripts]
|
| 41 |
+
z-ai2api = "main:app"
|
| 42 |
+
|
| 43 |
+
[tool.hatch.build.targets.wheel]
|
| 44 |
+
packages = ["."]
|
| 45 |
+
|
| 46 |
+
[tool.uv]
|
| 47 |
+
dev-dependencies = [
|
| 48 |
+
"pytest>=7.0.0",
|
| 49 |
+
"pytest-asyncio>=0.21.0",
|
| 50 |
+
"requests>=2.30.0",
|
| 51 |
+
"ruff>=0.1.0",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
[tool.ruff]
|
| 55 |
+
line-length = 88
|
| 56 |
+
target-version = "py38"
|
| 57 |
+
select = ["E", "F", "I", "B"]
|
| 58 |
+
ignore = []
|
| 59 |
+
|
| 60 |
+
[tool.ruff.isort]
|
| 61 |
+
known-first-party = []
|
| 62 |
+
|
| 63 |
+
[tool.pytest.ini_options]
|
| 64 |
+
asyncio_mode = "auto"
|
| 65 |
+
testpaths = ["tests"]
|
| 66 |
+
python_files = ["test_*.py"]
|
| 67 |
+
python_functions = ["test_*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.116.1
|
| 2 |
+
granian[reload,pname]==2.5.2
|
| 3 |
+
httpx==0.28.1
|
| 4 |
+
pydantic==2.11.7
|
| 5 |
+
pydantic-settings==2.10.1
|
| 6 |
+
pydantic-core==2.33.2
|
| 7 |
+
typing-inspection==0.4.1
|
| 8 |
+
fake-useragent==2.2.0
|
| 9 |
+
loguru==0.7.3
|
| 10 |
+
psutil>=7.0.0
|
| 11 |
+
json-repair==0.44.1
|
tests/test_comprehensive_fix.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
全面测试 ZAI Provider 修复效果
|
| 6 |
+
验证流式输出、工具调用、思考模式、重试机制等功能
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import json
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
# 添加项目根目录到路径
|
| 15 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
|
| 17 |
+
from app.providers.zai_provider import ZAIProvider
|
| 18 |
+
from app.models.schemas import OpenAIRequest, Message
|
| 19 |
+
from app.core.config import settings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def test_basic_stream():
|
| 23 |
+
"""测试基本流式输出"""
|
| 24 |
+
print("🧪 测试基本流式输出...")
|
| 25 |
+
|
| 26 |
+
provider = ZAIProvider()
|
| 27 |
+
|
| 28 |
+
request = OpenAIRequest(
|
| 29 |
+
model=settings.PRIMARY_MODEL,
|
| 30 |
+
messages=[
|
| 31 |
+
Message(role="user", content="你好,请简单介绍一下自己")
|
| 32 |
+
],
|
| 33 |
+
stream=True
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
response = await provider.chat_completion(request)
|
| 38 |
+
|
| 39 |
+
if hasattr(response, '__aiter__'):
|
| 40 |
+
print("✅ 返回了异步生成器")
|
| 41 |
+
chunk_count = 0
|
| 42 |
+
content_chunks = []
|
| 43 |
+
|
| 44 |
+
async for chunk in response:
|
| 45 |
+
chunk_count += 1
|
| 46 |
+
if chunk.startswith("data: ") and not chunk.strip().endswith("[DONE]"):
|
| 47 |
+
try:
|
| 48 |
+
chunk_data = json.loads(chunk[6:].strip())
|
| 49 |
+
if "choices" in chunk_data and chunk_data["choices"]:
|
| 50 |
+
choice = chunk_data["choices"][0]
|
| 51 |
+
if "delta" in choice and "content" in choice["delta"]:
|
| 52 |
+
content = choice["delta"]["content"]
|
| 53 |
+
if content:
|
| 54 |
+
content_chunks.append(content)
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
if chunk_count >= 10: # 限制测试长度
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
full_content = "".join(content_chunks)
|
| 62 |
+
print(f"✅ 成功处理了 {chunk_count} 个数据块")
|
| 63 |
+
print(f"📝 内容预览: {full_content[:100]}...")
|
| 64 |
+
return len(content_chunks) > 0
|
| 65 |
+
else:
|
| 66 |
+
print("❌ 返回的不是流式响应")
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"❌ 基本流式测试失败: {e}")
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
async def test_thinking_mode():
|
| 75 |
+
"""测试思考模式"""
|
| 76 |
+
print("\n🧪 测试思考模式...")
|
| 77 |
+
|
| 78 |
+
provider = ZAIProvider()
|
| 79 |
+
|
| 80 |
+
request = OpenAIRequest(
|
| 81 |
+
model=settings.THINKING_MODEL,
|
| 82 |
+
messages=[
|
| 83 |
+
Message(role="user", content="请解释一下量子计算的基本原理")
|
| 84 |
+
],
|
| 85 |
+
stream=True
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
response = await provider.chat_completion(request)
|
| 90 |
+
|
| 91 |
+
if hasattr(response, '__aiter__'):
|
| 92 |
+
print("✅ 返回了异步生成器")
|
| 93 |
+
chunk_count = 0
|
| 94 |
+
has_thinking = False
|
| 95 |
+
has_content = False
|
| 96 |
+
|
| 97 |
+
async for chunk in response:
|
| 98 |
+
chunk_count += 1
|
| 99 |
+
|
| 100 |
+
# 检查是否包含思考内容
|
| 101 |
+
if 'thinking' in chunk:
|
| 102 |
+
has_thinking = True
|
| 103 |
+
print("✅ 检测到思考内容")
|
| 104 |
+
|
| 105 |
+
# 检查是否包含普通内容
|
| 106 |
+
if '"content"' in chunk and '"thinking"' not in chunk:
|
| 107 |
+
has_content = True
|
| 108 |
+
print("✅ 检测到答案内容")
|
| 109 |
+
|
| 110 |
+
if chunk_count >= 20: # 限制测试长度
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
print(f"✅ 成功处理了 {chunk_count} 个数据块")
|
| 114 |
+
print(f"🤔 思考模式: {'正常' if has_thinking else '未检测到'}")
|
| 115 |
+
print(f"💬 答案内容: {'正常' if has_content else '未检测到'}")
|
| 116 |
+
return True
|
| 117 |
+
else:
|
| 118 |
+
print("❌ 返回的不是流式响应")
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"❌ 思考模式测试失败: {e}")
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
async def test_tool_support():
|
| 127 |
+
"""测试工具调用支持"""
|
| 128 |
+
print("\n🧪 测试工具调用支持...")
|
| 129 |
+
|
| 130 |
+
if not settings.TOOL_SUPPORT:
|
| 131 |
+
print("⚠️ 工具支持已禁用,跳过测试")
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
provider = ZAIProvider()
|
| 135 |
+
|
| 136 |
+
# 简单的工具定义
|
| 137 |
+
tools = [
|
| 138 |
+
{
|
| 139 |
+
"type": "function",
|
| 140 |
+
"function": {
|
| 141 |
+
"name": "get_weather",
|
| 142 |
+
"description": "获取指定城市的天气信息",
|
| 143 |
+
"parameters": {
|
| 144 |
+
"type": "object",
|
| 145 |
+
"properties": {
|
| 146 |
+
"city": {
|
| 147 |
+
"type": "string",
|
| 148 |
+
"description": "城市名称"
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
"required": ["city"]
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
request = OpenAIRequest(
|
| 158 |
+
model=settings.PRIMARY_MODEL,
|
| 159 |
+
messages=[
|
| 160 |
+
Message(role="user", content="请帮我查询北京的天气")
|
| 161 |
+
],
|
| 162 |
+
tools=tools,
|
| 163 |
+
stream=True
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
response = await provider.chat_completion(request)
|
| 168 |
+
|
| 169 |
+
if hasattr(response, '__aiter__'):
|
| 170 |
+
print("✅ 返回了异步生成器")
|
| 171 |
+
chunk_count = 0
|
| 172 |
+
has_tool_call = False
|
| 173 |
+
|
| 174 |
+
async for chunk in response:
|
| 175 |
+
chunk_count += 1
|
| 176 |
+
|
| 177 |
+
# 检查是否包含工具调用
|
| 178 |
+
if 'tool_calls' in chunk:
|
| 179 |
+
has_tool_call = True
|
| 180 |
+
print("✅ 检测到工具调用")
|
| 181 |
+
|
| 182 |
+
if chunk_count >= 30: # 限制测试长度
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
print(f"✅ 成功处理了 {chunk_count} 个数据块")
|
| 186 |
+
print(f"🔧 工具调用: {'正常' if has_tool_call else '未检测到'}")
|
| 187 |
+
return True
|
| 188 |
+
else:
|
| 189 |
+
print("❌ 返回的不是流式响应")
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"❌ 工具调用测试失败: {e}")
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
async def test_error_handling():
|
| 198 |
+
"""测试错误处理"""
|
| 199 |
+
print("\n🧪 测试错误处理...")
|
| 200 |
+
|
| 201 |
+
provider = ZAIProvider()
|
| 202 |
+
|
| 203 |
+
# 使用无效的消息来触发错误
|
| 204 |
+
request = OpenAIRequest(
|
| 205 |
+
model="invalid-model",
|
| 206 |
+
messages=[
|
| 207 |
+
Message(role="user", content="测试错误处理")
|
| 208 |
+
],
|
| 209 |
+
stream=True
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
response = await provider.chat_completion(request)
|
| 214 |
+
|
| 215 |
+
if hasattr(response, '__aiter__'):
|
| 216 |
+
chunk_count = 0
|
| 217 |
+
has_error = False
|
| 218 |
+
|
| 219 |
+
async for chunk in response:
|
| 220 |
+
chunk_count += 1
|
| 221 |
+
|
| 222 |
+
# 检查是否包含错误信息
|
| 223 |
+
if 'error' in chunk:
|
| 224 |
+
has_error = True
|
| 225 |
+
print("✅ 检测到错误处理")
|
| 226 |
+
|
| 227 |
+
if chunk_count >= 5: # 限制测试长度
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
print(f"✅ 错误处理测试完成,处理了 {chunk_count} 个数据块")
|
| 231 |
+
return True
|
| 232 |
+
else:
|
| 233 |
+
print("✅ 返回了错误响应(非流式)")
|
| 234 |
+
return True
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"✅ 正确捕获了异常: {type(e).__name__}")
|
| 238 |
+
return True
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
async def main():
|
| 242 |
+
"""主测试函数"""
|
| 243 |
+
print("🚀 开始全面测试 ZAI Provider 修复效果\n")
|
| 244 |
+
|
| 245 |
+
# 显示配置信息
|
| 246 |
+
print("📋 当前配置:")
|
| 247 |
+
print(f" - 匿名模式: {settings.ANONYMOUS_MODE}")
|
| 248 |
+
print(f" - 工具支持: {settings.TOOL_SUPPORT}")
|
| 249 |
+
print(f" - 最大重试: {settings.MAX_RETRIES}")
|
| 250 |
+
print(f" - 重试延迟: {settings.RETRY_DELAY}s")
|
| 251 |
+
print()
|
| 252 |
+
|
| 253 |
+
tests = [
|
| 254 |
+
("基本流式输出", test_basic_stream),
|
| 255 |
+
("思考模式", test_thinking_mode),
|
| 256 |
+
("工具调用支持", test_tool_support),
|
| 257 |
+
("错误处理", test_error_handling),
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
passed = 0
|
| 261 |
+
total = len(tests)
|
| 262 |
+
|
| 263 |
+
for test_name, test_func in tests:
|
| 264 |
+
try:
|
| 265 |
+
print(f"{'='*50}")
|
| 266 |
+
result = await test_func()
|
| 267 |
+
if result:
|
| 268 |
+
passed += 1
|
| 269 |
+
print(f"✅ {test_name} 测试通过")
|
| 270 |
+
else:
|
| 271 |
+
print(f"❌ {test_name} 测试失败")
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"❌ {test_name} 测试异常: {e}")
|
| 274 |
+
|
| 275 |
+
print()
|
| 276 |
+
|
| 277 |
+
print(f"{'='*50}")
|
| 278 |
+
print(f"📊 测试结果: {passed}/{total} 通过")
|
| 279 |
+
|
| 280 |
+
if passed == total:
|
| 281 |
+
print("🎉 所有测试都通过了!ZAI Provider 修复成功")
|
| 282 |
+
elif passed >= total * 0.75:
|
| 283 |
+
print("✅ 大部分测试通过,ZAI Provider 基本修复成功")
|
| 284 |
+
else:
|
| 285 |
+
print("⚠️ 多个测试失败,需要进一步检查")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
asyncio.run(main())
|
tests/test_done_phase.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
测试 done 阶段处理
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
|
| 10 |
+
from app.utils.sse_tool_handler import SSEToolHandler
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
def test_done_phase_handling():
|
| 14 |
+
"""测试 done 阶段的处理"""
|
| 15 |
+
|
| 16 |
+
handler = SSEToolHandler("test-model", stream=True)
|
| 17 |
+
|
| 18 |
+
print("🧪 测试 done 阶段处理\n")
|
| 19 |
+
|
| 20 |
+
# 模拟完整的对话流程
|
| 21 |
+
test_chunks = [
|
| 22 |
+
# 回答阶段
|
| 23 |
+
{
|
| 24 |
+
"phase": "answer",
|
| 25 |
+
"delta_content": "这是回答内容",
|
| 26 |
+
"edit_content": ""
|
| 27 |
+
},
|
| 28 |
+
# 完成阶段
|
| 29 |
+
{
|
| 30 |
+
"phase": "done",
|
| 31 |
+
"done": True,
|
| 32 |
+
"delta_content": "",
|
| 33 |
+
"usage": {
|
| 34 |
+
"prompt_tokens": 100,
|
| 35 |
+
"completion_tokens": 50,
|
| 36 |
+
"total_tokens": 150
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
output_chunks = []
|
| 42 |
+
|
| 43 |
+
for i, chunk in enumerate(test_chunks, 1):
|
| 44 |
+
print(f"处理块 {i}: phase={chunk['phase']}")
|
| 45 |
+
|
| 46 |
+
results = list(handler.process_sse_chunk(chunk))
|
| 47 |
+
output_chunks.extend(results)
|
| 48 |
+
|
| 49 |
+
print(f" 输出数量: {len(results)}")
|
| 50 |
+
for j, result in enumerate(results):
|
| 51 |
+
if result.strip() == "data: [DONE]":
|
| 52 |
+
print(f" 输出 {j+1}: [DONE] 标记")
|
| 53 |
+
else:
|
| 54 |
+
print(f" 输出 {j+1}: {result[:80]}{'...' if len(result) > 80 else ''}")
|
| 55 |
+
print()
|
| 56 |
+
|
| 57 |
+
print(f"📊 测试结果:")
|
| 58 |
+
print(f" 总输出块数量: {len(output_chunks)}")
|
| 59 |
+
|
| 60 |
+
# 验证输出内容
|
| 61 |
+
has_content = False
|
| 62 |
+
has_final_chunk = False
|
| 63 |
+
has_done_marker = False
|
| 64 |
+
has_usage = False
|
| 65 |
+
|
| 66 |
+
for output in output_chunks:
|
| 67 |
+
if output.startswith("data: "):
|
| 68 |
+
json_str = output[6:].strip()
|
| 69 |
+
if json_str == "[DONE]":
|
| 70 |
+
has_done_marker = True
|
| 71 |
+
print(" ✅ 找到 [DONE] 标记")
|
| 72 |
+
elif json_str:
|
| 73 |
+
try:
|
| 74 |
+
data = json.loads(json_str)
|
| 75 |
+
if "choices" in data and data["choices"]:
|
| 76 |
+
delta = data["choices"][0].get("delta", {})
|
| 77 |
+
content = delta.get("content", "")
|
| 78 |
+
finish_reason = data["choices"][0].get("finish_reason")
|
| 79 |
+
|
| 80 |
+
if content:
|
| 81 |
+
has_content = True
|
| 82 |
+
print(f" ✅ 找到内容: '{content}'")
|
| 83 |
+
|
| 84 |
+
if finish_reason == "stop":
|
| 85 |
+
has_final_chunk = True
|
| 86 |
+
print(" ✅ 找到最终完成块")
|
| 87 |
+
|
| 88 |
+
if "usage" in data:
|
| 89 |
+
has_usage = True
|
| 90 |
+
print(f" ✅ 找到 usage 信息: {data['usage']}")
|
| 91 |
+
|
| 92 |
+
except json.JSONDecodeError as e:
|
| 93 |
+
print(f" ❌ JSON 解析错误: {e}")
|
| 94 |
+
|
| 95 |
+
# 验证结果
|
| 96 |
+
success = has_content and has_final_chunk and has_done_marker
|
| 97 |
+
|
| 98 |
+
print(f"\n📋 验证结果:")
|
| 99 |
+
print(f" 包含回答内容: {'✅' if has_content else '❌'}")
|
| 100 |
+
print(f" 包含最终完成块: {'✅' if has_final_chunk else '❌'}")
|
| 101 |
+
print(f" 包含 [DONE] 标记: {'✅' if has_done_marker else '❌'}")
|
| 102 |
+
print(f" 包含 usage 信息: {'✅' if has_usage else '❌'}")
|
| 103 |
+
|
| 104 |
+
if success:
|
| 105 |
+
print("\n✅ done 阶段处理测试通过!")
|
| 106 |
+
return True
|
| 107 |
+
else:
|
| 108 |
+
print("\n❌ done 阶段处理测试失败!")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def test_done_phase_with_tool_call():
|
| 112 |
+
"""测试带工具调用的 done 阶段处理"""
|
| 113 |
+
|
| 114 |
+
handler = SSEToolHandler("test-model", stream=True)
|
| 115 |
+
|
| 116 |
+
print("🧪 测试带工具调用的 done 阶段处理\n")
|
| 117 |
+
|
| 118 |
+
# 模拟工具调用 + 回答 + 完成的流程
|
| 119 |
+
test_chunks = [
|
| 120 |
+
# 工具调用开始
|
| 121 |
+
{
|
| 122 |
+
"phase": "tool_call",
|
| 123 |
+
"edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_test", "name": "test_tool", "arguments": "{}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 124 |
+
"edit_index": 100
|
| 125 |
+
},
|
| 126 |
+
# 工具调用结束
|
| 127 |
+
{
|
| 128 |
+
"phase": "other",
|
| 129 |
+
"edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 130 |
+
"edit_index": 200
|
| 131 |
+
},
|
| 132 |
+
# 回答阶段
|
| 133 |
+
{
|
| 134 |
+
"phase": "answer",
|
| 135 |
+
"delta_content": "工具调用完成,这是回答。",
|
| 136 |
+
"edit_content": ""
|
| 137 |
+
},
|
| 138 |
+
# 完成阶段
|
| 139 |
+
{
|
| 140 |
+
"phase": "done",
|
| 141 |
+
"done": True,
|
| 142 |
+
"delta_content": ""
|
| 143 |
+
}
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
output_chunks = []
|
| 147 |
+
|
| 148 |
+
for i, chunk in enumerate(test_chunks, 1):
|
| 149 |
+
print(f"处理块 {i}: phase={chunk['phase']}")
|
| 150 |
+
|
| 151 |
+
results = list(handler.process_sse_chunk(chunk))
|
| 152 |
+
output_chunks.extend(results)
|
| 153 |
+
|
| 154 |
+
print(f" 输出数量: {len(results)}")
|
| 155 |
+
print()
|
| 156 |
+
|
| 157 |
+
# 检查是否有工具调用、回答内容和完成标记
|
| 158 |
+
has_tool_call = any("tool_calls" in output for output in output_chunks)
|
| 159 |
+
has_answer_content = any("工具调用完成" in output for output in output_chunks)
|
| 160 |
+
has_done_marker = any(output.strip() == "data: [DONE]" for output in output_chunks)
|
| 161 |
+
|
| 162 |
+
print(f"📊 混合流程测试结果:")
|
| 163 |
+
print(f" 包含工具调用: {'✅' if has_tool_call else '❌'}")
|
| 164 |
+
print(f" 包含回答内容: {'✅' if has_answer_content else '❌'}")
|
| 165 |
+
print(f" 包含 [DONE] 标记: {'✅' if has_done_marker else '❌'}")
|
| 166 |
+
|
| 167 |
+
success = has_tool_call and has_answer_content and has_done_marker
|
| 168 |
+
|
| 169 |
+
if success:
|
| 170 |
+
print("\n✅ 混合流程 done 阶段测试通过!")
|
| 171 |
+
return True
|
| 172 |
+
else:
|
| 173 |
+
print("\n❌ 混合流程 done 阶段测试失败!")
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
def test_done_phase_warning_fix():
|
| 177 |
+
"""测试 done 阶段不再产生警告"""
|
| 178 |
+
|
| 179 |
+
handler = SSEToolHandler("test-model", stream=True)
|
| 180 |
+
|
| 181 |
+
print("🧪 测试 done 阶段警告修复\n")
|
| 182 |
+
|
| 183 |
+
# 模拟 done 阶段
|
| 184 |
+
chunk = {
|
| 185 |
+
"phase": "done",
|
| 186 |
+
"done": True,
|
| 187 |
+
"delta_content": ""
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
print("处理 done 阶段块...")
|
| 191 |
+
|
| 192 |
+
# 捕获日志输出(这里我们主要检查是否有异常)
|
| 193 |
+
try:
|
| 194 |
+
results = list(handler.process_sse_chunk(chunk))
|
| 195 |
+
print(f" 成功处理,输出 {len(results)} 个块")
|
| 196 |
+
|
| 197 |
+
# 检查是否有 [DONE] 标记
|
| 198 |
+
has_done = any(output.strip() == "data: [DONE]" for output in results)
|
| 199 |
+
print(f" 包含 [DONE] 标记: {'✅' if has_done else '❌'}")
|
| 200 |
+
|
| 201 |
+
print("\n✅ done 阶段不再产生警告!")
|
| 202 |
+
return True
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"\n❌ 处理 done 阶段时出错: {e}")
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
print("🔧 测试 done 阶段处理\n")
|
| 210 |
+
|
| 211 |
+
test1_success = test_done_phase_handling()
|
| 212 |
+
print("\n" + "="*50 + "\n")
|
| 213 |
+
test2_success = test_done_phase_with_tool_call()
|
| 214 |
+
print("\n" + "="*50 + "\n")
|
| 215 |
+
test3_success = test_done_phase_warning_fix()
|
| 216 |
+
|
| 217 |
+
print("\n" + "="*50)
|
| 218 |
+
print("🎯 总结:")
|
| 219 |
+
print(f" done 阶段基本处理: {'✅ 通过' if test1_success else '❌ 失败'}")
|
| 220 |
+
print(f" done 阶段混合流程: {'✅ 通过' if test2_success else '❌ 失败'}")
|
| 221 |
+
print(f" done 阶段警告修复: {'✅ 通过' if test3_success else '❌ 失败'}")
|
| 222 |
+
|
| 223 |
+
if test1_success and test2_success and test3_success:
|
| 224 |
+
print("\n🎉 所有测试通过!done 阶段处理完善!")
|
| 225 |
+
print("\n💡 修复效果:")
|
| 226 |
+
print(" - 不再显示 '未知的 SSE 阶段: done' 警告")
|
| 227 |
+
print(" - 正确处理对话完成流程")
|
| 228 |
+
print(" - 自动刷新缓冲区和完成工具调用")
|
| 229 |
+
print(" - 发送标准的 OpenAI 完成标记")
|
| 230 |
+
else:
|
| 231 |
+
print("\n❌ 部分测试失败,需要进一步调试")
|
tests/test_longcat_connection.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
测试 LongCat API 连接性
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import httpx
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
# LongCat API 端点
|
| 13 |
+
LONGCAT_API_ENDPOINT = "https://longcat.chat/api/v1/chat-completion-oversea"
|
| 14 |
+
|
| 15 |
+
async def test_longcat_api():
|
| 16 |
+
"""测试 LongCat API 连接"""
|
| 17 |
+
print(f"🧪 测试 LongCat API 连接...")
|
| 18 |
+
print(f"📡 API 端点: {LONGCAT_API_ENDPOINT}")
|
| 19 |
+
|
| 20 |
+
headers = {
|
| 21 |
+
'accept': 'text/event-stream,application/json',
|
| 22 |
+
'content-type': 'application/json',
|
| 23 |
+
'origin': 'https://longcat.chat',
|
| 24 |
+
'referer': 'https://longcat.chat/t',
|
| 25 |
+
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36'
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
payload = {
|
| 29 |
+
"stream": True,
|
| 30 |
+
"temperature": 0.7,
|
| 31 |
+
"content": "Hello",
|
| 32 |
+
"messages": [
|
| 33 |
+
{
|
| 34 |
+
"role": "user",
|
| 35 |
+
"content": "Hello"
|
| 36 |
+
}
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
print(f"📤 发送请求...")
|
| 41 |
+
print(f"📋 Headers: {json.dumps(headers, indent=2)}")
|
| 42 |
+
print(f"📋 Payload: {json.dumps(payload, indent=2)}")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 46 |
+
response = await client.post(
|
| 47 |
+
LONGCAT_API_ENDPOINT,
|
| 48 |
+
headers=headers,
|
| 49 |
+
json=payload
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
print(f"📡 响应状态码: {response.status_code}")
|
| 53 |
+
print(f"📋 响应头: {dict(response.headers)}")
|
| 54 |
+
|
| 55 |
+
if not response.is_success:
|
| 56 |
+
error_text = await response.atext()
|
| 57 |
+
print(f"❌ API 错误: {error_text}")
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
print(f"✅ 连接成功,开始读取流数据...")
|
| 61 |
+
|
| 62 |
+
line_count = 0
|
| 63 |
+
async for line in response.aiter_lines():
|
| 64 |
+
line_count += 1
|
| 65 |
+
line = line.strip()
|
| 66 |
+
print(f"📥 第 {line_count} 行: {line}")
|
| 67 |
+
|
| 68 |
+
if line_count > 10: # 只读取前10行
|
| 69 |
+
print(f"⏹️ 停止读取(已读取 {line_count} 行)")
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
if line.startswith('data:'):
|
| 73 |
+
data_str = line[5:].strip()
|
| 74 |
+
if data_str == '[DONE]':
|
| 75 |
+
print(f"🏁 收到结束标记")
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
data = json.loads(data_str)
|
| 80 |
+
print(f"📦 解析成功: {json.dumps(data, ensure_ascii=False, indent=2)}")
|
| 81 |
+
except json.JSONDecodeError as e:
|
| 82 |
+
print(f"❌ JSON 解析失败: {e}")
|
| 83 |
+
|
| 84 |
+
print(f"✅ 测试完成,共读取 {line_count} 行")
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
except httpx.TimeoutException:
|
| 88 |
+
print(f"❌ 请求超时")
|
| 89 |
+
return False
|
| 90 |
+
except httpx.ConnectError as e:
|
| 91 |
+
print(f"❌ 连接错误: {e}")
|
| 92 |
+
return False
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"❌ 未知错误: {e}")
|
| 95 |
+
import traceback
|
| 96 |
+
print(f"❌ 错误堆栈: {traceback.format_exc()}")
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
async def test_simple_request():
|
| 100 |
+
"""测试简单的非流式请求"""
|
| 101 |
+
print(f"\n🧪 测试简单的非流式请求...")
|
| 102 |
+
|
| 103 |
+
headers = {
|
| 104 |
+
'accept': 'application/json',
|
| 105 |
+
'content-type': 'application/json',
|
| 106 |
+
'origin': 'https://longcat.chat',
|
| 107 |
+
'referer': 'https://longcat.chat/t',
|
| 108 |
+
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
payload = {
|
| 112 |
+
"stream": False,
|
| 113 |
+
"temperature": 0.7,
|
| 114 |
+
"content": "Hello",
|
| 115 |
+
"messages": [
|
| 116 |
+
{
|
| 117 |
+
"role": "user",
|
| 118 |
+
"content": "Hello"
|
| 119 |
+
}
|
| 120 |
+
]
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 125 |
+
response = await client.post(
|
| 126 |
+
LONGCAT_API_ENDPOINT,
|
| 127 |
+
headers=headers,
|
| 128 |
+
json=payload
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
print(f"📡 响应状态码: {response.status_code}")
|
| 132 |
+
|
| 133 |
+
if response.is_success:
|
| 134 |
+
content = await response.atext()
|
| 135 |
+
print(f"✅ 响应内容: {content[:500]}...")
|
| 136 |
+
return True
|
| 137 |
+
else:
|
| 138 |
+
error_text = await response.atext()
|
| 139 |
+
print(f"❌ 错误响应: {error_text}")
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"❌ 请求失败: {e}")
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
async def main():
|
| 147 |
+
"""运行所有测试"""
|
| 148 |
+
print("🚀 开始 LongCat API 连接测试...\n")
|
| 149 |
+
|
| 150 |
+
# 测试流式请求
|
| 151 |
+
stream_result = await test_longcat_api()
|
| 152 |
+
|
| 153 |
+
# 测试非流式请求
|
| 154 |
+
simple_result = await test_simple_request()
|
| 155 |
+
|
| 156 |
+
print(f"\n📊 测试结果:")
|
| 157 |
+
print(f" 流式请求: {'✅ 成功' if stream_result else '❌ 失败'}")
|
| 158 |
+
print(f" 非流式请求: {'✅ 成功' if simple_result else '❌ 失败'}")
|
| 159 |
+
|
| 160 |
+
if stream_result and simple_result:
|
| 161 |
+
print(f"🎉 所有测试通过!")
|
| 162 |
+
else:
|
| 163 |
+
print(f"⚠️ 部分测试失败,请检查网络连接和 API 端点")
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
asyncio.run(main())
|
tests/test_multiple_tools.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
测试多个工具调用的处理逻辑
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
|
| 10 |
+
from app.utils.sse_tool_handler import SSEToolHandler
|
| 11 |
+
|
| 12 |
+
def test_multiple_tool_calls():
|
| 13 |
+
"""测试多个工具调用的处理"""
|
| 14 |
+
|
| 15 |
+
handler = SSEToolHandler("test-model", stream=False)
|
| 16 |
+
|
| 17 |
+
print("🧪 测试多个工具调用处理\n")
|
| 18 |
+
|
| 19 |
+
# 模拟真实的多工具调用序列(基于日志)
|
| 20 |
+
test_chunks = [
|
| 21 |
+
# 第一个工具调用开始
|
| 22 |
+
{
|
| 23 |
+
"phase": "tool_call",
|
| 24 |
+
"edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_5y5gir0mygx", "name": "mcp__playwright__browser_navigate", "arguments": "{\\"url\\":\\"https://www.bil", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 25 |
+
"edit_index": 24
|
| 26 |
+
},
|
| 27 |
+
# 第一个工具调用参数补充
|
| 28 |
+
{
|
| 29 |
+
"phase": "tool_call",
|
| 30 |
+
"edit_content": 'ibili.com\\"}',
|
| 31 |
+
"edit_index": 194
|
| 32 |
+
},
|
| 33 |
+
# 第一个工具调用结束
|
| 34 |
+
{
|
| 35 |
+
"phase": "other",
|
| 36 |
+
"edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 37 |
+
"edit_index": 219
|
| 38 |
+
},
|
| 39 |
+
# 第二个工具调用开始
|
| 40 |
+
{
|
| 41 |
+
"phase": "tool_call",
|
| 42 |
+
"edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_j8r24x6xtg", "name": "mcp__playwright__browser_snapshot", "arguments": "{}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 43 |
+
"edit_index": 406
|
| 44 |
+
},
|
| 45 |
+
# 第二个工具调用结束
|
| 46 |
+
{
|
| 47 |
+
"phase": "other",
|
| 48 |
+
"edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 49 |
+
"edit_index": 566
|
| 50 |
+
},
|
| 51 |
+
# 第三个工具调用开始(重复的 navigate)
|
| 52 |
+
{
|
| 53 |
+
"phase": "tool_call",
|
| 54 |
+
"edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_scvwo0xaoil", "name": "mcp__playwright__browser_navigate", "arguments": "{\\"url\\":\\"https://www.bil", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 55 |
+
"edit_index": 753
|
| 56 |
+
},
|
| 57 |
+
# 第三个工具调用参数补充
|
| 58 |
+
{
|
| 59 |
+
"phase": "tool_call",
|
| 60 |
+
"edit_content": 'ibili.com\\"}',
|
| 61 |
+
"edit_index": 925
|
| 62 |
+
},
|
| 63 |
+
# 第三个工具调用结束
|
| 64 |
+
{
|
| 65 |
+
"phase": "other",
|
| 66 |
+
"edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
|
| 67 |
+
"edit_index": 950
|
| 68 |
+
}
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
tool_calls_completed = []
|
| 72 |
+
|
| 73 |
+
for i, chunk in enumerate(test_chunks, 1):
|
| 74 |
+
print(f"处理块 {i}: edit_index={chunk['edit_index']}, phase={chunk['phase']}")
|
| 75 |
+
|
| 76 |
+
# 记录处理前的工具状态
|
| 77 |
+
old_tool_id = handler.tool_id
|
| 78 |
+
old_tool_name = handler.tool_name
|
| 79 |
+
old_has_tool_call = handler.has_tool_call
|
| 80 |
+
|
| 81 |
+
# 处理块
|
| 82 |
+
results = list(handler.process_sse_chunk(chunk))
|
| 83 |
+
|
| 84 |
+
# 检查是否有新工具调用开始
|
| 85 |
+
if handler.tool_id != old_tool_id and handler.tool_id:
|
| 86 |
+
print(f" 🎯 新工具调用开始: {handler.tool_name} (id: {handler.tool_id})")
|
| 87 |
+
|
| 88 |
+
# 检查是否有工具调用完成
|
| 89 |
+
if old_has_tool_call and not handler.has_tool_call:
|
| 90 |
+
tool_calls_completed.append({
|
| 91 |
+
"name": old_tool_name or "unknown",
|
| 92 |
+
"id": old_tool_id
|
| 93 |
+
})
|
| 94 |
+
print(f" ✅ 工具调用完成: {old_tool_name or 'unknown'}")
|
| 95 |
+
|
| 96 |
+
print(f" 当前状态: has_tool_call={handler.has_tool_call}, tool_id={handler.tool_id}")
|
| 97 |
+
print()
|
| 98 |
+
|
| 99 |
+
print(f"📊 测试结果:")
|
| 100 |
+
print(f" 完成的工具调用数量: {len(tool_calls_completed)}")
|
| 101 |
+
for i, tool in enumerate(tool_calls_completed, 1):
|
| 102 |
+
print(f" {i}. {tool['name']} (id: {tool['id']})")
|
| 103 |
+
|
| 104 |
+
# 验证是否正确处理了所有工具调用
|
| 105 |
+
expected_tools = [
|
| 106 |
+
"mcp__playwright__browser_navigate",
|
| 107 |
+
"mcp__playwright__browser_snapshot",
|
| 108 |
+
"mcp__playwright__browser_navigate"
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
completed_tool_names = [tool['name'] for tool in tool_calls_completed]
|
| 112 |
+
|
| 113 |
+
if completed_tool_names == expected_tools:
|
| 114 |
+
print("\n✅ 测试通过!正确处理了所有工具调用")
|
| 115 |
+
print("📝 结论:重复的工具调用是上游发送的,我们的处理逻辑是正确的")
|
| 116 |
+
return True
|
| 117 |
+
else:
|
| 118 |
+
print(f"\n❌ 测试失败!")
|
| 119 |
+
print(f" 期望: {expected_tools}")
|
| 120 |
+
print(f" 实际: {completed_tool_names}")
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
success = test_multiple_tool_calls()
|
| 125 |
+
|
| 126 |
+
if success:
|
| 127 |
+
print("\n🎯 总结:")
|
| 128 |
+
print("1. 我们的 API 代理正确处理了每个不同的工具调用")
|
| 129 |
+
print("2. 重复的工具调用是上游 Z.AI 模型发送的,不是我们的问题")
|
| 130 |
+
print("3. 每个工具调用都有不同的 ID,说明这是模型的有意行为")
|
| 131 |
+
print("4. 可能的原因:模型重试、验证操作、或处理复杂任务的策略")
|
| 132 |
+
else:
|
| 133 |
+
print("\n❌ 需要进一步调试处理逻辑")
|
tests/test_simple_performance.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
简化的性能测试,避免过多日志输出
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
|
| 13 |
+
# 临时禁用日志以避免性能测试中的噪音
|
| 14 |
+
logging.getLogger().setLevel(logging.CRITICAL)
|
| 15 |
+
|
| 16 |
+
from app.utils.sse_tool_handler import SSEToolHandler
|
| 17 |
+
|
| 18 |
+
def test_optimized_performance():
|
| 19 |
+
"""测试优化后的性能"""
|
| 20 |
+
|
| 21 |
+
print("🧪 测试优化后的 JSON 修复性能\n")
|
| 22 |
+
|
| 23 |
+
# 测试用例
|
| 24 |
+
test_cases = [
|
| 25 |
+
{
|
| 26 |
+
"name": "简单JSON",
|
| 27 |
+
"input": '{"command":"echo hello","description":"test"}',
|
| 28 |
+
"iterations": 100
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"name": "复杂命令行参数",
|
| 32 |
+
"input": '{"command":"echo \\"添加更多内容\\uff1a$(date)\\\\\\" >> \\\\\\"C:\\\\\\\\Users\\\\\\\\test\\\\\\\\1.txt\\\\\\"\\"","description":"test"}',
|
| 33 |
+
"iterations": 50
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"name": "缺少开始括号",
|
| 37 |
+
"input": '"command":"echo hello","description":"test"}',
|
| 38 |
+
"iterations": 50
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"name": "Windows路径问题",
|
| 42 |
+
"input": '{"path":"C:\\\\\\\\Users\\\\\\\\Documents","command":"dir"}',
|
| 43 |
+
"iterations": 50
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
handler = SSEToolHandler("test-model", stream=False)
|
| 48 |
+
|
| 49 |
+
total_time = 0
|
| 50 |
+
total_iterations = 0
|
| 51 |
+
|
| 52 |
+
for test_case in test_cases:
|
| 53 |
+
print(f"测试: {test_case['name']}")
|
| 54 |
+
print(f" 输入长度: {len(test_case['input'])} 字符")
|
| 55 |
+
print(f" 迭代次数: {test_case['iterations']}")
|
| 56 |
+
|
| 57 |
+
# 预热
|
| 58 |
+
for _ in range(5):
|
| 59 |
+
handler._fix_tool_arguments(test_case['input'])
|
| 60 |
+
|
| 61 |
+
# 性能测试
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
for _ in range(test_case['iterations']):
|
| 64 |
+
result = handler._fix_tool_arguments(test_case['input'])
|
| 65 |
+
end_time = time.time()
|
| 66 |
+
|
| 67 |
+
duration = end_time - start_time
|
| 68 |
+
if duration > 0:
|
| 69 |
+
avg_time = duration / test_case['iterations'] * 1000 # 毫秒
|
| 70 |
+
throughput = test_case['iterations'] / duration
|
| 71 |
+
else:
|
| 72 |
+
avg_time = 0
|
| 73 |
+
throughput = float('inf')
|
| 74 |
+
|
| 75 |
+
print(f" 总时间: {duration:.4f}s")
|
| 76 |
+
print(f" 平均时间: {avg_time:.4f}ms")
|
| 77 |
+
print(f" 吞吐量: {throughput:.1f} ops/s")
|
| 78 |
+
|
| 79 |
+
total_time += duration
|
| 80 |
+
total_iterations += test_case['iterations']
|
| 81 |
+
|
| 82 |
+
# 验证结果正确性
|
| 83 |
+
try:
|
| 84 |
+
parsed = json.loads(result)
|
| 85 |
+
print(f" ✅ 结果有效")
|
| 86 |
+
except:
|
| 87 |
+
print(f" ❌ 结果无效")
|
| 88 |
+
|
| 89 |
+
print()
|
| 90 |
+
|
| 91 |
+
print(f"📊 总体性能:")
|
| 92 |
+
print(f" 总时间: {total_time:.4f}s")
|
| 93 |
+
print(f" 总迭代: {total_iterations}")
|
| 94 |
+
if total_time > 0:
|
| 95 |
+
print(f" 平均性能: {total_iterations/total_time:.1f} ops/s")
|
| 96 |
+
print(f" 平均延迟: {total_time/total_iterations*1000:.4f}ms")
|
| 97 |
+
else:
|
| 98 |
+
print(f" 平均性能: ∞ ops/s")
|
| 99 |
+
print(f" 平均延迟: 0.0000ms")
|
| 100 |
+
|
| 101 |
+
def test_code_simplification_benefits():
|
| 102 |
+
"""测试代码简化的好处"""
|
| 103 |
+
|
| 104 |
+
print("\n🧪 测试代码简化的好处\n")
|
| 105 |
+
|
| 106 |
+
# 测试不同复杂度的JSON
|
| 107 |
+
test_cases = [
|
| 108 |
+
'{"command":"echo hello"}', # 简单
|
| 109 |
+
'{"command":"echo \\"hello\\"","description":"test"}', # 转义引号
|
| 110 |
+
'"command":"echo hello","description":"test"}', # 缺少开始括号
|
| 111 |
+
'{"command":"echo hello > file.txt\\"","description":"test"}', # 多余引号
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
handler = SSEToolHandler("test-model", stream=False)
|
| 115 |
+
|
| 116 |
+
print("测试各种JSON修复场景:")
|
| 117 |
+
for i, test_input in enumerate(test_cases, 1):
|
| 118 |
+
print(f"\n场景 {i}: {test_input[:50]}{'...' if len(test_input) > 50 else ''}")
|
| 119 |
+
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
result = handler._fix_tool_arguments(test_input)
|
| 122 |
+
end_time = time.time()
|
| 123 |
+
|
| 124 |
+
duration = (end_time - start_time) * 1000 # 毫秒
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
parsed = json.loads(result)
|
| 128 |
+
status = "✅ 成功"
|
| 129 |
+
except:
|
| 130 |
+
status = "❌ 失败"
|
| 131 |
+
|
| 132 |
+
print(f" 处理时间: {duration:.4f}ms")
|
| 133 |
+
print(f" 修复状态: {status}")
|
| 134 |
+
print(f" 结果长度: {len(result)} 字符")
|
| 135 |
+
|
| 136 |
+
def test_memory_efficiency():
|
| 137 |
+
"""测试内存效率"""
|
| 138 |
+
|
| 139 |
+
print("\n🧪 测试内存效率\n")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
import psutil
|
| 143 |
+
process = psutil.Process()
|
| 144 |
+
|
| 145 |
+
# 基线内存
|
| 146 |
+
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
|
| 147 |
+
print(f"基线内存: {baseline_memory:.2f} MB")
|
| 148 |
+
|
| 149 |
+
handler = SSEToolHandler("test-model", stream=False)
|
| 150 |
+
|
| 151 |
+
# 测试大量小JSON
|
| 152 |
+
test_data = '{"command":"echo test","description":"test"}'
|
| 153 |
+
|
| 154 |
+
start_memory = process.memory_info().rss / 1024 / 1024
|
| 155 |
+
|
| 156 |
+
for i in range(100):
|
| 157 |
+
result = handler._fix_tool_arguments(test_data)
|
| 158 |
+
|
| 159 |
+
end_memory = process.memory_info().rss / 1024 / 1024
|
| 160 |
+
|
| 161 |
+
print(f"处理100次后内存: {end_memory:.2f} MB")
|
| 162 |
+
print(f"内存增长: {end_memory - baseline_memory:.2f} MB")
|
| 163 |
+
print(f"平均每次处理: {(end_memory - start_memory) / 100 * 1024:.2f} KB")
|
| 164 |
+
|
| 165 |
+
except ImportError:
|
| 166 |
+
print("psutil 未安装,跳过内存测试")
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
test_optimized_performance()
|
| 170 |
+
test_code_simplification_benefits()
|
| 171 |
+
test_memory_efficiency()
|
| 172 |
+
|
| 173 |
+
print("\n🎯 优化总结:")
|
| 174 |
+
print("✅ 简化了预处理逻辑")
|
| 175 |
+
print("✅ 统一了修复流程")
|
| 176 |
+
print("✅ 减少了代码复杂度")
|
| 177 |
+
print("✅ 保持了修复质量")
|
| 178 |
+
print("✅ 提高了可维护性")
|
tokens.txt.example
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 认证Token配置文件(可选)
|
| 2 |
+
#
|
| 3 |
+
# 说明:
|
| 4 |
+
# 1. 此文件是可选的,如果不需要备用token可以删除此文件
|
| 5 |
+
# 2. 支持两种格式:每行一个token 或 逗号分隔的token
|
| 6 |
+
# 3. 只包含认证用户token (role: "user"),不要添加匿名用户token (role: "guest")
|
| 7 |
+
# 4. 系统会自动去重和验证token有效性
|
| 8 |
+
# 5. 自动跳过空格、换行符和空token
|
| 9 |
+
# 6. 当匿名模式正常工作时,此文件中的token不会被使用
|
| 10 |
+
#
|
| 11 |
+
# 格式1:纯换行分隔
|
| 12 |
+
# token1
|
| 13 |
+
# token2
|
| 14 |
+
# token3
|
| 15 |
+
|
| 16 |
+
# 格式2:纯逗号分隔
|
| 17 |
+
# token1,token2,token3
|
| 18 |
+
|
| 19 |
+
# 格式3:混合格式
|
| 20 |
+
# token1,token2
|
| 21 |
+
# token3
|
| 22 |
+
# token4,token5,token6
|
| 23 |
+
# token7
|
| 24 |
+
|
| 25 |
+
# 请在下方添加您的认证用户token(使用任一格式):
|
| 26 |
+
|