keungliang commited on
Commit
fd21f34
·
verified ·
1 Parent(s): 956e544

Upload 31 files

Browse files
.env.example ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 代理服务配置文件示例
2
+ # 复制此文件为 .env 并根据需要修改配置值
3
+
4
+ # ========== API 基础配置 ==========
5
+ # 客户端认证密钥(您自定义的 API 密钥,用于客户端访问本服务)
6
+ AUTH_TOKEN=sk-your-api-key
7
+
8
+ # 跳过客户端认证(仅开发环境使用)
9
+ SKIP_AUTH_TOKEN=false
10
+
11
+ # ========== Z.ai Token池配置 ==========
12
+ # Token失败阈值(失败多少次后标记为不可用)
13
+ TOKEN_FAILURE_THRESHOLD=3
14
+
15
+ # Token恢复超时时间(秒,失败token在此时间后重新尝试)
16
+ TOKEN_RECOVERY_TIMEOUT=1800
17
+
18
+ # Token健康检查间隔(秒,定期检查token状态)
19
+ TOKEN_HEALTH_CHECK_INTERVAL=300
20
+
21
+ # Z.AI 匿名用户模式
22
+ # false: 使用认证 Token 令牌,失败时自动降级为匿名请求
23
+ # true: 自动从 Z.ai 获取临时访问令牌,避免对话历史共享
24
+ ANONYMOUS_MODE=true
25
+
26
+ # ========== Z.ai 认证token配置(可选) ===========
27
+ # 使用独立的token文件配置(可选)
28
+ # 如果需要认证token,在项目根目录创建 tokens.txt 文件,每行一个token或逗号分隔
29
+ # 如果不需要认证token,想走匿名请求模式,可以注释掉或删除此配置项
30
+ # AUTH_TOKENS_FILE=tokens.txt
31
+
32
+ # ========== LongCat 配置 ==========
33
+ # LongCat passport token(单个token)
34
+ # LONGCAT_PASSPORT_TOKEN=your_passport_token_here
35
+
36
+ # LongCat tokens 文件路径(多个token)
37
+ # LONGCAT_TOKENS_FILE=longcat_tokens.txt
38
+
39
+ # ========== 服务器配置 ==========
40
+ # 服务监听端口
41
+ LISTEN_PORT=8080
42
+
43
+ # 服务名称(用于进程唯一性验证)
44
+ SERVICE_NAME=z-ai2api-server
45
+
46
+ # 调试日志
47
+ DEBUG_LOGGING=false
48
+
49
+ # Function Call 功能开关
50
+ TOOL_SUPPORT=true
51
+
52
+ # 工具调用扫描限制(字符数)
53
+ SCAN_LIMIT=200000
54
+
55
+ # ========== Z.AI 错误码400处理 ==========
56
+
57
+ # 重试次数
58
+ MAX_RETRIES=6
59
+ # 初始重试延迟
60
+ RETRY_DELAY=1
Dockerfile CHANGED
@@ -1 +1,10 @@
1
- FROM sxuancoder/z-ai-api-server:latest
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY .. .
9
+
10
+ CMD ["python", "main.py"]
app/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from app import core, models, utils
5
+
6
+ __all__ = ["core", "models", "utils"]
app/core/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from app.core import config, zai_transformer, openai
5
+
6
+ __all__ = ["config", "zai_transformer", "openai"]
app/core/config.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ from typing import Dict, List, Optional
6
+ from pydantic_settings import BaseSettings
7
+ from app.utils.logger import logger
8
+
9
+
10
+ class Settings(BaseSettings):
11
+ """Application settings"""
12
+
13
+ # API Configuration
14
+ API_ENDPOINT: str = "https://chat.z.ai/api/chat/completions"
15
+ AUTH_TOKEN: str = os.getenv("AUTH_TOKEN", "sk-your-api-key")
16
+
17
+ # 认证token文件路径(可选)
18
+ AUTH_TOKENS_FILE: Optional[str] = os.getenv("AUTH_TOKENS_FILE")
19
+
20
+ # Token池配置
21
+ TOKEN_HEALTH_CHECK_INTERVAL: int = int(os.getenv("TOKEN_HEALTH_CHECK_INTERVAL", "300")) # 5分钟
22
+ TOKEN_FAILURE_THRESHOLD: int = int(os.getenv("TOKEN_FAILURE_THRESHOLD", "3")) # 失败3次后标记为不可用
23
+ TOKEN_RECOVERY_TIMEOUT: int = int(os.getenv("TOKEN_RECOVERY_TIMEOUT", "1800")) # 30分钟后重试失败的token
24
+
25
+ def _load_tokens_from_file(self, file_path: str) -> List[str]:
26
+ """
27
+ 从文件加载token列表
28
+
29
+ 支持多种格式的混合使用:
30
+ 1. 每行一个token(换行分隔)
31
+ 2. 逗号分隔的token
32
+ 3. 混合格式(同时支持换行和逗号分隔)
33
+ """
34
+ tokens = []
35
+ try:
36
+ if os.path.exists(file_path):
37
+ with open(file_path, 'r', encoding='utf-8') as f:
38
+ content = f.read().strip()
39
+
40
+ if not content:
41
+ logger.debug(f"📄 Token文件为空: {file_path}")
42
+ return tokens
43
+
44
+ # 智能解析:同时支持换行和逗号分隔
45
+ # 1. 先按换行符分割处理每一行
46
+ lines = content.split('\n')
47
+
48
+ for line in lines:
49
+ line = line.strip()
50
+ # 跳过空行和注释行
51
+ if not line or line.startswith('#'):
52
+ continue
53
+
54
+ # 2. 检查当前行是否包含逗号分隔
55
+ if ',' in line:
56
+ # 按逗号分割当前行
57
+ comma_tokens = line.split(',')
58
+ for token in comma_tokens:
59
+ token = token.strip()
60
+ if token: # 跳过空token
61
+ tokens.append(token)
62
+ else:
63
+ # 整行作为一个token
64
+ tokens.append(line)
65
+
66
+ logger.info(f"📄 从文件加载了 {len(tokens)} 个token: {file_path}")
67
+ else:
68
+ logger.debug(f"📄 Token文件不存在: {file_path}")
69
+ except Exception as e:
70
+ logger.error(f"❌ 读取token文件失败 {file_path}: {e}")
71
+ return tokens
72
+
73
+ @property
74
+ def auth_token_list(self) -> List[str]:
75
+ """
76
+ 解析认证token列表
77
+
78
+ 从AUTH_TOKENS_FILE指定的文件加载token(如果配置了文件路径)
79
+ """
80
+ # 如果未配置token文件路径,返回空列表
81
+ if not self.AUTH_TOKENS_FILE:
82
+ logger.debug("📄 未配置AUTH_TOKENS_FILE,跳过token文件加载")
83
+ return []
84
+
85
+ # 从文件加载token
86
+ tokens = self._load_tokens_from_file(self.AUTH_TOKENS_FILE)
87
+
88
+ # 去重,保持顺序
89
+ if tokens:
90
+ seen = set()
91
+ unique_tokens = []
92
+ for token in tokens:
93
+ if token not in seen:
94
+ unique_tokens.append(token)
95
+ seen.add(token)
96
+
97
+ # 记录去重信息
98
+ duplicate_count = len(tokens) - len(unique_tokens)
99
+ if duplicate_count > 0:
100
+ logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复token,已自动去重")
101
+
102
+ return unique_tokens
103
+
104
+ return []
105
+
106
+ @property
107
+ def longcat_token_list(self) -> List[str]:
108
+ """
109
+ 解析 LongCat token 列表
110
+
111
+ 从 LONGCAT_TOKENS_FILE 指定的文件加载 token(如果配置了文件路径)
112
+ """
113
+ # 如果未配置token文件路径,返回空列表
114
+ if not self.LONGCAT_TOKENS_FILE:
115
+ logger.debug("📄 未配置LONGCAT_TOKENS_FILE,跳过LongCat token文件加载")
116
+ return []
117
+
118
+ # 从文件加载token
119
+ tokens = self._load_tokens_from_file(self.LONGCAT_TOKENS_FILE)
120
+
121
+ # 去重,保持顺序
122
+ if tokens:
123
+ seen = set()
124
+ unique_tokens = []
125
+ for token in tokens:
126
+ if token not in seen:
127
+ unique_tokens.append(token)
128
+ seen.add(token)
129
+
130
+ # 记录去重信息
131
+ duplicate_count = len(tokens) - len(unique_tokens)
132
+ if duplicate_count > 0:
133
+ logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复LongCat token,已自动去重")
134
+
135
+ return unique_tokens
136
+
137
+ return []
138
+
139
+ # Model Configuration
140
+ PRIMARY_MODEL: str = os.getenv("PRIMARY_MODEL", "GLM-4.5")
141
+ THINKING_MODEL: str = os.getenv("THINKING_MODEL", "GLM-4.5-Thinking")
142
+ SEARCH_MODEL: str = os.getenv("SEARCH_MODEL", "GLM-4.5-Search")
143
+ AIR_MODEL: str = os.getenv("AIR_MODEL", "GLM-4.5-Air")
144
+ GLM46_MODEL: str = os.getenv("GLM46_MODEL", "GLM-4.6")
145
+ GLM46_THINKING_MODEL: str = os.getenv("GLM46_THINKING_MODEL", "GLM-4.6-Thinking")
146
+ GLM46_SEARCH_MODEL: str = os.getenv("GLM46_SEARCH_MODEL", "GLM-4.6-Search")
147
+
148
+
149
+
150
+ # Provider Model Mapping
151
+ @property
152
+ def provider_model_mapping(self) -> Dict[str, str]:
153
+ """模型到提供商的映射"""
154
+ return {
155
+ # Z.AI models
156
+ "GLM-4.5": "zai",
157
+ "GLM-4.5-Thinking": "zai",
158
+ "GLM-4.5-Search": "zai",
159
+ "GLM-4.5-Air": "zai",
160
+ "GLM-4.6": "zai",
161
+ "GLM-4.6-Thinking": "zai",
162
+ "GLM-4.6-Search": "zai",
163
+ # K2Think models
164
+ "MBZUAI-IFM/K2-Think": "k2think",
165
+ # LongCat models
166
+ "LongCat-Flash": "longcat",
167
+ "LongCat": "longcat",
168
+ "LongCat-Search": "longcat",
169
+ }
170
+
171
+ # Server Configuration
172
+ LISTEN_PORT: int = int(os.getenv("LISTEN_PORT", "8080"))
173
+ DEBUG_LOGGING: bool = os.getenv("DEBUG_LOGGING", "true").lower() == "true"
174
+ SERVICE_NAME: str = os.getenv("SERVICE_NAME", "z-ai2api-server")
175
+
176
+ ANONYMOUS_MODE: bool = os.getenv("ANONYMOUS_MODE", "true").lower() == "true"
177
+ TOOL_SUPPORT: bool = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
178
+ SCAN_LIMIT: int = int(os.getenv("SCAN_LIMIT", "200000"))
179
+ SKIP_AUTH_TOKEN: bool = os.getenv("SKIP_AUTH_TOKEN", "false").lower() == "true"
180
+
181
+ # LongCat Configuration
182
+ LONGCAT_PASSPORT_TOKEN: Optional[str] = os.getenv("LONGCAT_PASSPORT_TOKEN")
183
+ LONGCAT_TOKENS_FILE: Optional[str] = os.getenv("LONGCAT_TOKENS_FILE")
184
+
185
+ # Retry Configuration
186
+ MAX_RETRIES: int = int(os.getenv("MAX_RETRIES", "5"))
187
+ RETRY_DELAY: float = float(os.getenv("RETRY_DELAY", "1.0")) # 初始重试延迟(秒)
188
+
189
+ # Browser Headers
190
+ CLIENT_HEADERS: Dict[str, str] = {
191
+ "Content-Type": "application/json",
192
+ "Accept": "application/json, text/event-stream",
193
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0",
194
+ "Accept-Language": "zh-CN",
195
+ "sec-ch-ua": '"Not;A=Brand";v="99", "Microsoft Edge";v="139", "Chromium";v="139"',
196
+ "sec-ch-ua-mobile": "?0",
197
+ "sec-ch-ua-platform": '"Windows"',
198
+ "X-FE-Version": "prod-fe-1.0.70",
199
+ "Origin": "https://chat.z.ai",
200
+ }
201
+
202
+ class Config:
203
+ env_file = ".env"
204
+
205
+
206
+ settings = Settings()
app/core/openai.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import time
5
+ import json
6
+ from typing import List, Dict, Any
7
+ from fastapi import APIRouter, Header, HTTPException
8
+ from fastapi.responses import StreamingResponse, JSONResponse
9
+
10
+ from app.core.config import settings
11
+ from app.models.schemas import OpenAIRequest, Message, ModelsResponse, Model, OpenAIResponse, Choice, Usage
12
+ from app.utils.logger import get_logger
13
+ from app.providers import get_provider_router
14
+ from app.utils.token_pool import get_token_pool
15
+
16
+ logger = get_logger()
17
+ router = APIRouter()
18
+
19
+ # 全局提供商路由器实例
20
+ provider_router = None
21
+
22
+
23
+ def get_provider_router_instance():
24
+ """获取提供商路由器实例"""
25
+ global provider_router
26
+ if provider_router is None:
27
+ provider_router = get_provider_router()
28
+ return provider_router
29
+
30
+
31
+ def create_chunk(chat_id: str, model: str, delta: Dict[str, Any], finish_reason: str = None) -> Dict[str, Any]:
32
+ """创建标准的 OpenAI chunk 结构"""
33
+ return {
34
+ "choices": [{
35
+ "delta": delta,
36
+ "finish_reason": finish_reason,
37
+ "index": 0,
38
+ "logprobs": None,
39
+ }],
40
+ "created": int(time.time()),
41
+ "id": chat_id,
42
+ "model": model,
43
+ "object": "chat.completion.chunk",
44
+ "system_fingerprint": "fp_zai_001",
45
+ }
46
+
47
+
48
+ async def handle_non_stream_response(stream_response, request: OpenAIRequest) -> JSONResponse:
49
+ """处理非流式响应"""
50
+ logger.info("📄 开始处理非流式响应")
51
+
52
+ # 收集所有流式数据
53
+ full_content = []
54
+ async for chunk_data in stream_response():
55
+ if chunk_data.startswith("data: "):
56
+ chunk_str = chunk_data[6:].strip()
57
+ if chunk_str and chunk_str != "[DONE]":
58
+ try:
59
+ chunk = json.loads(chunk_str)
60
+ if "choices" in chunk and chunk["choices"]:
61
+ choice = chunk["choices"][0]
62
+ if "delta" in choice and "content" in choice["delta"]:
63
+ content = choice["delta"]["content"]
64
+ if content:
65
+ full_content.append(content)
66
+ except json.JSONDecodeError:
67
+ continue
68
+
69
+ # 构建响应
70
+ response_data = OpenAIResponse(
71
+ id=f"chatcmpl-{int(time.time())}",
72
+ object="chat.completion",
73
+ created=int(time.time()),
74
+ model=request.model,
75
+ choices=[Choice(
76
+ index=0,
77
+ message=Message(
78
+ role="assistant",
79
+ content="".join(full_content),
80
+ tool_calls=None
81
+ ),
82
+ finish_reason="stop"
83
+ )],
84
+ usage=Usage(
85
+ prompt_tokens=0,
86
+ completion_tokens=0,
87
+ total_tokens=0
88
+ )
89
+ )
90
+
91
+ logger.info("✅ 非流式响应处理完成")
92
+ return JSONResponse(content=response_data.model_dump(exclude_none=True))
93
+
94
+
95
+ @router.get("/v1/models")
96
+ async def list_models():
97
+ """List available models from all providers"""
98
+ try:
99
+ router_instance = get_provider_router_instance()
100
+ models_data = router_instance.get_models_list()
101
+ return JSONResponse(content=models_data)
102
+ except Exception as e:
103
+ logger.error(f"❌ 获取模型列表失败: {e}")
104
+ # 返回默认模型列表作为后备
105
+ current_time = int(time.time())
106
+ fallback_response = ModelsResponse(
107
+ data=[
108
+ Model(id=settings.PRIMARY_MODEL, created=current_time, owned_by="z.ai"),
109
+ Model(id=settings.THINKING_MODEL, created=current_time, owned_by="z.ai"),
110
+ Model(id=settings.SEARCH_MODEL, created=current_time, owned_by="z.ai"),
111
+ Model(id=settings.AIR_MODEL, created=current_time, owned_by="z.ai"),
112
+ ]
113
+ )
114
+ return fallback_response
115
+
116
+
117
+ @router.post("/v1/chat/completions")
118
+ async def chat_completions(request: OpenAIRequest, authorization: str = Header(...)):
119
+ """Handle chat completion requests with multi-provider architecture"""
120
+ role = request.messages[0].role if request.messages else "unknown"
121
+ logger.info(f"😶‍🌫️ 收到客户端请求 - 模型: {request.model}, 流式: {request.stream}, 消息数: {len(request.messages)}, 角色: {role}, 工具数: {len(request.tools) if request.tools else 0}")
122
+
123
+ try:
124
+ # Validate API key (skip if SKIP_AUTH_TOKEN is enabled)
125
+ if not settings.SKIP_AUTH_TOKEN:
126
+ if not authorization.startswith("Bearer "):
127
+ raise HTTPException(status_code=401, detail="Missing or invalid Authorization header")
128
+
129
+ api_key = authorization[7:]
130
+ if api_key != settings.AUTH_TOKEN:
131
+ raise HTTPException(status_code=401, detail="Invalid API key")
132
+
133
+ # 使用多提供商路由器处理请求
134
+ router_instance = get_provider_router_instance()
135
+ result = await router_instance.route_request(request)
136
+
137
+ # 检查是否有错误
138
+ if isinstance(result, dict) and "error" in result:
139
+ error_info = result["error"]
140
+ if error_info.get("code") == "model_not_found":
141
+ raise HTTPException(status_code=404, detail=error_info["message"])
142
+ else:
143
+ raise HTTPException(status_code=500, detail=error_info["message"])
144
+
145
+ # 处理响应
146
+ if request.stream:
147
+ # 流式响应
148
+ if hasattr(result, '__aiter__'):
149
+ # 结果是异步生成器
150
+ return StreamingResponse(
151
+ result,
152
+ media_type="text/event-stream",
153
+ headers={
154
+ "Cache-Control": "no-cache",
155
+ "Connection": "keep-alive",
156
+ "Access-Control-Allow-Origin": "*",
157
+ }
158
+ )
159
+ else:
160
+ # 结果是字典,可能包含错误
161
+ raise HTTPException(status_code=500, detail="Expected streaming response but got non-streaming result")
162
+ else:
163
+ # 非流式响应
164
+ if isinstance(result, dict):
165
+ return JSONResponse(content=result)
166
+ else:
167
+ # 如果是异步生成器,需要收集所有内容
168
+ return await handle_non_stream_response(result, request)
169
+
170
+ except HTTPException:
171
+ # 重新抛出 HTTP 异常
172
+ raise
173
+ except Exception as e:
174
+ logger.error(f"❌ 请求处理失败: {e}")
175
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
176
+
177
+
178
+ # Token pool management endpoints
179
+
180
+
181
+ @router.get("/v1/token-pool/status")
182
+ async def get_token_pool_status():
183
+ """获取token池状态信息"""
184
+ try:
185
+ token_pool = get_token_pool()
186
+ if not token_pool:
187
+ return {
188
+ "status": "disabled",
189
+ "message": "Token池未初始化,当前仅使用匿名模式",
190
+ "anonymous_mode": settings.ANONYMOUS_MODE,
191
+ "auth_tokens_file": settings.AUTH_TOKENS_FILE,
192
+ "auth_tokens_configured": len(settings.auth_token_list) > 0
193
+ }
194
+
195
+ pool_status = token_pool.get_pool_status()
196
+ return {
197
+ "status": "active",
198
+ "pool_info": pool_status,
199
+ "config": {
200
+ "anonymous_mode": settings.ANONYMOUS_MODE,
201
+ "failure_threshold": settings.TOKEN_FAILURE_THRESHOLD,
202
+ "recovery_timeout": settings.TOKEN_RECOVERY_TIMEOUT,
203
+ "health_check_interval": settings.TOKEN_HEALTH_CHECK_INTERVAL
204
+ }
205
+ }
206
+ except Exception as e:
207
+ logger.error(f"获取token池状态失败: {e}")
208
+ raise HTTPException(status_code=500, detail=f"Failed to get token pool status: {str(e)}")
209
+
210
+
211
+ @router.post("/v1/token-pool/health-check")
212
+ async def trigger_health_check():
213
+ """手动触发token池健康检查"""
214
+ try:
215
+ token_pool = get_token_pool()
216
+ if not token_pool:
217
+ raise HTTPException(status_code=404, detail="Token池未初始化")
218
+
219
+ start_time = time.time()
220
+ logger.info("🔍 API触发Token池健康检查...")
221
+ await token_pool.health_check_all()
222
+ duration = time.time() - start_time
223
+
224
+ pool_status = token_pool.get_pool_status()
225
+ total_tokens = pool_status['total_tokens']
226
+ healthy_tokens = sum(1 for token_info in pool_status['tokens'] if token_info['is_healthy'])
227
+
228
+ response = {
229
+ "status": "completed",
230
+ "message": f"健康检查已完成,耗时 {duration:.2f} 秒",
231
+ "summary": {
232
+ "total_tokens": total_tokens,
233
+ "healthy_tokens": healthy_tokens,
234
+ "unhealthy_tokens": total_tokens - healthy_tokens,
235
+ "health_rate": f"{(healthy_tokens/total_tokens*100):.1f}%" if total_tokens > 0 else "0%",
236
+ "duration_seconds": round(duration, 2)
237
+ },
238
+ "pool_info": pool_status
239
+ }
240
+
241
+ logger.info(f"✅ API健康检查完成: {healthy_tokens}/{total_tokens} 个token健康")
242
+ return response
243
+ except Exception as e:
244
+ logger.error(f"健康检查失败: {e}")
245
+ raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")
246
+
247
+
248
+ @router.post("/v1/token-pool/update")
249
+ async def update_token_pool_endpoint(tokens: List[str]):
250
+ """动态更新token池"""
251
+ try:
252
+ from app.utils.token_pool import update_token_pool
253
+
254
+ valid_tokens = [token.strip() for token in tokens if token.strip()]
255
+ if not valid_tokens:
256
+ raise HTTPException(status_code=400, detail="至少需要提供一个有效的token")
257
+
258
+ update_token_pool(valid_tokens)
259
+ token_pool = get_token_pool()
260
+
261
+ return {
262
+ "status": "updated",
263
+ "message": f"Token池已更新,共 {len(valid_tokens)} 个token",
264
+ "pool_info": token_pool.get_pool_status() if token_pool else None
265
+ }
266
+ except Exception as e:
267
+ logger.error(f"更新token池失败: {e}")
268
+ raise HTTPException(status_code=500, detail=f"Failed to update token pool: {str(e)}")
app/core/zai_transformer.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import json
5
+ import time
6
+ import uuid
7
+ import random
8
+ from datetime import datetime
9
+ from typing import Dict, List, Any, Optional, Generator, AsyncGenerator
10
+ import httpx
11
+ import asyncio
12
+
13
+ from app.core.config import settings
14
+ from app.utils.logger import get_logger
15
+ from app.utils.token_pool import get_token_pool, initialize_token_pool
16
+ from app.utils.user_agent import get_random_user_agent
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ def get_zai_dynamic_headers(chat_id: str = "") -> Dict[str, str]:
22
+ """
23
+ 生成 Z.AI 特定的动态浏览器 headers,包含随机 User-Agent
24
+ 使用通用的 UserAgent 工具,但添加 Z.AI 特定的业务逻辑
25
+
26
+ Args:
27
+ chat_id: 聊天 ID,用于生成正确的 Referer
28
+
29
+ Returns:
30
+ Dict[str, str]: 包含 Z.AI 特定配置的 headers
31
+ """
32
+ # 随机选择浏览器类型,偏向Chrome和Edge
33
+ browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"]
34
+ browser_type = random.choice(browser_choices)
35
+
36
+ user_agent = get_random_user_agent(browser_type)
37
+
38
+ # 提取版本信息
39
+ chrome_version = "139"
40
+ edge_version = "139"
41
+
42
+ if "Chrome/" in user_agent:
43
+ try:
44
+ chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
45
+ except:
46
+ pass
47
+
48
+ if "Edg/" in user_agent:
49
+ try:
50
+ edge_version = user_agent.split("Edg/")[1].split(".")[0]
51
+ sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
52
+ except:
53
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
54
+ elif "Firefox/" in user_agent:
55
+ sec_ch_ua = None # Firefox不使用sec-ch-ua
56
+ else:
57
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
58
+
59
+ # Z.AI 特定的 headers
60
+ headers = {
61
+ "Content-Type": "application/json",
62
+ "Accept": "application/json, text/event-stream",
63
+ "User-Agent": user_agent,
64
+ "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
65
+ "X-FE-Version": "prod-fe-1.0.79",
66
+ "Origin": "https://chat.z.ai",
67
+ }
68
+
69
+ # 添加浏览器特定的 sec-ch-ua headers
70
+ if sec_ch_ua:
71
+ headers["sec-ch-ua"] = sec_ch_ua
72
+ headers["sec-ch-ua-mobile"] = "?0"
73
+ headers["sec-ch-ua-platform"] = '"Windows"'
74
+
75
+ # 根据 chat_id 设置 Referer
76
+ if chat_id:
77
+ headers["Referer"] = f"https://chat.z.ai/c/{chat_id}"
78
+ else:
79
+ headers["Referer"] = "https://chat.z.ai/"
80
+
81
+ return headers
82
+
83
+
84
+
85
+ def generate_uuid() -> str:
86
+ """生成UUID v4"""
87
+ return str(uuid.uuid4())
88
+
89
+
90
+ def get_auth_token_sync() -> str:
91
+ """同步获取认证令牌(用于非异步场景)"""
92
+ # 如果启用匿名模式,只尝试获取访客令牌
93
+ if settings.ANONYMOUS_MODE:
94
+ try:
95
+ headers = get_zai_dynamic_headers()
96
+ with httpx.Client() as client:
97
+ response = client.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10.0)
98
+ if response.status_code == 200:
99
+ data = response.json()
100
+ token = data.get("token", "")
101
+ if token:
102
+ logger.debug(f"获取访客令牌成功: {token[:20]}...")
103
+ return token
104
+ except Exception as e:
105
+ logger.warning(f"获取访客令牌失败: {e}")
106
+
107
+ # 匿名模式下,如果获取访客令牌失败,直接返回空
108
+ logger.error("❌ 匿名模式下获取访客令牌失败")
109
+ return ""
110
+
111
+ # 非匿名模式:首先使用token池获取备份令牌
112
+ token_pool = get_token_pool()
113
+ if token_pool:
114
+ token = token_pool.get_next_token()
115
+ if token:
116
+ logger.debug(f"从token池获取令牌: {token[:20]}...")
117
+ return token
118
+
119
+ # 如果没有备份token,尝试降级到匿名模式
120
+ logger.warning("⚠️ 没有可用的备份token,尝试降级到匿名模式...")
121
+ try:
122
+ headers = get_zai_dynamic_headers()
123
+ with httpx.Client() as client:
124
+ response = client.get("https://chat.z.ai/api/v1/auths/", headers=headers, timeout=10.0)
125
+ if response.status_code == 200:
126
+ data = response.json()
127
+ token = data.get("token", "")
128
+ if token:
129
+ logger.info(f"✅ 降级到匿名模式成功: {token[:20]}...")
130
+ return token
131
+ except Exception as e:
132
+ logger.warning(f"降级到匿名模式失败: {e}")
133
+
134
+ # 没有可用的token
135
+ logger.error("❌ 所有认证方式都失败了")
136
+ return ""
137
+
138
+
139
+ class ZAITransformer:
140
+ """ZAI转换器类"""
141
+
142
+ def __init__(self):
143
+ """初始化转换器"""
144
+ self.name = "zai"
145
+ self.base_url = "https://chat.z.ai"
146
+ self.api_url = settings.API_ENDPOINT
147
+ self.auth_url = f"{self.base_url}/api/v1/auths/"
148
+
149
+ # 模型映射
150
+ self.model_mapping = {
151
+ settings.PRIMARY_MODEL: "0727-360B-API", # GLM-4.5
152
+ settings.THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking
153
+ settings.SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search
154
+ settings.AIR_MODEL: "0727-106B-API", # GLM-4.5-Air
155
+ }
156
+
157
+ async def get_token(self) -> str:
158
+ """异步获取认证令牌"""
159
+ # 如果启用匿名模式,只尝试获取访客令牌
160
+ if settings.ANONYMOUS_MODE:
161
+ try:
162
+ headers = get_zai_dynamic_headers()
163
+ async with httpx.AsyncClient() as client:
164
+ response = await client.get(self.auth_url, headers=headers, timeout=10.0)
165
+ if response.status_code == 200:
166
+ data = response.json()
167
+ token = data.get("token", "")
168
+ if token:
169
+ logger.debug(f"获取访客令牌成功: {token[:20]}...")
170
+ return token
171
+ except Exception as e:
172
+ logger.warning(f"异步获取访客令牌失败: {e}")
173
+
174
+ # 匿名模式下,如果获取访客令牌失败,直接返回空
175
+ logger.error("❌ 匿名模式下获取访客令牌失败")
176
+ return ""
177
+
178
+ # 非匿名模式:首先使用token池获取备份令牌
179
+ token_pool = get_token_pool()
180
+ if token_pool:
181
+ token = token_pool.get_next_token()
182
+ if token:
183
+ logger.debug(f"从token池获取令牌: {token[:20]}...")
184
+ return token
185
+
186
+ # 如果没有备份token,尝试降级到匿名模式
187
+ logger.warning("⚠️ 没有可用的备份token,尝试降级到匿名模式...")
188
+ try:
189
+ headers = get_zai_dynamic_headers()
190
+ async with httpx.AsyncClient() as client:
191
+ response = await client.get(self.auth_url, headers=headers, timeout=10.0)
192
+ if response.status_code == 200:
193
+ data = response.json()
194
+ token = data.get("token", "")
195
+ if token:
196
+ logger.info(f"✅ 降级到匿名模式成功: {token[:20]}...")
197
+ return token
198
+ except Exception as e:
199
+ logger.warning(f"降级到匿名模式失败: {e}")
200
+
201
+ # 没有可用的token
202
+ logger.error("❌ 所有认证方式都失败了")
203
+ return ""
204
+
205
+ def mark_token_success(self, token: str):
206
+ """标记token使用成功"""
207
+ token_pool = get_token_pool()
208
+ if token_pool:
209
+ token_pool.mark_token_success(token)
210
+
211
+ def mark_token_failure(self, token: str, error: Exception = None):
212
+ """标记token使用失败"""
213
+ token_pool = get_token_pool()
214
+ if token_pool:
215
+ token_pool.mark_token_failure(token, error)
216
+
217
+ async def transform_request_in(self, request: Dict[str, Any]) -> Dict[str, Any]:
218
+ """
219
+ 转换OpenAI请求为z.ai格式
220
+ 整合现有功能:模型映射、MCP服务器等
221
+ """
222
+ logger.info(f"🔄 开始转换 OpenAI 请求到 Z.AI 格式: {request.get('model', settings.PRIMARY_MODEL)} -> Z.AI")
223
+
224
+ # 获取认证令牌
225
+ token = await self.get_token()
226
+ logger.debug(f" 使用令牌: {token[:20] if token else 'None'}...")
227
+
228
+ # 检查token是否有效
229
+ if not token:
230
+ # 提供详细的配置建议
231
+ error_msg = "❌ 无法获取有效的认证令牌"
232
+ suggestions = []
233
+
234
+ if not settings.ANONYMOUS_MODE:
235
+ suggestions.append("1. 设置 ANONYMOUS_MODE=true 启用匿名模式")
236
+
237
+ if not settings.AUTH_TOKENS_FILE:
238
+ suggestions.append("2. 配置 AUTH_TOKENS_FILE 并创建对应的token文件")
239
+ elif settings.AUTH_TOKENS_FILE and not settings.auth_token_list:
240
+ suggestions.append(f"3. 检查token文件 '{settings.AUTH_TOKENS_FILE}' 是否存在且包含有效token")
241
+
242
+ if suggestions:
243
+ error_msg += "\n建议的解决方案:\n" + "\n".join(suggestions)
244
+
245
+ logger.error(error_msg)
246
+ raise Exception("无法获取有效的认证令牌,请检查配置")
247
+
248
+ # 确定请求的模型特性
249
+ requested_model = request.get("model", settings.PRIMARY_MODEL)
250
+ is_thinking = requested_model == settings.THINKING_MODEL or request.get("reasoning", False)
251
+ is_search = requested_model == settings.SEARCH_MODEL
252
+ is_air = requested_model == settings.AIR_MODEL
253
+
254
+ # 获取上游模型ID(使用模型映射)
255
+ upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
256
+ logger.debug(f" 模型映射: {requested_model} -> {upstream_model_id}")
257
+
258
+ # 处理消息列表
259
+ logger.debug(f" 开始处理 {len(request.get('messages', []))} 条消息")
260
+ messages = []
261
+ for idx, orig_msg in enumerate(request.get("messages", [])):
262
+ msg = orig_msg.copy()
263
+
264
+ # 处理system角色转换
265
+ if msg.get("role") == "system":
266
+
267
+ msg["role"] = "user"
268
+ content = msg.get("content")
269
+
270
+ if isinstance(content, list):
271
+ msg["content"] = [
272
+ {"type": "text", "text": "This is a system command, you must enforce compliance."}
273
+ ] + content
274
+ elif isinstance(content, str):
275
+ msg["content"] = f"This is a system command, you must enforce compliance.{content}"
276
+
277
+ # 处理user角色的图片内容
278
+ elif msg.get("role") == "user":
279
+ content = msg.get("content")
280
+ if isinstance(content, list):
281
+ new_content = []
282
+ for part_idx, part in enumerate(content):
283
+ # 处理图片URL(支持base64和http URL)
284
+ if (
285
+ part.get("type") == "image_url"
286
+ and part.get("image_url", {}).get("url")
287
+ and isinstance(part["image_url"]["url"], str)
288
+ ):
289
+ logger.debug(f" 消息[{idx}]内容[{part_idx}]: 检测到图片URL")
290
+ # 直接传递图片内容
291
+ new_content.append(part)
292
+ else:
293
+ new_content.append(part)
294
+ msg["content"] = new_content
295
+
296
+ # 处理assistant消息中的reasoning_content
297
+ elif msg.get("role") == "assistant" and msg.get("reasoning_content"):
298
+
299
+ # 如果有reasoning_content,保留它
300
+ pass
301
+
302
+ messages.append(msg)
303
+
304
+ # 构建MCP服务器列表
305
+ mcp_servers = []
306
+ if is_search:
307
+ mcp_servers.append("deep-web-search")
308
+ logger.info(f"🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
309
+ else:
310
+ logger.debug(f" 非搜索模型,不添加 MCP 服务器")
311
+
312
+ logger.debug(f" MCP服务器列表: {mcp_servers}")
313
+
314
+ # 构建上游请求体
315
+ chat_id = generate_uuid()
316
+
317
+ body = {
318
+ "stream": True, # 总是使用流式
319
+ "model": upstream_model_id, # 使用映射后的模型ID
320
+ "messages": messages,
321
+ "params": {},
322
+ "features": {
323
+ "image_generation": False,
324
+ "web_search": is_search,
325
+ "auto_web_search": is_search,
326
+ "preview_mode": False,
327
+ "flags": [],
328
+ "features": [],
329
+ "enable_thinking": is_thinking,
330
+ },
331
+ "background_tasks": {
332
+ "title_generation": False,
333
+ "tags_generation": False,
334
+ },
335
+ "mcp_servers": mcp_servers, # 保留MCP服务器支持
336
+ "variables": {
337
+ "{{USER_NAME}}": "Guest",
338
+ "{{USER_LOCATION}}": "Unknown",
339
+ "{{CURRENT_DATETIME}}": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
340
+ "{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
341
+ "{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
342
+ "{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
343
+ "{{CURRENT_TIMEZONE}}": "Asia/Shanghai", # 使用更合适的时区
344
+ "{{USER_LANGUAGE}}": "zh-CN",
345
+ },
346
+ "model_item": {
347
+ "id": upstream_model_id,
348
+ "name": requested_model,
349
+ "owned_by": "z.ai"
350
+ },
351
+ "chat_id": chat_id,
352
+ "id": generate_uuid(),
353
+ }
354
+
355
+ # 处理工具支持
356
+ if settings.TOOL_SUPPORT and not is_thinking and request.get("tools"):
357
+ body["tools"] = request["tools"]
358
+ logger.info(f"启用工具支持: {len(request['tools'])} 个工具")
359
+ else:
360
+ body["tools"] = None
361
+
362
+ # 构建请求配置
363
+ dynamic_headers = get_zai_dynamic_headers(chat_id)
364
+
365
+ config = {
366
+ "url": self.api_url, # 使用原始URL
367
+ "headers": {
368
+ **dynamic_headers, # 使用动态生成的headers
369
+ "Authorization": f"Bearer {token}",
370
+ "Cache-Control": "no-cache",
371
+ "Connection": "keep-alive",
372
+ "Pragma": "no-cache",
373
+ "Sec-Fetch-Dest": "empty",
374
+ "Sec-Fetch-Mode": "cors",
375
+ "Sec-Fetch-Site": "same-origin",
376
+ },
377
+ }
378
+
379
+ logger.info("✅ 请求转换完成")
380
+
381
+ # 记录关键的请求信息用于调试
382
+ logger.debug(f" 📋 发送到Z.AI的关键信息:")
383
+ logger.debug(f" - 上游模型: {body['model']}")
384
+ logger.debug(f" - MCP服务器: {body['mcp_servers']}")
385
+ logger.debug(f" - web_search: {body['features']['web_search']}")
386
+ logger.debug(f" - auto_web_search: {body['features']['auto_web_search']}")
387
+ logger.debug(f" - 消息数量: {len(body['messages'])}")
388
+ tools_count = len(body.get('tools') or [])
389
+ logger.debug(f" - 工具数量: {tools_count}")
390
+
391
+ # 返回转换后的请求数据
392
+ return {
393
+ "body": body,
394
+ "config": config,
395
+ "token": token
396
+ }
397
+
398
+ async def transform_response_out(
399
+ self, response_stream: Generator, context: Dict[str, Any]
400
+ ) -> AsyncGenerator[str, None]:
401
+ """
402
+ 转换z.ai响应为OpenAI格式
403
+ 支持流式和非流式输出
404
+ """
405
+ is_stream = context.get("req", {}).get("body", {}).get("stream", True)
406
+
407
+ # 初始化结果对象(用于非流式)
408
+ result = {
409
+ "id": "",
410
+ "choices": [
411
+ {
412
+ "finish_reason": None,
413
+ "index": 0,
414
+ "message": {
415
+ "content": "",
416
+ "role": "assistant",
417
+ },
418
+ }
419
+ ],
420
+ "created": int(time.time()),
421
+ "model": context.get("req", {}).get("body", {}).get("model", ""),
422
+ "object": "chat.completion",
423
+ "usage": {
424
+ "completion_tokens": 0,
425
+ "prompt_tokens": 0,
426
+ "total_tokens": 0,
427
+ },
428
+ }
429
+
430
+ # 状态变量
431
+ current_id = ""
432
+ current_model = context.get("req", {}).get("body", {}).get("model", "")
433
+ has_tool_call = False
434
+ tool_args = ""
435
+ tool_id = ""
436
+ tool_call_usage = None
437
+ content_index = 0
438
+ has_thinking = False
439
+
440
+ async for line in response_stream:
441
+ if not line.strip():
442
+ continue
443
+
444
+ if line.startswith("data:"):
445
+ chunk_str = line[5:].strip()
446
+ if not chunk_str:
447
+ continue
448
+
449
+ try:
450
+ chunk = json.loads(chunk_str)
451
+
452
+ if chunk.get("type") == "chat:completion":
453
+ data = chunk.get("data", {})
454
+
455
+ # 保存ID和模型信息
456
+ if data.get("id"):
457
+ current_id = data["id"]
458
+ if data.get("model"):
459
+ current_model = data["model"]
460
+
461
+ # 处理不同阶段
462
+ phase = data.get("phase")
463
+
464
+ if phase == "tool_call":
465
+ # 处理工具调用
466
+ if not has_tool_call:
467
+ has_tool_call = True
468
+
469
+ if is_stream:
470
+ # 发送初始角色
471
+ role_chunk = {
472
+ "choices": [
473
+ {
474
+ "delta": {"role": "assistant"},
475
+ "finish_reason": None,
476
+ "index": 0,
477
+ }
478
+ ],
479
+ "created": int(time.time()),
480
+ "id": current_id,
481
+ "model": current_model,
482
+ "object": "chat.completion.chunk",
483
+ }
484
+ yield f"data: {json.dumps(role_chunk)}\n\n"
485
+
486
+ # 处理工具调用块
487
+ tool_call_id = data.get("tool_call", {}).get("id", "")
488
+ tool_name = data.get("tool_call", {}).get("name", "")
489
+ delta_args = data.get("delta_tool_call", {}).get("arguments", "")
490
+
491
+ if tool_call_id and tool_call_id != tool_id:
492
+ # 新工具调用
493
+ if tool_id and is_stream:
494
+ # 关闭前一个工具调用
495
+ close_chunk = {
496
+ "choices": [
497
+ {
498
+ "delta": {
499
+ "tool_calls": [
500
+ {"index": content_index, "function": {"arguments": ""}}
501
+ ]
502
+ },
503
+ "finish_reason": None,
504
+ "index": 0,
505
+ }
506
+ ],
507
+ "created": int(time.time()),
508
+ "id": current_id,
509
+ "model": current_model,
510
+ "object": "chat.completion.chunk",
511
+ }
512
+ yield f"data: {json.dumps(close_chunk)}\n\n"
513
+ content_index += 1
514
+
515
+ tool_id = tool_call_id
516
+ tool_args = ""
517
+
518
+ if is_stream:
519
+ # 发送新工具调用
520
+ new_tool_chunk = {
521
+ "choices": [
522
+ {
523
+ "delta": {
524
+ "tool_calls": [
525
+ {
526
+ "index": content_index,
527
+ "id": tool_call_id,
528
+ "type": "function",
529
+ "function": {"name": tool_name, "arguments": ""},
530
+ }
531
+ ]
532
+ },
533
+ "finish_reason": None,
534
+ "index": 0,
535
+ }
536
+ ],
537
+ "created": int(time.time()),
538
+ "id": current_id,
539
+ "model": current_model,
540
+ "object": "chat.completion.chunk",
541
+ }
542
+ yield f"data: {json.dumps(new_tool_chunk)}\n\n"
543
+
544
+ # 处理参数增量
545
+ if delta_args:
546
+ tool_args += delta_args
547
+ if is_stream:
548
+ args_chunk = {
549
+ "choices": [
550
+ {
551
+ "delta": {
552
+ "tool_calls": [
553
+ {
554
+ "index": content_index,
555
+ "function": {"arguments": delta_args},
556
+ }
557
+ ]
558
+ },
559
+ "finish_reason": None,
560
+ "index": 0,
561
+ }
562
+ ],
563
+ "created": int(time.time()),
564
+ "id": current_id,
565
+ "model": current_model,
566
+ "object": "chat.completion.chunk",
567
+ }
568
+ yield f"data: {json.dumps(args_chunk)}\n\n"
569
+
570
+ elif phase == "thinking":
571
+ # 处理思考内容
572
+ if not has_thinking:
573
+ has_thinking = True
574
+ # 初始化thinking字段
575
+ if not is_stream:
576
+ result["choices"][0]["message"]["thinking"] = {"content": ""}
577
+
578
+ if is_stream:
579
+ # 发送初始角色
580
+ role_chunk = {
581
+ "choices": [
582
+ {
583
+ "delta": {"role": "assistant"},
584
+ "finish_reason": None,
585
+ "index": 0,
586
+ }
587
+ ],
588
+ "created": int(time.time()),
589
+ "id": current_id,
590
+ "model": current_model,
591
+ "object": "chat.completion.chunk",
592
+ }
593
+ yield f"data: {json.dumps(role_chunk)}\n\n"
594
+
595
+ delta_content = data.get("delta_content", "")
596
+ if delta_content:
597
+ # 处理思考内容格式
598
+ if delta_content.startswith("<details"):
599
+ content = (
600
+ delta_content.split("</summary>\n>")[-1].strip()
601
+ if "</summary>\n>" in delta_content
602
+ else delta_content
603
+ )
604
+ else:
605
+ content = delta_content
606
+
607
+ if is_stream:
608
+ thinking_chunk = {
609
+ "choices": [
610
+ {
611
+ "delta": {"thinking": {"content": content}},
612
+ "finish_reason": None,
613
+ "index": 0,
614
+ }
615
+ ],
616
+ "created": int(time.time()),
617
+ "id": current_id,
618
+ "model": current_model,
619
+ "object": "chat.completion.chunk",
620
+ }
621
+ yield f"data: {json.dumps(thinking_chunk)}\n\n"
622
+ else:
623
+ result["choices"][0]["message"]["thinking"]["content"] += content
624
+
625
+ elif phase == "answer":
626
+ # 处理答案内容
627
+ edit_content = data.get("edit_content", "")
628
+ delta_content = data.get("delta_content", "")
629
+
630
+ # 处理思考结束和答案开始
631
+ if edit_content and "</details>\n" in edit_content:
632
+ if has_thinking:
633
+ signature = str(int(time.time() * 1000))
634
+
635
+ if is_stream:
636
+ # 发送思考签名
637
+ sig_chunk = {
638
+ "choices": [
639
+ {
640
+ "delta": {
641
+ "role": "assistant",
642
+ "thinking": {"content": "", "signature": signature},
643
+ },
644
+ "finish_reason": None,
645
+ "index": 0,
646
+ }
647
+ ],
648
+ "created": int(time.time()),
649
+ "id": current_id,
650
+ "model": current_model,
651
+ "object": "chat.completion.chunk",
652
+ }
653
+ yield f"data: {json.dumps(sig_chunk)}\n\n"
654
+ content_index += 1
655
+ else:
656
+ result["choices"][0]["message"]["thinking"]["signature"] = signature
657
+
658
+ # 提取答案内容
659
+ content_after = edit_content.split("</details>\n")[-1]
660
+ if content_after:
661
+ if is_stream:
662
+ content_chunk = {
663
+ "choices": [
664
+ {
665
+ "delta": {"role": "assistant", "content": content_after},
666
+ "finish_reason": None,
667
+ "index": 0,
668
+ }
669
+ ],
670
+ "created": int(time.time()),
671
+ "id": current_id,
672
+ "model": current_model,
673
+ "object": "chat.completion.chunk",
674
+ }
675
+ yield f"data: {json.dumps(content_chunk)}\n\n"
676
+ else:
677
+ result["choices"][0]["message"]["content"] += content_after
678
+
679
+ # 处理增量内容
680
+ elif delta_content:
681
+ if is_stream:
682
+ # 如果还没有发送角色
683
+ if not has_thinking and not has_tool_call:
684
+ role_chunk = {
685
+ "choices": [
686
+ {
687
+ "delta": {"role": "assistant"},
688
+ "finish_reason": None,
689
+ "index": 0,
690
+ }
691
+ ],
692
+ "created": int(time.time()),
693
+ "id": current_id,
694
+ "model": current_model,
695
+ "object": "chat.completion.chunk",
696
+ }
697
+ yield f"data: {json.dumps(role_chunk)}\n\n"
698
+
699
+ content_chunk = {
700
+ "choices": [
701
+ {
702
+ "delta": {"role": "assistant", "content": delta_content},
703
+ "finish_reason": None,
704
+ "index": 0,
705
+ }
706
+ ],
707
+ "created": int(time.time()),
708
+ "id": current_id,
709
+ "model": current_model,
710
+ "object": "chat.completion.chunk",
711
+ }
712
+ yield f"data: {json.dumps(content_chunk)}\n\n"
713
+ else:
714
+ result["choices"][0]["message"]["content"] += delta_content
715
+
716
+ # 处理完成
717
+ if data.get("usage"):
718
+ usage = data["usage"]
719
+ if is_stream:
720
+ finish_chunk = {
721
+ "choices": [
722
+ {
723
+ "delta": {"role": "assistant", "content": ""},
724
+ "finish_reason": "stop",
725
+ "index": 0,
726
+ }
727
+ ],
728
+ "usage": usage,
729
+ "created": int(time.time()),
730
+ "id": current_id,
731
+ "model": current_model,
732
+ "object": "chat.completion.chunk",
733
+ }
734
+ yield f"data: {json.dumps(finish_chunk)}\n\n"
735
+ yield "data: [DONE]\n\n"
736
+ else:
737
+ result["id"] = current_id
738
+ result["model"] = current_model
739
+ result["usage"] = usage
740
+ result["choices"][0]["finish_reason"] = "stop"
741
+
742
+ elif phase == "other":
743
+ # 处理其他阶段(可能包含usage信息)
744
+ if data.get("usage"):
745
+ tool_call_usage = data["usage"]
746
+ if has_tool_call and is_stream:
747
+ # 关闭最后一个工具调用并发送完成
748
+ if tool_id:
749
+ close_chunk = {
750
+ "choices": [
751
+ {
752
+ "delta": {
753
+ "tool_calls": [
754
+ {"index": content_index, "function": {"arguments": ""}}
755
+ ]
756
+ },
757
+ "finish_reason": "tool_calls",
758
+ "index": 0,
759
+ }
760
+ ],
761
+ "usage": tool_call_usage,
762
+ "created": int(time.time()),
763
+ "id": current_id,
764
+ "model": current_model,
765
+ "object": "chat.completion.chunk",
766
+ }
767
+ yield f"data: {json.dumps(close_chunk)}\n\n"
768
+ yield "data: [DONE]\n\n"
769
+
770
+ except json.JSONDecodeError as e:
771
+ logger.debug(f"JSON解析错误: {e}")
772
+ except Exception as e:
773
+ logger.error(f"处理chunk错误: {e}")
774
+
775
+ # 非流式模式返回完整结果
776
+ if not is_stream:
777
+ yield json.dumps(result)
app/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from app.models import schemas
5
+
6
+ __all__ = ["schemas"]
app/models/schemas.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from typing import Dict, List, Optional, Any, Union, Literal
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class ContentPart(BaseModel):
9
+ """Content part model for OpenAI's new content format"""
10
+
11
+ type: str
12
+ text: Optional[str] = None
13
+
14
+
15
+ class Message(BaseModel):
16
+ """Chat message model"""
17
+
18
+ role: str
19
+ content: Optional[Union[str, List[ContentPart]]] = None
20
+ reasoning_content: Optional[str] = None
21
+ tool_calls: Optional[List[Dict[str, Any]]] = None
22
+
23
+
24
+ class OpenAIRequest(BaseModel):
25
+ """OpenAI-compatible request model"""
26
+
27
+ model: str
28
+ messages: List[Message]
29
+ stream: Optional[bool] = False
30
+ temperature: Optional[float] = None
31
+ max_tokens: Optional[int] = None
32
+ tools: Optional[List[Dict[str, Any]]] = None
33
+ tool_choice: Optional[Any] = None
34
+
35
+
36
+ class ModelItem(BaseModel):
37
+ """Model information item"""
38
+
39
+ id: str
40
+ name: str
41
+ owned_by: str
42
+
43
+
44
+ class UpstreamRequest(BaseModel):
45
+ """Upstream service request model"""
46
+
47
+ stream: bool
48
+ model: str
49
+ messages: List[Message]
50
+ params: Dict[str, Any] = {}
51
+ features: Dict[str, Any] = {}
52
+ background_tasks: Optional[Dict[str, bool]] = None
53
+ chat_id: Optional[str] = None
54
+ id: Optional[str] = None
55
+ mcp_servers: Optional[List[str]] = None
56
+ model_item: Optional[Dict[str, Any]] = {} # Model item dictionary
57
+ tools: Optional[List[Dict[str, Any]]] = None # Add tools field for OpenAI compatibility
58
+ variables: Optional[Dict[str, str]] = None
59
+ model_config = {"protected_namespaces": ()}
60
+
61
+
62
+ class Delta(BaseModel):
63
+ """Stream delta model"""
64
+
65
+ role: Optional[str] = None
66
+ content: Optional[str] = "" or None
67
+ reasoning_content: Optional[str] = None
68
+ tool_calls: Optional[List[Dict[str, Any]]] = None
69
+
70
+
71
+ class Choice(BaseModel):
72
+ """Response choice model"""
73
+
74
+ index: int
75
+ message: Optional[Message] = None
76
+ delta: Optional[Delta] = None
77
+ finish_reason: Optional[str] = None
78
+
79
+
80
+ class Usage(BaseModel):
81
+ """Token usage statistics"""
82
+
83
+ prompt_tokens: int = 0
84
+ completion_tokens: int = 0
85
+ total_tokens: int = 0
86
+
87
+
88
+ class OpenAIResponse(BaseModel):
89
+ """OpenAI-compatible response model"""
90
+
91
+ id: str
92
+ object: str
93
+ created: int
94
+ model: str
95
+ choices: List[Choice]
96
+ usage: Optional[Usage] = None
97
+
98
+
99
+ class UpstreamError(BaseModel):
100
+ """Upstream error model"""
101
+
102
+ detail: str
103
+ code: int
104
+
105
+
106
+ class UpstreamDataInner(BaseModel):
107
+ """Inner upstream data model"""
108
+
109
+ error: Optional[UpstreamError] = None
110
+
111
+
112
+ class UpstreamDataData(BaseModel):
113
+ """Upstream data content model"""
114
+
115
+ delta_content: str = ""
116
+ edit_content: str = ""
117
+ phase: str = ""
118
+ done: bool = False
119
+ usage: Optional[Usage] = None
120
+ error: Optional[UpstreamError] = None
121
+ inner: Optional[UpstreamDataInner] = None
122
+
123
+
124
+ class UpstreamData(BaseModel):
125
+ """Upstream data model"""
126
+
127
+ type: str
128
+ data: UpstreamDataData
129
+ error: Optional[UpstreamError] = None
130
+
131
+
132
+ class Model(BaseModel):
133
+ """Model information for listing"""
134
+
135
+ id: str
136
+ object: str = "model"
137
+ created: int
138
+ owned_by: str
139
+
140
+
141
+ class ModelsResponse(BaseModel):
142
+ """Models list response model"""
143
+
144
+ object: str = "list"
145
+ data: List[Model]
app/providers/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 多提供商架构包
6
+ 提供统一的提供商接口和路由机制
7
+ """
8
+
9
+ from app.providers.base import BaseProvider, ProviderConfig, provider_registry
10
+ from app.providers.zai_provider import ZAIProvider
11
+ from app.providers.k2think_provider import K2ThinkProvider
12
+ from app.providers.longcat_provider import LongCatProvider
13
+ from app.providers.provider_factory import ProviderFactory, ProviderRouter, get_provider_router, initialize_providers
14
+
15
+ __all__ = [
16
+ "BaseProvider",
17
+ "ProviderConfig",
18
+ "provider_registry",
19
+ "ZAIProvider",
20
+ "K2ThinkProvider",
21
+ "LongCatProvider",
22
+ "ProviderFactory",
23
+ "ProviderRouter",
24
+ "get_provider_router",
25
+ "initialize_providers"
26
+ ]
app/providers/base.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 基础提供商抽象层
6
+ 定义统一的提供商接口规范
7
+ """
8
+
9
+ import json
10
+ import time
11
+ import uuid
12
+ from abc import ABC, abstractmethod
13
+ from typing import Dict, List, Any, Optional, AsyncGenerator, Union
14
+ from dataclasses import dataclass
15
+
16
+ from app.models.schemas import OpenAIRequest, Message
17
+ from app.utils.logger import get_logger
18
+
19
+ logger = get_logger()
20
+
21
+
22
+ @dataclass
23
+ class ProviderConfig:
24
+ """提供商配置"""
25
+ name: str
26
+ api_endpoint: str
27
+ timeout: int = 30
28
+ headers: Optional[Dict[str, str]] = None
29
+ extra_config: Optional[Dict[str, Any]] = None
30
+
31
+
32
+ @dataclass
33
+ class ProviderResponse:
34
+ """提供商响应"""
35
+ success: bool
36
+ content: str = ""
37
+ error: Optional[str] = None
38
+ usage: Optional[Dict[str, int]] = None
39
+ extra_data: Optional[Dict[str, Any]] = None
40
+
41
+
42
+ class BaseProvider(ABC):
43
+ """基础提供商抽象类"""
44
+
45
+ def __init__(self, config: ProviderConfig):
46
+ """初始化提供商"""
47
+ self.config = config
48
+ self.name = config.name
49
+ self.logger = get_logger()
50
+
51
+ @abstractmethod
52
+ async def chat_completion(
53
+ self,
54
+ request: OpenAIRequest,
55
+ **kwargs
56
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
57
+ """
58
+ 聊天完成接口
59
+
60
+ Args:
61
+ request: OpenAI格式的请求
62
+ **kwargs: 额外参数
63
+
64
+ Returns:
65
+ 非流式: Dict[str, Any] - OpenAI格式的响应
66
+ 流式: AsyncGenerator[str, None] - SSE格式的流式响应
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
72
+ """
73
+ 转换OpenAI请求为提供商特定格式
74
+
75
+ Args:
76
+ request: OpenAI格式的请求
77
+
78
+ Returns:
79
+ Dict[str, Any]: 提供商特定格式的请求
80
+ """
81
+ pass
82
+
83
+ @abstractmethod
84
+ async def transform_response(
85
+ self,
86
+ response: Any,
87
+ request: OpenAIRequest
88
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
89
+ """
90
+ 转换提供商响应为OpenAI格式
91
+
92
+ Args:
93
+ response: 提供商的原始响应
94
+ request: 原始请求(用于构造响应)
95
+
96
+ Returns:
97
+ Union[Dict[str, Any], AsyncGenerator[str, None]]: OpenAI格式的响应
98
+ """
99
+ pass
100
+
101
+ def get_supported_models(self) -> List[str]:
102
+ """获取支持的模型列表"""
103
+ return []
104
+
105
+ def create_chat_id(self) -> str:
106
+ """生成聊天ID"""
107
+ return f"chatcmpl-{uuid.uuid4().hex}"
108
+
109
+ def create_openai_chunk(
110
+ self,
111
+ chat_id: str,
112
+ model: str,
113
+ delta: Dict[str, Any],
114
+ finish_reason: Optional[str] = None
115
+ ) -> Dict[str, Any]:
116
+ """创建OpenAI格式的流式响应块"""
117
+ return {
118
+ "id": chat_id,
119
+ "object": "chat.completion.chunk",
120
+ "created": int(time.time()),
121
+ "model": model,
122
+ "choices": [{
123
+ "index": 0,
124
+ "delta": delta,
125
+ "finish_reason": finish_reason,
126
+ "logprobs": None,
127
+ }],
128
+ "system_fingerprint": f"fp_{self.name}_001",
129
+ }
130
+
131
+ def create_openai_response(
132
+ self,
133
+ chat_id: str,
134
+ model: str,
135
+ content: str,
136
+ usage: Optional[Dict[str, int]] = None
137
+ ) -> Dict[str, Any]:
138
+ """创建OpenAI格式的非流式响应"""
139
+ return {
140
+ "id": chat_id,
141
+ "object": "chat.completion",
142
+ "created": int(time.time()),
143
+ "model": model,
144
+ "choices": [{
145
+ "index": 0,
146
+ "message": {
147
+ "role": "assistant",
148
+ "content": content
149
+ },
150
+ "finish_reason": "stop",
151
+ "logprobs": None,
152
+ }],
153
+ "usage": usage or {
154
+ "prompt_tokens": 0,
155
+ "completion_tokens": 0,
156
+ "total_tokens": 0
157
+ },
158
+ "system_fingerprint": f"fp_{self.name}_001",
159
+ }
160
+
161
+ def create_openai_response_with_reasoning(
162
+ self,
163
+ chat_id: str,
164
+ model: str,
165
+ content: str,
166
+ reasoning_content: str = None,
167
+ usage: Optional[Dict[str, int]] = None
168
+ ) -> Dict[str, Any]:
169
+ """创建包含推理内容的OpenAI格式非流式响应"""
170
+ message = {
171
+ "role": "assistant",
172
+ "content": content
173
+ }
174
+
175
+ # 只有当推理内容存在且不为空时才添加
176
+ if reasoning_content and reasoning_content.strip():
177
+ message["reasoning_content"] = reasoning_content
178
+
179
+ return {
180
+ "id": chat_id,
181
+ "object": "chat.completion",
182
+ "created": int(time.time()),
183
+ "model": model,
184
+ "choices": [{
185
+ "index": 0,
186
+ "message": message,
187
+ "finish_reason": "stop",
188
+ "logprobs": None,
189
+ }],
190
+ "usage": usage or {
191
+ "prompt_tokens": 0,
192
+ "completion_tokens": 0,
193
+ "total_tokens": 0
194
+ },
195
+ "system_fingerprint": f"fp_{self.name}_001",
196
+ }
197
+
198
+ async def format_sse_chunk(self, chunk: Dict[str, Any]) -> str:
199
+ """格式化SSE响应块"""
200
+ return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
201
+
202
+ async def format_sse_done(self) -> str:
203
+ """格式化SSE结束标记"""
204
+ return "data: [DONE]\n\n"
205
+
206
+ def log_request(self, request: OpenAIRequest):
207
+ """记录请求日志"""
208
+ self.logger.info(f"🔄 {self.name} 处理请求: {request.model}")
209
+ self.logger.debug(f" 消息数量: {len(request.messages)}")
210
+ self.logger.debug(f" 流式模式: {request.stream}")
211
+
212
+ def log_response(self, success: bool, error: Optional[str] = None):
213
+ """记录响应日志"""
214
+ if success:
215
+ self.logger.info(f"✅ {self.name} 响应成功")
216
+ else:
217
+ self.logger.error(f"❌ {self.name} 响应失败: {error}")
218
+
219
+ def handle_error(self, error: Exception, context: str = "") -> Dict[str, Any]:
220
+ """统一错误处理"""
221
+ error_msg = f"{self.name} {context} 错误: {str(error)}"
222
+ self.logger.error(error_msg)
223
+
224
+ return {
225
+ "error": {
226
+ "message": error_msg,
227
+ "type": "provider_error",
228
+ "code": "internal_error"
229
+ }
230
+ }
231
+
232
+
233
+ class ProviderRegistry:
234
+ """提供商注册表"""
235
+
236
+ def __init__(self):
237
+ self._providers: Dict[str, BaseProvider] = {}
238
+ self._model_mapping: Dict[str, str] = {}
239
+
240
+ def register(self, provider: BaseProvider, models: List[str]):
241
+ """注册提供商"""
242
+ self._providers[provider.name] = provider
243
+ for model in models:
244
+ self._model_mapping[model] = provider.name
245
+ logger.info(f"📝 注册提供商: {provider.name}, 模型: {models}")
246
+
247
+ def get_provider(self, model: str) -> Optional[BaseProvider]:
248
+ """根据模型获取提供商"""
249
+ provider_name = self._model_mapping.get(model)
250
+ if provider_name:
251
+ return self._providers.get(provider_name)
252
+ return None
253
+
254
+ def get_provider_by_name(self, name: str) -> Optional[BaseProvider]:
255
+ """根据名称获取提供商"""
256
+ return self._providers.get(name)
257
+
258
+ def list_models(self) -> List[str]:
259
+ """列出所有支持的模型"""
260
+ return list(self._model_mapping.keys())
261
+
262
+ def list_providers(self) -> List[str]:
263
+ """列出所有提供商"""
264
+ return list(self._providers.keys())
265
+
266
+
267
+ # 全局提供商注册表
268
+ provider_registry = ProviderRegistry()
app/providers/k2think_provider.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ K2Think 提供商适配器
6
+ """
7
+
8
+ import json
9
+ import re
10
+ import time
11
+ import uuid
12
+ import httpx
13
+ from typing import Dict, List, Any, Optional, AsyncGenerator, Union
14
+
15
+ from app.providers.base import BaseProvider, ProviderConfig
16
+ from app.models.schemas import OpenAIRequest, Message
17
+ from app.utils.logger import get_logger
18
+
19
+ logger = get_logger()
20
+
21
+
22
+ class K2ThinkProvider(BaseProvider):
23
+ """K2Think 提供商"""
24
+
25
+ def __init__(self):
26
+ config = ProviderConfig(
27
+ name="k2think",
28
+ api_endpoint="https://www.k2think.ai/api/guest/chat/completions",
29
+ timeout=30,
30
+ headers={
31
+ 'Accept': 'text/event-stream',
32
+ 'Accept-Encoding': 'gzip, deflate, br, zstd',
33
+ 'Accept-Language': 'en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7',
34
+ 'Content-Type': 'application/json',
35
+ 'Origin': 'https://www.k2think.ai',
36
+ 'Pragma': 'no-cache',
37
+ 'Referer': 'https://www.k2think.ai/guest',
38
+ 'Sec-Ch-Ua': '"Chromium";v="124", "Google Chrome";v="124", "Not-A.Brand";v="99"',
39
+ 'Sec-Ch-Ua-Mobile': '?0',
40
+ 'Sec-Ch-Ua-Platform': '"macOS"',
41
+ 'Sec-Fetch-Dest': 'empty',
42
+ 'Sec-Fetch-Mode': 'cors',
43
+ 'Sec-Fetch-Site': 'same-origin',
44
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
45
+ }
46
+ )
47
+ super().__init__(config)
48
+
49
+ # K2Think 特定配置
50
+ self.handshake_url = "https://www.k2think.ai/guest"
51
+ self.new_chat_url = "https://www.k2think.ai/api/v1/chats/guest/new"
52
+
53
+ # 内容解析正则表达式 - 使用DOTALL标志确保.匹配换行符
54
+ self.reasoning_pattern = re.compile(r'<details type="reasoning"[^>]*>.*?<summary>.*?</summary>(.*?)</details>', re.DOTALL)
55
+ self.answer_pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL)
56
+
57
+ def get_supported_models(self) -> List[str]:
58
+ """获取支持的模型列表"""
59
+ return ["MBZUAI-IFM/K2-Think"]
60
+
61
+ def parse_cookies(self, headers) -> str:
62
+ """解析Cookie"""
63
+ cookies = []
64
+ for key, value in headers.items():
65
+ if key.lower() == 'set-cookie':
66
+ cookies.append(value.split(';')[0])
67
+ return '; '.join(cookies)
68
+
69
+ def extract_reasoning_and_answer(self, content: str) -> tuple[str, str]:
70
+ """提取推理内容和答案内容"""
71
+ if not content:
72
+ return "", ""
73
+
74
+ try:
75
+ reasoning_match = self.reasoning_pattern.search(content)
76
+ reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
77
+
78
+ answer_match = self.answer_pattern.search(content)
79
+ answer = answer_match.group(1).strip() if answer_match else ""
80
+
81
+ return reasoning, answer
82
+ except Exception as e:
83
+ self.logger.error(f"提取K2内容错误: {e}")
84
+ return "", ""
85
+
86
+ def calculate_delta(self, previous: str, current: str) -> str:
87
+ """计算内容增量"""
88
+ if not previous:
89
+ return current
90
+ if not current or len(current) < len(previous):
91
+ return ""
92
+ return current[len(previous):]
93
+
94
+ def parse_api_response(self, obj: Any) -> tuple[str, bool]:
95
+ """解析API响应"""
96
+ if not obj or not isinstance(obj, dict):
97
+ return "", False
98
+
99
+ if obj.get("done") is True:
100
+ return "", True
101
+
102
+ choices = obj.get("choices", [])
103
+ if choices and len(choices) > 0:
104
+ delta = choices[0].get("delta", {})
105
+ return delta.get("content", ""), False
106
+
107
+ content = obj.get("content")
108
+ if isinstance(content, str):
109
+ return content, False
110
+
111
+ return "", False
112
+
113
+ async def get_k2_auth_data(self, request: OpenAIRequest) -> Dict[str, Any]:
114
+ """获取K2Think认证数据"""
115
+ # 1. 握手请求 - 使用更简单的Accept-Encoding来避免Brotli问题
116
+ headers_for_handshake = {**self.config.headers}
117
+ headers_for_handshake['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
118
+
119
+ async with httpx.AsyncClient() as client:
120
+ handshake_response = await client.get(
121
+ self.handshake_url,
122
+ headers=headers_for_handshake,
123
+ follow_redirects=True
124
+ )
125
+ if not handshake_response.is_success:
126
+ try:
127
+ # 使用httpx的text属性,它会自动处理解压缩和编码
128
+ error_text = handshake_response.text
129
+ raise Exception(f"K2 握手失败: {handshake_response.status_code} {error_text[:200]}")
130
+ except Exception as e:
131
+ raise Exception(f"K2 握手失败: {handshake_response.status_code}")
132
+
133
+ initial_cookies = self.parse_cookies(handshake_response.headers)
134
+
135
+ # 2. 准备消息
136
+ prepared_messages = self.prepare_k2_messages(request.messages)
137
+ first_user_message = next((m for m in prepared_messages if m["role"] == "user"), None)
138
+ if not first_user_message:
139
+ raise Exception("没有找到用户消息来初始化对话")
140
+
141
+ # 3. 创建新对话
142
+ message_id = str(uuid.uuid4())
143
+ now = int(time.time() * 1000)
144
+ model_id = request.model or "MBZUAI-IFM/K2-Think"
145
+
146
+ new_chat_payload = {
147
+ "chat": {
148
+ "id": "",
149
+ "title": "Guest Chat",
150
+ "models": [model_id],
151
+ "params": {},
152
+ "history": {
153
+ "messages": {
154
+ message_id: {
155
+ "id": message_id,
156
+ "parentId": None,
157
+ "childrenIds": [],
158
+ "role": "user",
159
+ "content": first_user_message["content"],
160
+ "timestamp": now // 1000,
161
+ "models": [model_id]
162
+ }
163
+ },
164
+ "currentId": message_id
165
+ },
166
+ "messages": [{
167
+ "id": message_id,
168
+ "parentId": None,
169
+ "childrenIds": [],
170
+ "role": "user",
171
+ "content": first_user_message["content"],
172
+ "timestamp": now // 1000,
173
+ "models": [model_id]
174
+ }],
175
+ "tags": [],
176
+ "timestamp": now
177
+ }
178
+ }
179
+
180
+ headers_with_cookies = {**self.config.headers, 'Cookie': initial_cookies}
181
+ headers_with_cookies['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
182
+
183
+ async with httpx.AsyncClient() as client:
184
+ new_chat_response = await client.post(
185
+ self.new_chat_url,
186
+ headers=headers_with_cookies,
187
+ json=new_chat_payload,
188
+ follow_redirects=True
189
+ )
190
+ if not new_chat_response.is_success:
191
+ try:
192
+ # 使用httpx的text属性,它会自动处理解压缩和编码
193
+ error_text = new_chat_response.text
194
+ except Exception:
195
+ error_text = f"Status: {new_chat_response.status_code}"
196
+ raise Exception(f"K2 新对话创建失败: {new_chat_response.status_code} {error_text[:200]}")
197
+
198
+ try:
199
+ new_chat_data = new_chat_response.json()
200
+ except Exception as e:
201
+ # 如果JSON解析失败,尝试获取原始内容
202
+ try:
203
+ # 使用httpx的text属性,它会自动处理解压缩和编码
204
+ content_str = new_chat_response.text
205
+ self.logger.debug(f"K2 响应原始内容: {content_str[:500]}")
206
+ raise Exception(f"K2 响应JSON解析失败: {e}, 原始内容: {content_str[:200]}")
207
+ except Exception as decode_error:
208
+ # 如果text也失败,尝试手动处理
209
+ try:
210
+ raw_bytes = new_chat_response.content
211
+ content_str = raw_bytes.decode('utf-8', errors='replace')
212
+ raise Exception(f"K2 响应解析失败: {e}, 手动解码内容: {content_str[:200]}")
213
+ except Exception:
214
+ raise Exception(f"K2 响应解析完全失败: {e}, 解码错误: {decode_error}")
215
+ conversation_id = new_chat_data.get("id")
216
+ if not conversation_id:
217
+ raise Exception("无法从K2 /new端点获取conversation_id")
218
+
219
+ chat_specific_cookies = self.parse_cookies(new_chat_response.headers)
220
+
221
+ # 4. 组合最终Cookie
222
+ base_cookies = [initial_cookies, chat_specific_cookies]
223
+ base_cookies = [c for c in base_cookies if c]
224
+ final_cookie = '; '.join(base_cookies) + '; guest_conversation_count=1'
225
+
226
+ # 5. 构建最终请求载荷
227
+ final_payload = {
228
+ "stream": True,
229
+ "model": model_id,
230
+ "messages": prepared_messages,
231
+ "conversation_id": conversation_id,
232
+ "params": {}
233
+ }
234
+
235
+ # 添加可选参数
236
+ if request.temperature is not None:
237
+ final_payload["params"]["temperature"] = request.temperature
238
+ if request.max_tokens is not None:
239
+ final_payload["params"]["max_tokens"] = request.max_tokens
240
+
241
+ final_headers = {**self.config.headers, 'Cookie': final_cookie}
242
+
243
+ return {
244
+ "payload": final_payload,
245
+ "headers": final_headers
246
+ }
247
+
248
+ def prepare_k2_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
249
+ """准备K2Think消息格式"""
250
+ result = []
251
+ system_content = ""
252
+
253
+ for msg in messages:
254
+ if msg.role == "system":
255
+ system_content = system_content + "\n\n" + msg.content if system_content else msg.content
256
+ else:
257
+ content = msg.content
258
+ if isinstance(content, list):
259
+ # 处理多模态内容,提取文本
260
+ text_parts = [part.text for part in content if hasattr(part, 'text') and part.text]
261
+ content = "\n".join(text_parts)
262
+
263
+ result.append({
264
+ "role": msg.role,
265
+ "content": content
266
+ })
267
+
268
+ # 将系统消息合并到第一个用户消息中
269
+ if system_content:
270
+ first_user_idx = next((i for i, m in enumerate(result) if m["role"] == "user"), -1)
271
+ if first_user_idx >= 0:
272
+ result[first_user_idx]["content"] = f"{system_content}\n\n{result[first_user_idx]['content']}"
273
+ else:
274
+ result.insert(0, {"role": "user", "content": system_content})
275
+
276
+ return result
277
+
278
+ async def _handle_stream_request(
279
+ self,
280
+ transformed: Dict[str, Any],
281
+ request: OpenAIRequest
282
+ ) -> AsyncGenerator[str, None]:
283
+ """处理流式请求 - 在client.stream上下文内直接处理"""
284
+ chat_id = self.create_chat_id()
285
+ model = transformed["model"]
286
+
287
+ # 准备请求头
288
+ headers_for_request = {**transformed["headers"]}
289
+ headers_for_request['Accept-Encoding'] = 'gzip, deflate'
290
+
291
+ self.logger.info(f"🌊 开始K2Think流式请求")
292
+
293
+ async with httpx.AsyncClient(timeout=30.0) as client:
294
+ async with client.stream(
295
+ "POST",
296
+ transformed["url"],
297
+ headers=headers_for_request,
298
+ json=transformed["payload"]
299
+ ) as response:
300
+ if not response.is_success:
301
+ error_msg = f"K2Think API 错误: {response.status_code}"
302
+ self.log_response(False, error_msg)
303
+ # 对于流式响应,我们需要yield错误信息
304
+ yield await self.format_sse_chunk({
305
+ "error": {
306
+ "message": error_msg,
307
+ "type": "provider_error",
308
+ "code": "api_error"
309
+ }
310
+ })
311
+ return
312
+
313
+ # 发送初始角色块
314
+ yield await self.format_sse_chunk(
315
+ self.create_openai_chunk(chat_id, model, {"role": "assistant"})
316
+ )
317
+
318
+ # 处理流式数据
319
+ accumulated_content = ""
320
+ previous_reasoning = ""
321
+ previous_answer = ""
322
+ reasoning_phase = True
323
+ chunk_count = 0
324
+
325
+ try:
326
+ async for line in response.aiter_lines():
327
+ chunk_count += 1
328
+ self.logger.debug(f"📦 收到数据块 #{chunk_count}: {line[:100]}...")
329
+
330
+ if not line.startswith("data:"):
331
+ continue
332
+
333
+ data_str = line[5:].strip()
334
+ if self._is_end_marker(data_str):
335
+ self.logger.debug(f"🏁 检测到结束标记: {data_str}")
336
+ continue
337
+
338
+ content = self._parse_data_string(data_str)
339
+ if not content:
340
+ continue
341
+
342
+ accumulated_content = content
343
+ current_reasoning, current_answer = self.extract_reasoning_and_answer(accumulated_content)
344
+
345
+ # 处理推理阶段
346
+ if reasoning_phase and current_reasoning:
347
+ delta = self.calculate_delta(previous_reasoning, current_reasoning)
348
+ if delta.strip():
349
+ self.logger.debug(f"🧠 推理增量: {delta[:50]}...")
350
+ yield await self.format_sse_chunk(
351
+ self.create_openai_chunk(chat_id, model, {"reasoning_content": delta})
352
+ )
353
+ previous_reasoning = current_reasoning
354
+
355
+ # 切换到答案阶段
356
+ if current_answer and reasoning_phase:
357
+ reasoning_phase = False
358
+ self.logger.debug("🔄 切换到答案阶段")
359
+ # 发送剩余的推理内容
360
+ final_reasoning_delta = self.calculate_delta(previous_reasoning, current_reasoning)
361
+ if final_reasoning_delta.strip():
362
+ yield await self.format_sse_chunk(
363
+ self.create_openai_chunk(chat_id, model, {"reasoning_content": final_reasoning_delta})
364
+ )
365
+
366
+ # 处理答案阶段
367
+ if not reasoning_phase and current_answer:
368
+ delta = self.calculate_delta(previous_answer, current_answer)
369
+ if delta.strip():
370
+ self.logger.debug(f"💬 答案增量: {delta[:50]}...")
371
+ yield await self.format_sse_chunk(
372
+ self.create_openai_chunk(chat_id, model, {"content": delta})
373
+ )
374
+ previous_answer = current_answer
375
+
376
+ except Exception as e:
377
+ self.logger.error(f"流式响应处理错误: {e}")
378
+ yield await self.format_sse_chunk({
379
+ "error": {
380
+ "message": f"流式处理错误: {str(e)}",
381
+ "type": "stream_error",
382
+ "code": "processing_error"
383
+ }
384
+ })
385
+ return
386
+
387
+ # 发送结束块
388
+ self.logger.info(f"✅ K2Think流式响应完成,共处理 {chunk_count} 个数据块")
389
+ yield await self.format_sse_chunk(
390
+ self.create_openai_chunk(chat_id, model, {}, "stop")
391
+ )
392
+ yield await self.format_sse_done()
393
+
394
+ async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
395
+ """转换OpenAI请求为K2Think格式"""
396
+ self.logger.info(f"🔄 转换 OpenAI 请求到 K2Think 格式: {request.model}")
397
+
398
+ auth_data = await self.get_k2_auth_data(request)
399
+
400
+ return {
401
+ "url": self.config.api_endpoint,
402
+ "headers": auth_data["headers"],
403
+ "payload": auth_data["payload"],
404
+ "model": request.model
405
+ }
406
+
407
+ async def chat_completion(
408
+ self,
409
+ request: OpenAIRequest
410
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
411
+ """聊天完成接口"""
412
+ self.log_request(request)
413
+
414
+ try:
415
+ # 转换请求
416
+ transformed = await self.transform_request(request)
417
+
418
+ # 发送请求 - 使用更兼容的压缩设置
419
+ headers_for_request = {**transformed["headers"]}
420
+ headers_for_request['Accept-Encoding'] = 'gzip, deflate' # 移除br和zstd
421
+
422
+ if request.stream:
423
+ # 流式请求 - 直接在这里处理流式响应
424
+ return self._handle_stream_request(transformed, request)
425
+ else:
426
+ # 非流式请求 - 使用传统的 client.post()
427
+ async with httpx.AsyncClient(timeout=30.0) as client:
428
+ response = await client.post(
429
+ transformed["url"],
430
+ headers=headers_for_request,
431
+ json=transformed["payload"]
432
+ )
433
+
434
+ if not response.is_success:
435
+ error_msg = f"K2Think API 错误: {response.status_code}"
436
+ self.log_response(False, error_msg)
437
+ return self.handle_error(Exception(error_msg))
438
+
439
+ # 转换非流式响应
440
+ return await self.transform_response(response, request, transformed)
441
+
442
+ except Exception as e:
443
+ self.log_response(False, str(e))
444
+ return self.handle_error(e, "请求处理")
445
+
446
+ async def transform_response(
447
+ self,
448
+ response: httpx.Response,
449
+ request: OpenAIRequest,
450
+ transformed: Dict[str, Any]
451
+ ) -> Dict[str, Any]:
452
+ """转换K2Think响应为OpenAI格式 - 仅用于非流式请求"""
453
+ chat_id = self.create_chat_id()
454
+ model = transformed["model"]
455
+
456
+ # 流式请求现在由 _handle_stream_request 直接处理
457
+ # 这里只处理非流式请求
458
+ return await self._handle_non_stream_response(response, chat_id, model)
459
+
460
+ def _is_end_marker(self, data: str) -> bool:
461
+ """检查是否为结束标记"""
462
+ return not data or data in ["-1", "[DONE]", "DONE", "done"]
463
+
464
+ def _parse_data_string(self, data_str: str) -> str:
465
+ """解析数据字符串"""
466
+ try:
467
+ obj = json.loads(data_str)
468
+ content, is_done = self.parse_api_response(obj)
469
+ return "" if is_done else content
470
+ except:
471
+ return data_str
472
+
473
+ async def _handle_non_stream_response(
474
+ self,
475
+ response: httpx.Response,
476
+ chat_id: str,
477
+ model: str
478
+ ) -> Dict[str, Any]:
479
+ """处理K2Think非流式响应"""
480
+ # 聚合流式内容 - 使用httpx的aiter_lines,它���自动处理解压缩
481
+ final_content = ""
482
+
483
+ try:
484
+ # 使用aiter_lines(),httpx会自动处理压缩和编码
485
+ async for line in response.aiter_lines():
486
+ if not line.startswith("data:"):
487
+ continue
488
+
489
+ data_str = line[5:].strip()
490
+ if self._is_end_marker(data_str):
491
+ continue
492
+
493
+ content = self._parse_data_string(data_str)
494
+ if content:
495
+ final_content = content
496
+
497
+ except Exception as e:
498
+ self.logger.error(f"非流式响应处理错误: {e}")
499
+ raise
500
+
501
+ # 提取推理内容和答案内容
502
+ reasoning, answer = self.extract_reasoning_and_answer(final_content)
503
+
504
+ # 清理内容格式
505
+ reasoning = reasoning.replace("\\n", "\n") if reasoning else ""
506
+ answer = answer.replace("\\n", "\n") if answer else final_content
507
+
508
+ # 创建包含推理内容的响应
509
+ return self.create_openai_response_with_reasoning(chat_id, model, answer, reasoning)
app/providers/longcat_provider.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ LongCat 提供商适配器
6
+ """
7
+
8
+ import json
9
+ import time
10
+ import httpx
11
+ import random
12
+ import asyncio
13
+ from typing import Dict, List, Any, Optional, AsyncGenerator, Union
14
+
15
+ from app.providers.base import BaseProvider, ProviderConfig
16
+ from app.models.schemas import OpenAIRequest, Message
17
+ from app.utils.logger import get_logger
18
+ from app.utils.user_agent import get_dynamic_headers
19
+ from app.core.config import settings
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ class LongCatProvider(BaseProvider):
25
+ """LongCat 提供商"""
26
+
27
+ def __init__(self):
28
+ # 使用动态生成的 headers,不包含 User-Agent(将在请求时动态生成)
29
+ config = ProviderConfig(
30
+ name="longcat",
31
+ api_endpoint="https://longcat.chat/api/v1/chat-completion",
32
+ timeout=30,
33
+ headers={
34
+ 'accept': 'text/event-stream,application/json',
35
+ 'content-type': 'application/json',
36
+ 'origin': 'https://longcat.chat',
37
+ 'referer': 'https://longcat.chat/t',
38
+ }
39
+ )
40
+ super().__init__(config)
41
+ self.base_url = "https://longcat.chat"
42
+ self.session_create_url = f"{self.base_url}/api/v1/session-create"
43
+ self.session_delete_url = f"{self.base_url}/api/v1/session-delete"
44
+
45
+ def get_supported_models(self) -> List[str]:
46
+ """获取支持的模型列表"""
47
+ return ["LongCat-Flash", "LongCat", "LongCat-Search"]
48
+
49
+ def get_passport_token(self) -> Optional[str]:
50
+ """获取 LongCat passport token"""
51
+ # 优先使用环境变量中的单个token
52
+ if settings.LONGCAT_PASSPORT_TOKEN:
53
+ return settings.LONGCAT_PASSPORT_TOKEN
54
+
55
+ # 从token文件中随机选择一个
56
+ token_list = settings.longcat_token_list
57
+ if token_list:
58
+ return random.choice(token_list)
59
+
60
+ return None
61
+
62
+ def create_headers_with_auth(self, token: str, user_agent: str, referer: str = None) -> Dict[str, str]:
63
+ """创建带认证的请求头"""
64
+ headers = {
65
+ "User-Agent": user_agent,
66
+ "Content-Type": "application/json",
67
+ "x-requested-with": "XMLHttpRequest",
68
+ "X-Client-Language": "zh",
69
+ "Cookie": f"passport_token_key={token}",
70
+ "Accept": "text/event-stream,application/json",
71
+ "Origin": "https://longcat.chat"
72
+ }
73
+ if referer:
74
+ headers["Referer"] = referer
75
+ else:
76
+ headers["Referer"] = f"{self.base_url}/"
77
+ return headers
78
+
79
+ async def create_session(self, token: str, user_agent: str) -> str:
80
+ """创建会话并返回 conversation_id"""
81
+ headers = self.create_headers_with_auth(token, user_agent)
82
+ data = {"model": "", "agentId": ""}
83
+
84
+ async with httpx.AsyncClient(timeout=30.0) as client:
85
+ response = await client.post(
86
+ self.session_create_url,
87
+ headers=headers,
88
+ json=data
89
+ )
90
+
91
+ if response.status_code != 200:
92
+ raise Exception(f"会话创建失败: {response.status_code}")
93
+
94
+ response_data = response.json()
95
+ if response_data.get("code") != 0:
96
+ raise Exception(f"会话创建错误: {response_data.get('message')}")
97
+
98
+ return response_data["data"]["conversationId"]
99
+
100
+ async def delete_session(self, conversation_id: str, token: str, user_agent: str) -> None:
101
+ """删除会话"""
102
+ try:
103
+ headers = self.create_headers_with_auth(
104
+ token,
105
+ user_agent,
106
+ f"{self.base_url}/c/{conversation_id}"
107
+ )
108
+
109
+ async with httpx.AsyncClient(timeout=30.0) as client:
110
+ url = f"{self.session_delete_url}?conversationId={conversation_id}"
111
+ response = await client.get(url, headers=headers)
112
+
113
+ if response.status_code == 200:
114
+ self.logger.debug(f"成功删除会话 {conversation_id}")
115
+ else:
116
+ self.logger.warning(f"删除会话失败: {response.status_code}")
117
+ except Exception as e:
118
+ self.logger.error(f"删除会话出错: {e}")
119
+
120
+ def schedule_session_deletion(self, conversation_id: str, token: str, user_agent: str):
121
+ """异步删除会话(不等待)"""
122
+ asyncio.create_task(self.delete_session(conversation_id, token, user_agent))
123
+
124
+ def format_messages_for_longcat(self, messages: List[Message]) -> str:
125
+ """格式化消息为 LongCat 格式"""
126
+ formatted_messages = []
127
+ for msg in messages:
128
+ content = msg.content
129
+ if isinstance(content, list):
130
+ # 处理多模态内容,提取文本
131
+ text_parts = []
132
+ for part in content:
133
+ if hasattr(part, 'text') and part.text:
134
+ text_parts.append(part.text)
135
+ content = "\n".join(text_parts)
136
+ formatted_messages.append(f"{msg.role}:{content}")
137
+ return ";".join(formatted_messages)
138
+
139
+ async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
140
+ """转换OpenAI请求为LongCat格式"""
141
+ # 获取认证token
142
+ passport_token = self.get_passport_token()
143
+ if not passport_token:
144
+ raise Exception("未配置 LongCat passport token,请设置 LONGCAT_PASSPORT_TOKEN 环境变量或 LONGCAT_TOKENS_FILE")
145
+
146
+ # 生成动态 User-Agent
147
+ dynamic_headers = get_dynamic_headers()
148
+ user_agent = dynamic_headers.get("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
149
+
150
+ # 创建会话
151
+ conversation_id = await self.create_session(passport_token, user_agent)
152
+
153
+ # 格式化消息内容
154
+ formatted_content = self.format_messages_for_longcat(request.messages)
155
+
156
+ # 构建LongCat请求载荷
157
+ payload = {
158
+ "conversationId": conversation_id,
159
+ "content": formatted_content,
160
+ "reasonEnabled": 0,
161
+ "searchEnabled": 1 if "search" in request.model.lower() else 0,
162
+ "parentMessageId": 0
163
+ }
164
+
165
+ # 创建带认证的请求头
166
+ headers = self.create_headers_with_auth(
167
+ passport_token,
168
+ user_agent,
169
+ f"{self.base_url}/c/{conversation_id}"
170
+ )
171
+
172
+ return {
173
+ "url": self.config.api_endpoint,
174
+ "headers": headers,
175
+ "payload": payload,
176
+ "model": request.model,
177
+ "conversation_id": conversation_id,
178
+ "passport_token": passport_token,
179
+ "user_agent": user_agent
180
+ }
181
+
182
+ async def chat_completion(
183
+ self,
184
+ request: OpenAIRequest,
185
+ **kwargs
186
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
187
+ """聊天完成接口"""
188
+ self.log_request(request)
189
+
190
+ try:
191
+ # 转换请求
192
+ transformed = await self.transform_request(request)
193
+
194
+ # 发送请求
195
+ async with httpx.AsyncClient(timeout=30.0) as client:
196
+ response = await client.post(
197
+ transformed["url"],
198
+ headers=transformed["headers"],
199
+ json=transformed["payload"]
200
+ )
201
+
202
+ if not response.is_success:
203
+ error_msg = f"LongCat API 错误: {response.status_code}"
204
+ try:
205
+ error_detail = await response.atext()
206
+ self.logger.error(f"❌ API 错误详情: {error_detail}")
207
+ except:
208
+ pass
209
+ self.log_response(False, error_msg)
210
+ return self.handle_error(Exception(error_msg))
211
+
212
+ # 转换响应
213
+ return await self.transform_response(response, request, transformed)
214
+
215
+ except Exception as e:
216
+ self.logger.error(f"❌ LongCat 请求处理异常: {e}")
217
+ self.log_response(False, str(e))
218
+ return self.handle_error(e, "请求处理")
219
+
220
+ async def transform_response(
221
+ self,
222
+ response: httpx.Response,
223
+ request: OpenAIRequest,
224
+ transformed: Dict[str, Any]
225
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
226
+ """转换LongCat响应为OpenAI格式"""
227
+ chat_id = self.create_chat_id()
228
+ model = transformed["model"]
229
+ conversation_id = transformed["conversation_id"]
230
+ passport_token = transformed["passport_token"]
231
+ user_agent = transformed["user_agent"]
232
+
233
+ if request.stream:
234
+ return self._handle_stream_response(
235
+ response, chat_id, model, conversation_id, passport_token, user_agent
236
+ )
237
+ else:
238
+ return await self._handle_non_stream_response(
239
+ response, chat_id, model, conversation_id, passport_token, user_agent
240
+ )
241
+
242
+ async def _handle_stream_response(
243
+ self,
244
+ response: httpx.Response,
245
+ chat_id: str,
246
+ model: str,
247
+ conversation_id: str,
248
+ passport_token: str,
249
+ user_agent: str
250
+ ) -> AsyncGenerator[str, None]:
251
+ """处理LongCat流式响应"""
252
+ session_deleted = False
253
+
254
+ try:
255
+ # 发送初始角色块
256
+ yield await self.format_sse_chunk(
257
+ self.create_openai_chunk(chat_id, model, {"role": "assistant"})
258
+ )
259
+
260
+ stream_finished = False
261
+
262
+ async for line in response.aiter_lines():
263
+ line = line.strip()
264
+
265
+ # 首先检查是否是错误响应(JSON格式但不是SSE格式)
266
+ if not line.startswith('data:'):
267
+ # 尝试解析为JSON错误响应
268
+ try:
269
+ error_data = json.loads(line)
270
+ if isinstance(error_data, dict) and 'code' in error_data and 'message' in error_data:
271
+ # 这是一个错误响应
272
+ self.logger.error(f"❌ LongCat API 返回错误: {error_data}")
273
+ error_message = error_data.get('message', '未知错误')
274
+ error_code = error_data.get('code', 'unknown')
275
+
276
+ # 使用统一的错误处理函数
277
+ error_exception = Exception(f"LongCat API 错误 ({error_code}): {error_message}")
278
+ error_response = self.handle_error(error_exception, "API响应")
279
+
280
+ # 发送错误响应块
281
+ yield await self.format_sse_chunk(error_response)
282
+ yield await self.format_sse_done()
283
+
284
+ # 清理会话
285
+ if not session_deleted:
286
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
287
+ session_deleted = True
288
+ return
289
+ except json.JSONDecodeError:
290
+ # 不是JSON,跳过这行
291
+ continue
292
+
293
+ # 如果不是错误响应,跳过
294
+ continue
295
+
296
+ data_str = line[5:].strip()
297
+ if data_str == '[DONE]':
298
+ # 如果还没有发送完成块,发送一个
299
+ if not stream_finished:
300
+ yield await self.format_sse_chunk(
301
+ self.create_openai_chunk(chat_id, model, {}, "stop")
302
+ )
303
+ yield await self.format_sse_done()
304
+
305
+ # 清理会话
306
+ if not session_deleted:
307
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
308
+ session_deleted = True
309
+ break
310
+
311
+ try:
312
+ longcat_data = json.loads(data_str)
313
+
314
+ # 获取 delta 内容
315
+ choices = longcat_data.get("choices", [])
316
+ if not choices:
317
+ continue
318
+
319
+ delta = choices[0].get("delta", {})
320
+ content = delta.get("content")
321
+ finish_reason = choices[0].get("finishReason")
322
+
323
+ # 只有当内容不为空时才发送内容块
324
+ if content is not None and content != "":
325
+ openai_chunk = self.create_openai_chunk(
326
+ chat_id,
327
+ model,
328
+ {"content": content}
329
+ )
330
+ yield await self.format_sse_chunk(openai_chunk)
331
+
332
+ # 检查是否为流的结束
333
+ # LongCat 使用 lastOne=true 来标识最后一个块
334
+ if longcat_data.get("lastOne") and not stream_finished:
335
+ yield await self.format_sse_chunk(
336
+ self.create_openai_chunk(chat_id, model, {}, "stop")
337
+ )
338
+ yield await self.format_sse_done()
339
+ stream_finished = True
340
+
341
+ # 清理会话
342
+ if not session_deleted:
343
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
344
+ session_deleted = True
345
+ break
346
+
347
+ # 备用检查:如果有 finishReason 但没有 lastOne,也可能是结束
348
+ elif finish_reason == "stop" and longcat_data.get("contentStatus") == "FINISHED" and not stream_finished:
349
+ yield await self.format_sse_chunk(
350
+ self.create_openai_chunk(chat_id, model, {}, "stop")
351
+ )
352
+ yield await self.format_sse_done()
353
+ stream_finished = True
354
+
355
+ # 清理会话
356
+ if not session_deleted:
357
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
358
+ session_deleted = True
359
+ break
360
+
361
+ except json.JSONDecodeError as e:
362
+ self.logger.error(f"❌ 解析LongCat流数据错误: {e}")
363
+ continue
364
+ except Exception as e:
365
+ self.logger.error(f"❌ 处理LongCat流数据错误: {e}")
366
+ continue
367
+
368
+ except Exception as e:
369
+ self.logger.error(f"❌ LongCat流处理错误: {e}")
370
+ # 发送错误结束块(只有在还没有结束的情况下)
371
+ if not stream_finished:
372
+ yield await self.format_sse_chunk(
373
+ self.create_openai_chunk(chat_id, model, {}, "stop")
374
+ )
375
+ yield await self.format_sse_done()
376
+ finally:
377
+ # 确保会话被清理
378
+ if not session_deleted:
379
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
380
+
381
+ async def _handle_non_stream_response(
382
+ self,
383
+ response: httpx.Response,
384
+ chat_id: str,
385
+ model: str,
386
+ conversation_id: str,
387
+ passport_token: str,
388
+ user_agent: str
389
+ ) -> Dict[str, Any]:
390
+ """处理LongCat非流式响应"""
391
+ full_content = ""
392
+ usage_info = {
393
+ "prompt_tokens": 0,
394
+ "completion_tokens": 0,
395
+ "total_tokens": 0
396
+ }
397
+
398
+ try:
399
+ async for line in response.aiter_lines():
400
+ line = line.strip()
401
+ if not line.startswith('data:'):
402
+ # 检查是否是错误响应
403
+ try:
404
+ error_data = json.loads(line)
405
+ if isinstance(error_data, dict) and 'code' in error_data and 'message' in error_data:
406
+ # 这是一个错误响应
407
+ self.logger.error(f"❌ LongCat API 返回错误: {error_data}")
408
+ error_message = error_data.get('message', '未知错误')
409
+ error_code = error_data.get('code', 'unknown')
410
+
411
+ # 使用统一的错误处理函数
412
+ error_exception = Exception(f"LongCat API 错误 ({error_code}): {error_message}")
413
+
414
+ # 清理会话
415
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
416
+
417
+ return self.handle_error(error_exception, "API响应")
418
+ except json.JSONDecodeError:
419
+ # 不是JSON,跳过这行
420
+ pass
421
+ continue
422
+
423
+ data_str = line[5:].strip()
424
+ if data_str == '[DONE]':
425
+ break
426
+
427
+ try:
428
+ chunk = json.loads(data_str)
429
+
430
+ # 提取内容 - 只有当内容不为空时才添加
431
+ choices = chunk.get("choices", [])
432
+ if choices:
433
+ delta = choices[0].get("delta", {})
434
+ content = delta.get("content")
435
+ if content is not None and content != "":
436
+ full_content += content
437
+
438
+ # 提取使用信息(通常在最后的块中)
439
+ if chunk.get("tokenInfo"):
440
+ token_info = chunk["tokenInfo"]
441
+ usage_info = {
442
+ "prompt_tokens": token_info.get("promptTokens", 0),
443
+ "completion_tokens": token_info.get("completionTokens", 0),
444
+ "total_tokens": token_info.get("totalTokens", 0)
445
+ }
446
+
447
+ # 如果是最后一个块,可以提前结束
448
+ if chunk.get("lastOne"):
449
+ break
450
+
451
+ except json.JSONDecodeError:
452
+ continue
453
+
454
+ except Exception as e:
455
+ self.logger.error(f"❌ 处理LongCat非流式响应错误: {e}")
456
+ full_content = "处理响应时发生错误"
457
+ finally:
458
+ # 清理会话
459
+ self.schedule_session_deletion(conversation_id, passport_token, user_agent)
460
+
461
+ return self.create_openai_response(
462
+ chat_id,
463
+ model,
464
+ full_content.strip(),
465
+ usage_info
466
+ )
app/providers/provider_factory.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 提供商工厂和路由机制
6
+ 负责根据模型名称自动选择合适的提供商
7
+ """
8
+
9
+ import time
10
+ from typing import Dict, List, Optional, Union, AsyncGenerator, Any
11
+ from app.providers.base import BaseProvider, provider_registry
12
+ from app.providers.zai_provider import ZAIProvider
13
+ from app.providers.k2think_provider import K2ThinkProvider
14
+ from app.providers.longcat_provider import LongCatProvider
15
+ from app.models.schemas import OpenAIRequest
16
+ from app.core.config import settings
17
+ from app.utils.logger import get_logger
18
+
19
+ logger = get_logger()
20
+
21
+
22
+ class ProviderFactory:
23
+ """提供商工厂"""
24
+
25
+ def __init__(self):
26
+ self._initialized = False
27
+ self._default_provider = "zai"
28
+
29
+ def initialize(self):
30
+ """初始化所有提供商"""
31
+ if self._initialized:
32
+ return
33
+
34
+ try:
35
+ # 注册 Z.AI 提供商
36
+ zai_provider = ZAIProvider()
37
+ provider_registry.register(
38
+ zai_provider,
39
+ zai_provider.get_supported_models()
40
+ )
41
+
42
+ # 注册 K2Think 提供商
43
+ k2think_provider = K2ThinkProvider()
44
+ provider_registry.register(
45
+ k2think_provider,
46
+ k2think_provider.get_supported_models()
47
+ )
48
+
49
+ # 注册 LongCat 提供商
50
+ longcat_provider = LongCatProvider()
51
+ provider_registry.register(
52
+ longcat_provider,
53
+ longcat_provider.get_supported_models()
54
+ )
55
+
56
+ self._initialized = True
57
+
58
+ except Exception as e:
59
+ logger.error(f"❌ 提供商工厂初始化失败: {e}")
60
+ raise
61
+
62
+ def get_provider_for_model(self, model: str) -> Optional[BaseProvider]:
63
+ """根据模型名称获取提供商"""
64
+ if not self._initialized:
65
+ self.initialize()
66
+
67
+ # 首先尝试从配置的映射中获取
68
+ provider_mapping = settings.provider_model_mapping
69
+ provider_name = provider_mapping.get(model)
70
+
71
+ if provider_name:
72
+ provider = provider_registry.get_provider_by_name(provider_name)
73
+ if provider:
74
+ logger.debug(f"🎯 模型 {model} 映射到提供商 {provider_name}")
75
+ return provider
76
+
77
+ # 尝试从注册表中直接获取
78
+ provider = provider_registry.get_provider(model)
79
+ if provider:
80
+ logger.debug(f"🎯 模型 {model} 找到提供商 {provider.name}")
81
+ return provider
82
+
83
+ # 使用默认提供商
84
+ default_provider = provider_registry.get_provider_by_name(self._default_provider)
85
+ if default_provider:
86
+ logger.warning(f"⚠️ 模型 {model} 未找到专用提供商,使用默认提供商 {self._default_provider}")
87
+ return default_provider
88
+
89
+ logger.error(f"❌ 无法为模型 {model} 找到任何提供商")
90
+ return None
91
+
92
+ def list_supported_models(self) -> List[str]:
93
+ """列出所有支持的模型"""
94
+ if not self._initialized:
95
+ self.initialize()
96
+ return provider_registry.list_models()
97
+
98
+ def list_providers(self) -> List[str]:
99
+ """列出所有提供商"""
100
+ if not self._initialized:
101
+ self.initialize()
102
+ return provider_registry.list_providers()
103
+
104
+ def get_models_for_provider(self, provider_name: str) -> List[str]:
105
+ """获取指定提供商支持的模型"""
106
+ if not self._initialized:
107
+ self.initialize()
108
+
109
+ provider = provider_registry.get_provider_by_name(provider_name)
110
+ if provider:
111
+ return provider.get_supported_models()
112
+ return []
113
+
114
+
115
+ class ProviderRouter:
116
+ """提供商路由器"""
117
+
118
+ def __init__(self):
119
+ self.factory = ProviderFactory()
120
+
121
+ async def route_request(
122
+ self,
123
+ request: OpenAIRequest,
124
+ **kwargs
125
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
126
+ """路由请求到合适的提供商"""
127
+ logger.info(f"🚦 路由请求: 模型={request.model}, 流式={request.stream}")
128
+
129
+ # 获取提供商
130
+ provider = self.factory.get_provider_for_model(request.model)
131
+ if not provider:
132
+ error_msg = f"不支持的模型: {request.model}"
133
+ logger.error(f"❌ {error_msg}")
134
+ return {
135
+ "error": {
136
+ "message": error_msg,
137
+ "type": "invalid_request_error",
138
+ "code": "model_not_found"
139
+ }
140
+ }
141
+
142
+ logger.info(f"✅ 使用提供商: {provider.name}")
143
+
144
+ try:
145
+ # 调用提供商处理请求
146
+ result = await provider.chat_completion(request, **kwargs)
147
+ logger.info(f"🎉 请求处理��成: {provider.name}")
148
+ return result
149
+
150
+ except Exception as e:
151
+ error_msg = f"提供商 {provider.name} 处理请求失败: {str(e)}"
152
+ logger.error(f"❌ {error_msg}")
153
+ return provider.handle_error(e, "路由处理")
154
+
155
+ def get_models_list(self) -> Dict[str, Any]:
156
+ """获取模型列表(OpenAI格式)"""
157
+ models = []
158
+ current_time = int(time.time())
159
+
160
+ # 按提供商分组获取模型
161
+ for provider_name in self.factory.list_providers():
162
+ provider_models = self.factory.get_models_for_provider(provider_name)
163
+ for model in provider_models:
164
+ models.append({
165
+ "id": model,
166
+ "object": "model",
167
+ "created": current_time,
168
+ "owned_by": provider_name
169
+ })
170
+
171
+ return {
172
+ "object": "list",
173
+ "data": models
174
+ }
175
+
176
+
177
+ # 全局路由器实例
178
+ _router: Optional[ProviderRouter] = None
179
+
180
+
181
+ def get_provider_router() -> ProviderRouter:
182
+ """获取全局提供商路由器"""
183
+ global _router
184
+ if _router is None:
185
+ _router = ProviderRouter()
186
+ # 确保工厂已初始化
187
+ _router.factory.initialize()
188
+ return _router
189
+
190
+
191
+ def initialize_providers():
192
+ """初始化提供商系统"""
193
+ logger.info("🚀 初始化提供商系统...")
194
+ router = get_provider_router()
195
+ logger.info("✅ 提供商系统初始化完成")
196
+ return router
app/providers/zai_provider.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Z.AI 提供商适配器
6
+ """
7
+
8
+ import json
9
+ import time
10
+ import uuid
11
+ import httpx
12
+ import asyncio
13
+ from datetime import datetime
14
+ from typing import Dict, List, Any, Optional, AsyncGenerator, Union
15
+
16
+ from app.providers.base import BaseProvider, ProviderConfig
17
+ from app.models.schemas import OpenAIRequest, Message
18
+ from app.core.config import settings
19
+ from app.utils.logger import get_logger
20
+ from app.utils.token_pool import get_token_pool
21
+ from app.core.zai_transformer import generate_uuid, get_zai_dynamic_headers
22
+ from app.utils.sse_tool_handler import SSEToolHandler
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class ZAIProvider(BaseProvider):
28
+ """Z.AI 提供商"""
29
+
30
+ def __init__(self):
31
+ config = ProviderConfig(
32
+ name="zai",
33
+ api_endpoint=settings.API_ENDPOINT,
34
+ timeout=30,
35
+ headers=get_zai_dynamic_headers()
36
+ )
37
+ super().__init__(config)
38
+
39
+ # Z.AI 特定配置
40
+ self.base_url = "https://chat.z.ai"
41
+ self.auth_url = f"{self.base_url}/api/v1/auths/"
42
+
43
+ # 模型映射
44
+ self.model_mapping = {
45
+ settings.PRIMARY_MODEL: "0727-360B-API", # GLM-4.5
46
+ settings.THINKING_MODEL: "0727-360B-API", # GLM-4.5-Thinking
47
+ settings.SEARCH_MODEL: "0727-360B-API", # GLM-4.5-Search
48
+ settings.AIR_MODEL: "0727-106B-API", # GLM-4.5-Air
49
+ settings.GLM46_MODEL: "GLM-4-6-API-V1", # GLM-4.6
50
+ settings.GLM46_THINKING_MODEL: "GLM-4-6-API-V1", # GLM-4.6-Thinking
51
+ settings.GLM46_SEARCH_MODEL: "GLM-4-6-API-V1", # GLM-4.6-Search
52
+ }
53
+
54
+ def get_supported_models(self) -> List[str]:
55
+ """获取支持的模型列表"""
56
+ return [
57
+ settings.PRIMARY_MODEL,
58
+ settings.THINKING_MODEL,
59
+ settings.SEARCH_MODEL,
60
+ settings.AIR_MODEL,
61
+ settings.GLM46_MODEL,
62
+ settings.GLM46_THINKING_MODEL,
63
+ settings.GLM46_SEARCH_MODEL,
64
+ ]
65
+
66
+ async def get_token(self) -> str:
67
+ """获取认证令牌"""
68
+ # 如果启用匿名模式,只尝试获取访客令牌
69
+ if settings.ANONYMOUS_MODE:
70
+ try:
71
+ headers = get_zai_dynamic_headers()
72
+ async with httpx.AsyncClient() as client:
73
+ response = await client.get(self.auth_url, headers=headers, timeout=10.0)
74
+ if response.status_code == 200:
75
+ data = response.json()
76
+ token = data.get("token", "")
77
+ if token:
78
+ self.logger.debug(f"获取访客令牌成功: {token[:20]}...")
79
+ return token
80
+ except Exception as e:
81
+ self.logger.warning(f"异步获取访客令牌失败: {e}")
82
+
83
+ # 匿名模式下,如果获取访客令牌失败,直接返回空
84
+ self.logger.error("❌ 匿名模式下获取访客令牌失败")
85
+ return ""
86
+
87
+ # 非匿名模式:首先使用token池获取备份令牌
88
+ token_pool = get_token_pool()
89
+ if token_pool:
90
+ token = token_pool.get_next_token()
91
+ if token:
92
+ self.logger.debug(f"从token池获取令牌: {token[:20]}...")
93
+ return token
94
+
95
+ # 如果token池为空或没有可用token,使用配置的AUTH_TOKEN
96
+ if settings.AUTH_TOKEN and settings.AUTH_TOKEN != "sk-your-api-key":
97
+ self.logger.debug("使用配置的AUTH_TOKEN")
98
+ return settings.AUTH_TOKEN
99
+
100
+ self.logger.error("❌ 无法获取有效的认证令牌")
101
+ return ""
102
+
103
+ def mark_token_failure(self, token: str, error: Exception = None):
104
+ """标记token使用失败"""
105
+ token_pool = get_token_pool()
106
+ if token_pool:
107
+ token_pool.mark_token_failure(token, error)
108
+
109
+ async def transform_request(self, request: OpenAIRequest) -> Dict[str, Any]:
110
+ """转换OpenAI请求为Z.AI格式"""
111
+ self.logger.info(f"🔄 转换 OpenAI 请求到 Z.AI 格式: {request.model}")
112
+
113
+ # 获取认证令牌
114
+ token = await self.get_token()
115
+
116
+ # 处理消息格式
117
+ messages = []
118
+ for msg in request.messages:
119
+ if isinstance(msg.content, str):
120
+ messages.append({
121
+ "role": msg.role,
122
+ "content": msg.content
123
+ })
124
+ elif isinstance(msg.content, list):
125
+ # 处理多模态内容
126
+ content_parts = []
127
+ for part in msg.content:
128
+ if hasattr(part, 'type') and hasattr(part, 'text'):
129
+ content_parts.append({
130
+ "type": part.type,
131
+ "text": part.text
132
+ })
133
+ messages.append({
134
+ "role": msg.role,
135
+ "content": content_parts
136
+ })
137
+
138
+ # 确定请求的模型特性
139
+ requested_model = request.model
140
+ is_thinking = "-thinking" in requested_model.casefold()
141
+ is_search = "-search" in requested_model.casefold()
142
+ is_air = "-air" in requested_model.casefold()
143
+
144
+ # 获取上游模型ID
145
+ upstream_model_id = self.model_mapping.get(requested_model, "0727-360B-API")
146
+
147
+ # 构建MCP服务器列表
148
+ mcp_servers = []
149
+ if is_search and "-4.5" in requested_model:
150
+ mcp_servers.append("deep-web-search")
151
+ self.logger.info("🔍 检测到搜索模型,添加 deep-web-search MCP 服务器")
152
+
153
+ # 构建上游请求体
154
+ chat_id = generate_uuid()
155
+
156
+ body = {
157
+ "stream": True, # 总是使用流式
158
+ "model": upstream_model_id,
159
+ "messages": messages,
160
+ "params": {},
161
+ "features": {
162
+ "image_generation": False,
163
+ "web_search": is_search,
164
+ "auto_web_search": is_search,
165
+ "preview_mode": False,
166
+ "flags": [],
167
+ "features": [
168
+ {
169
+ "type": "mcp",
170
+ "server": "vibe-coding",
171
+ "status": "hidden"
172
+ },
173
+ {
174
+ "type": "mcp",
175
+ "server": "ppt-maker",
176
+ "status": "hidden"
177
+ },
178
+ {
179
+ "type": "mcp",
180
+ "server": "image-search",
181
+ "status": "hidden"
182
+ },
183
+ {
184
+ "type": "mcp",
185
+ "server": "deep-research",
186
+ "status": "hidden"
187
+ },
188
+ {
189
+ "type": "tool_selector",
190
+ "server": "tool_selector",
191
+ "status": "hidden"
192
+ },
193
+ {
194
+ "type": "mcp",
195
+ "server": "advanced-search",
196
+ "status": "hidden"
197
+ }
198
+ ],
199
+ "enable_thinking": is_thinking,
200
+ },
201
+ "background_tasks": {
202
+ "title_generation": False,
203
+ "tags_generation": False,
204
+ },
205
+ "mcp_servers": mcp_servers,
206
+ "variables": {
207
+ "{{USER_NAME}}": "Guest",
208
+ "{{USER_LOCATION}}": "Unknown",
209
+ "{{CURRENT_DATETIME}}": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
210
+ "{{CURRENT_DATE}}": datetime.now().strftime("%Y-%m-%d"),
211
+ "{{CURRENT_TIME}}": datetime.now().strftime("%H:%M:%S"),
212
+ "{{CURRENT_WEEKDAY}}": datetime.now().strftime("%A"),
213
+ "{{CURRENT_TIMEZONE}}": "Asia/Shanghai",
214
+ "{{USER_LANGUAGE}}": "zh-CN",
215
+ },
216
+ "model_item": {
217
+ "id": upstream_model_id,
218
+ "name": requested_model,
219
+ "owned_by": "z.ai"
220
+ },
221
+ "chat_id": chat_id,
222
+ "id": generate_uuid(),
223
+ }
224
+
225
+ # 处理工具支持
226
+ if settings.TOOL_SUPPORT and not is_thinking and request.tools:
227
+ body["tools"] = request.tools
228
+ self.logger.info(f"启用工具支持: {len(request.tools)} 个工具")
229
+ else:
230
+ body["tools"] = None
231
+
232
+ # 处理其他参数
233
+ if request.temperature is not None:
234
+ body["params"]["temperature"] = request.temperature
235
+ if request.max_tokens is not None:
236
+ body["params"]["max_tokens"] = request.max_tokens
237
+
238
+ # 构建请求头
239
+ headers = get_zai_dynamic_headers(chat_id)
240
+ if token:
241
+ headers["Authorization"] = f"Bearer {token}"
242
+
243
+ # 存储当前token用于错误处理
244
+ self._current_token = token
245
+
246
+ return {
247
+ "url": self.config.api_endpoint,
248
+ "headers": headers,
249
+ "body": body,
250
+ "token": token,
251
+ "chat_id": chat_id,
252
+ "model": requested_model
253
+ }
254
+
255
+ async def chat_completion(
256
+ self,
257
+ request: OpenAIRequest,
258
+ **kwargs
259
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
260
+ """聊天完成接口"""
261
+ self.log_request(request)
262
+
263
+ try:
264
+ # 转换请求
265
+ transformed = await self.transform_request(request)
266
+
267
+ # 根据请求类型返回响应
268
+ if request.stream:
269
+ # 流式响应
270
+ return self._create_stream_response_with_retry(request, transformed)
271
+ else:
272
+ # 非流式响应
273
+ async with httpx.AsyncClient(timeout=30.0) as client:
274
+ response = await client.post(
275
+ transformed["url"],
276
+ headers=transformed["headers"],
277
+ json=transformed["body"]
278
+ )
279
+
280
+ if not response.is_success:
281
+ error_msg = f"Z.AI API 错误: {response.status_code}"
282
+ self.log_response(False, error_msg)
283
+ return self.handle_error(Exception(error_msg))
284
+
285
+ return await self.transform_response(response, request, transformed)
286
+
287
+ except Exception as e:
288
+ self.log_response(False, str(e))
289
+ return self.handle_error(e, "请求处理")
290
+
291
+ async def _create_stream_response_with_retry(
292
+ self,
293
+ request: OpenAIRequest,
294
+ transformed: Dict[str, Any]
295
+ ) -> AsyncGenerator[str, None]:
296
+ """创建带重试机制的流式响应生成器"""
297
+ retry_count = 0
298
+ last_error = None
299
+ current_token = transformed.get("token", "")
300
+
301
+ while retry_count <= settings.MAX_RETRIES:
302
+ try:
303
+ # 如果是重试,重新获取令牌并更新请求
304
+ if retry_count > 0:
305
+ delay = settings.RETRY_DELAY
306
+ self.logger.warning(f"重试请求 ({retry_count}/{settings.MAX_RETRIES}) - 等待 {delay:.1f}s")
307
+ await asyncio.sleep(delay)
308
+
309
+ # 标记前一个token失败(如果不是匿名模式)
310
+ if current_token and not settings.ANONYMOUS_MODE:
311
+ self.mark_token_failure(current_token, Exception(f"Retry {retry_count}: {last_error}"))
312
+
313
+ # 重新获取令牌
314
+ self.logger.info("🔑 重新获取令牌用于重试...")
315
+ new_token = await self.get_token()
316
+ if not new_token:
317
+ self.logger.error("❌ 重试时无法获取有效的认证令牌")
318
+ raise Exception("重试时无法获取有效的认证令牌")
319
+ transformed["headers"]["Authorization"] = f"Bearer {new_token}"
320
+ current_token = new_token
321
+
322
+ async with httpx.AsyncClient(timeout=60.0) as client:
323
+ # 发送请求到上游
324
+ self.logger.info(f"🎯 发送请求到 Z.AI: {transformed['url']}")
325
+ async with client.stream(
326
+ "POST",
327
+ transformed["url"],
328
+ json=transformed["body"],
329
+ headers=transformed["headers"],
330
+ ) as response:
331
+ # 检查响应状态码
332
+ if response.status_code == 400:
333
+ # 400 错误,触发重试
334
+ error_text = await response.aread()
335
+ error_msg = error_text.decode('utf-8', errors='ignore')
336
+ self.logger.warning(f"❌ 上游返回 400 错误 (尝试 {retry_count + 1}/{settings.MAX_RETRIES + 1})")
337
+
338
+ retry_count += 1
339
+ last_error = f"400 Bad Request: {error_msg}"
340
+
341
+ # 如果还有重试机会,继续循环
342
+ if retry_count <= settings.MAX_RETRIES:
343
+ continue
344
+ else:
345
+ # 达到最大重试次数,抛出错误
346
+ self.logger.error(f"❌ 达到最大重试次数 ({settings.MAX_RETRIES}),请求失败")
347
+ error_response = {
348
+ "error": {
349
+ "message": f"Request failed after {settings.MAX_RETRIES} retries: {last_error}",
350
+ "type": "upstream_error",
351
+ "code": 400
352
+ }
353
+ }
354
+ yield f"data: {json.dumps(error_response)}\n\n"
355
+ yield "data: [DONE]\n\n"
356
+ return
357
+
358
+ elif response.status_code != 200:
359
+ # 其他错误,直接返回
360
+ self.logger.error(f"❌ 上游返回错误: {response.status_code}")
361
+ error_text = await response.aread()
362
+ error_msg = error_text.decode('utf-8', errors='ignore')
363
+ self.logger.error(f"❌ 错误详情: {error_msg}")
364
+
365
+ error_response = {
366
+ "error": {
367
+ "message": f"Upstream error: {response.status_code}",
368
+ "type": "upstream_error",
369
+ "code": response.status_code
370
+ }
371
+ }
372
+ yield f"data: {json.dumps(error_response)}\n\n"
373
+ yield "data: [DONE]\n\n"
374
+ return
375
+
376
+ # 200 成功,处理响应
377
+ if retry_count > 0:
378
+ self.logger.info(f"✨ 第 {retry_count} 次重试成功")
379
+
380
+ # 标记token使用成功(如果不是匿名模式)
381
+ if current_token and not settings.ANONYMOUS_MODE:
382
+ token_pool = get_token_pool()
383
+ if token_pool:
384
+ token_pool.mark_token_success(current_token)
385
+
386
+ # 处理流式响应
387
+ chat_id = transformed["chat_id"]
388
+ model = transformed["model"]
389
+ async for chunk in self._handle_stream_response(response, chat_id, model, request, transformed):
390
+ yield chunk
391
+ return
392
+
393
+ except Exception as e:
394
+ self.logger.error(f"❌ 流处理错误: {e}")
395
+ import traceback
396
+ self.logger.error(traceback.format_exc())
397
+
398
+ # 标记token失败(如果不是匿名模式)
399
+ if current_token and not settings.ANONYMOUS_MODE:
400
+ self.mark_token_failure(current_token, e)
401
+
402
+ # 检查是否还可以重试
403
+ retry_count += 1
404
+ last_error = str(e)
405
+
406
+ if retry_count > settings.MAX_RETRIES:
407
+ # 达到最大重试次数,返回错误
408
+ self.logger.error(f"❌ 达到最大重试次数 ({settings.MAX_RETRIES}),流处理失败")
409
+ error_response = {
410
+ "error": {
411
+ "message": f"Stream processing failed after {settings.MAX_RETRIES} retries: {last_error}",
412
+ "type": "stream_error"
413
+ }
414
+ }
415
+ yield f"data: {json.dumps(error_response)}\n\n"
416
+ yield "data: [DONE]\n\n"
417
+ return
418
+
419
+ async def transform_response(
420
+ self,
421
+ response: httpx.Response,
422
+ request: OpenAIRequest,
423
+ transformed: Dict[str, Any]
424
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
425
+ """转换Z.AI响应为OpenAI格式"""
426
+ chat_id = transformed["chat_id"]
427
+ model = transformed["model"]
428
+
429
+ if request.stream:
430
+ return self._handle_stream_response(response, chat_id, model, request, transformed)
431
+ else:
432
+ return await self._handle_non_stream_response(response, chat_id, model)
433
+
434
+ async def _handle_stream_response(
435
+ self,
436
+ response: httpx.Response,
437
+ chat_id: str,
438
+ model: str,
439
+ request: OpenAIRequest,
440
+ transformed: Dict[str, Any]
441
+ ) -> AsyncGenerator[str, None]:
442
+ """处理Z.AI流式响应"""
443
+ self.logger.info(f"✅ Z.AI 响应成功,开始处理 SSE 流")
444
+
445
+ # 初始化工具处理器(如果需要)
446
+ has_tools = transformed["body"].get("tools") is not None
447
+ tool_handler = None
448
+
449
+ if has_tools:
450
+ tool_handler = SSEToolHandler(model, stream=True)
451
+ self.logger.info(f"🔧 初始化工具处理器: {len(transformed['body'].get('tools', []))} 个工具")
452
+
453
+ # 处理状态
454
+ has_thinking = False
455
+ thinking_signature = None
456
+
457
+ # 处理SSE流
458
+ buffer = ""
459
+ line_count = 0
460
+ self.logger.debug("📡 开始接收 SSE 流数据...")
461
+
462
+ try:
463
+ async for line in response.aiter_lines():
464
+ line_count += 1
465
+ if not line:
466
+ continue
467
+
468
+ # 累积到buffer处理完整的数据行
469
+ buffer += line + "\n"
470
+
471
+ # 检查是否有完整的data行
472
+ while "\n" in buffer:
473
+ current_line, buffer = buffer.split("\n", 1)
474
+ if not current_line.strip():
475
+ continue
476
+
477
+ if current_line.startswith("data:"):
478
+ chunk_str = current_line[5:].strip()
479
+ if not chunk_str or chunk_str == "[DONE]":
480
+ if chunk_str == "[DONE]":
481
+ yield "data: [DONE]\n\n"
482
+ continue
483
+
484
+ self.logger.debug(f"📦 解析数据块: {chunk_str[:1000]}..." if len(chunk_str) > 1000 else f"📦 解析数据块: {chunk_str}")
485
+
486
+ try:
487
+ chunk = json.loads(chunk_str)
488
+
489
+ if chunk.get("type") == "chat:completion":
490
+ data = chunk.get("data", {})
491
+ phase = data.get("phase")
492
+
493
+ # 记录每个阶段(只在阶段变化时记录)
494
+ if phase and phase != getattr(self, '_last_phase', None):
495
+ self.logger.info(f"📈 SSE 阶段: {phase}")
496
+ self._last_phase = phase
497
+
498
+ # 使用工具处理器处理所有阶段
499
+ if tool_handler:
500
+ # 构建 SSE 数据块,包含所有必要字段
501
+ sse_chunk = {
502
+ "phase": phase,
503
+ "edit_content": data.get("edit_content", ""),
504
+ "delta_content": data.get("delta_content", ""),
505
+ "edit_index": data.get("edit_index"),
506
+ "usage": data.get("usage", {})
507
+ }
508
+
509
+ # 处理工具调用并输出结果
510
+ for output in tool_handler.process_sse_chunk(sse_chunk):
511
+ yield output
512
+
513
+ # 非工具调用模式 - 处理思考内容
514
+ elif phase == "thinking":
515
+ if not has_thinking:
516
+ has_thinking = True
517
+ # 发送初始角色
518
+ role_chunk = self.create_openai_chunk(
519
+ chat_id,
520
+ model,
521
+ {"role": "assistant"}
522
+ )
523
+ yield await self.format_sse_chunk(role_chunk)
524
+
525
+ delta_content = data.get("delta_content", "")
526
+ if delta_content:
527
+ # 处理思考内容格式
528
+ if delta_content.startswith("<details"):
529
+ content = (
530
+ delta_content.split("</summary>\n>")[-1].strip()
531
+ if "</summary>\n>" in delta_content
532
+ else delta_content
533
+ )
534
+ else:
535
+ content = delta_content
536
+
537
+ thinking_chunk = self.create_openai_chunk(
538
+ chat_id,
539
+ model,
540
+ {
541
+ "role": "assistant",
542
+ "thinking": {"content": content}
543
+ }
544
+ )
545
+ yield await self.format_sse_chunk(thinking_chunk)
546
+
547
+ # 处理答案内容
548
+ elif phase == "answer":
549
+ edit_content = data.get("edit_content", "")
550
+ delta_content = data.get("delta_content", "")
551
+
552
+ # 处理思考结束和答案开始
553
+ if edit_content and "</details>\n" in edit_content:
554
+ if has_thinking:
555
+ # 发送思考签名
556
+ thinking_signature = str(int(time.time() * 1000))
557
+ sig_chunk = self.create_openai_chunk(
558
+ chat_id,
559
+ model,
560
+ {
561
+ "role": "assistant",
562
+ "thinking": {
563
+ "content": "",
564
+ "signature": thinking_signature,
565
+ }
566
+ }
567
+ )
568
+ yield await self.format_sse_chunk(sig_chunk)
569
+
570
+ # 提取答案内容
571
+ content_after = edit_content.split("</details>\n")[-1]
572
+ if content_after:
573
+ content_chunk = self.create_openai_chunk(
574
+ chat_id,
575
+ model,
576
+ {
577
+ "role": "assistant",
578
+ "content": content_after
579
+ }
580
+ )
581
+ yield await self.format_sse_chunk(content_chunk)
582
+
583
+ # 处理增量内容
584
+ elif delta_content:
585
+ # 如果还没有发送角色
586
+ if not has_thinking:
587
+ role_chunk = self.create_openai_chunk(
588
+ chat_id,
589
+ model,
590
+ {"role": "assistant"}
591
+ )
592
+ yield await self.format_sse_chunk(role_chunk)
593
+
594
+ content_chunk = self.create_openai_chunk(
595
+ chat_id,
596
+ model,
597
+ {
598
+ "role": "assistant",
599
+ "content": delta_content
600
+ }
601
+ )
602
+ output_data = await self.format_sse_chunk(content_chunk)
603
+ self.logger.debug(f"➡️ 输出内容块到客户端: {output_data}")
604
+ yield output_data
605
+
606
+ # 处理完成
607
+ if data.get("usage"):
608
+ self.logger.info(f"📦 完成响应 - 使用统计: {json.dumps(data['usage'])}")
609
+
610
+ # 只有在非工具调用模式下才发送普通完成信号
611
+ if not tool_handler:
612
+ finish_chunk = self.create_openai_chunk(
613
+ chat_id,
614
+ model,
615
+ {"role": "assistant", "content": ""},
616
+ "stop"
617
+ )
618
+ finish_chunk["usage"] = data["usage"]
619
+
620
+ finish_output = await self.format_sse_chunk(finish_chunk)
621
+ self.logger.debug(f"➡️ 发送完成信号: {finish_output[:1000]}...")
622
+ yield finish_output
623
+ self.logger.debug("➡️ 发送 [DONE]")
624
+ yield "data: [DONE]\n\n"
625
+
626
+ except json.JSONDecodeError as e:
627
+ self.logger.debug(f"❌ JSON解析错误: {e}, 内容: {chunk_str[:1000]}")
628
+ except Exception as e:
629
+ self.logger.error(f"❌ 处理chunk错误: {e}")
630
+
631
+ # 工具处理器会自动发送结束信号,这里不需要重复发送
632
+ if not tool_handler:
633
+ self.logger.debug("📤 发送最终 [DONE] 信号")
634
+ yield "data: [DONE]\n\n"
635
+
636
+ self.logger.info(f"✅ SSE 流处理完成,共处理 {line_count} 行数据")
637
+
638
+ except Exception as e:
639
+ self.logger.error(f"❌ 流式响应处理错误: {e}")
640
+ import traceback
641
+ self.logger.error(traceback.format_exc())
642
+ # 发送错误结束块
643
+ yield await self.format_sse_chunk(
644
+ self.create_openai_chunk(chat_id, model, {}, "stop")
645
+ )
646
+ yield "data: [DONE]\n\n"
647
+
648
+ async def _handle_non_stream_response(
649
+ self,
650
+ response: httpx.Response,
651
+ chat_id: str,
652
+ model: str
653
+ ) -> Dict[str, Any]:
654
+ """处理非流式响应
655
+
656
+ 说明:上游始终以 SSE 形式返回(transform_request 固定 stream=True),
657
+ 因此这里需要聚合 aiter_lines() 的 data: 块,提取 usage、思考内容与答案内容,
658
+ 并最终产出一次性 OpenAI 格式响应。
659
+ """
660
+ final_content = ""
661
+ reasoning_content = ""
662
+ usage_info: Dict[str, int] = {
663
+ "prompt_tokens": 0,
664
+ "completion_tokens": 0,
665
+ "total_tokens": 0,
666
+ }
667
+
668
+ try:
669
+ async for line in response.aiter_lines():
670
+ if not line:
671
+ continue
672
+
673
+ line = line.strip()
674
+
675
+ # 仅处理以 data: 开头的 SSE 行,其余行尝试作为错误/JSON 忽略
676
+ if not line.startswith("data:"):
677
+ # 尝试解析为错误 JSON
678
+ try:
679
+ maybe_err = json.loads(line)
680
+ if isinstance(maybe_err, dict) and (
681
+ "error" in maybe_err or "code" in maybe_err or "message" in maybe_err
682
+ ):
683
+ # 统一错误处理
684
+ msg = (
685
+ (maybe_err.get("error") or {}).get("message")
686
+ if isinstance(maybe_err.get("error"), dict)
687
+ else maybe_err.get("message")
688
+ ) or "上游返回错误"
689
+ return self.handle_error(Exception(msg), "API响应")
690
+ except Exception:
691
+ pass
692
+ continue
693
+
694
+ data_str = line[5:].strip()
695
+ if not data_str or data_str in ("[DONE]", "DONE", "done"):
696
+ continue
697
+
698
+ # 解析 SSE 数据块
699
+ try:
700
+ chunk = json.loads(data_str)
701
+ except json.JSONDecodeError:
702
+ continue
703
+
704
+ if chunk.get("type") != "chat:completion":
705
+ continue
706
+
707
+ data = chunk.get("data", {})
708
+ phase = data.get("phase")
709
+ delta_content = data.get("delta_content", "")
710
+ edit_content = data.get("edit_content", "")
711
+
712
+ # 记录用量(通常在最后块中出现,但这里每次覆盖保持最新)
713
+ if data.get("usage"):
714
+ try:
715
+ usage_info = data["usage"]
716
+ except Exception:
717
+ pass
718
+
719
+ # 思考阶段聚合(去除 <details><summary>... 包裹头)
720
+ if phase == "thinking":
721
+ if delta_content:
722
+ if delta_content.startswith("<details"):
723
+ cleaned = (
724
+ delta_content.split("</summary>\n>")[-1].strip()
725
+ if "</summary>\n>" in delta_content
726
+ else delta_content
727
+ )
728
+ else:
729
+ cleaned = delta_content
730
+ reasoning_content += cleaned
731
+
732
+ # 答案阶段聚合
733
+ elif phase == "answer":
734
+ # 当 edit_content 同时包含思考结束标记与答案时,提取答案部分
735
+ if edit_content and "</details>\n" in edit_content:
736
+ content_after = edit_content.split("</details>\n")[-1]
737
+ if content_after:
738
+ final_content += content_after
739
+ elif delta_content:
740
+ final_content += delta_content
741
+
742
+ except Exception as e:
743
+ self.logger.error(f"❌ 非流式响应处理错误: {e}")
744
+ import traceback
745
+ self.logger.error(traceback.format_exc())
746
+ # 返回统一错误响应
747
+ return self.handle_error(e, "非流式聚合")
748
+
749
+ # 清理并返回
750
+ final_content = (final_content or "").strip()
751
+ reasoning_content = (reasoning_content or "").strip()
752
+
753
+ # 若没有聚合到答案,但有思考内容,则保底返回思考内容
754
+ if not final_content and reasoning_content:
755
+ final_content = reasoning_content
756
+
757
+ # 返回包含推理内容的标准响应(若无推理则不会携带)
758
+ return self.create_openai_response_with_reasoning(
759
+ chat_id,
760
+ model,
761
+ final_content,
762
+ reasoning_content,
763
+ usage_info,
764
+ )
app/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from app.utils import sse_tool_handler, reload_config, logger
5
+
6
+ __all__ = ["sse_tool_handler", "reload_config", "logger"]
app/utils/logger.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import sys
5
+ from pathlib import Path
6
+ from loguru import logger
7
+
8
+ # Global logger instance
9
+ app_logger = None
10
+
11
+
12
+ def setup_logger(log_dir, log_retention_days=7, log_rotation="1 day", debug_mode=False):
13
+ """
14
+ Create a logger instance
15
+
16
+ Parameters:
17
+ log_dir (str): 日志目录
18
+ log_retention_days (int): 日志保留天数
19
+ log_rotation (str): 日志轮转间隔
20
+ debug_mode (bool): 是否开启调试模式
21
+ """
22
+ global app_logger
23
+
24
+ try:
25
+ logger.remove()
26
+
27
+ log_level = "DEBUG" if debug_mode else "INFO"
28
+
29
+ console_format = (
30
+ "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>"
31
+ if not debug_mode
32
+ else "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | "
33
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | <level>{message}</level>"
34
+ )
35
+
36
+ logger.add(sys.stderr, level=log_level, format=console_format, colorize=True)
37
+
38
+ if debug_mode:
39
+ log_path = Path(log_dir)
40
+ log_path.mkdir(parents=True, exist_ok=True)
41
+
42
+ log_file = log_path / "{time:YYYY-MM-DD}.log"
43
+ file_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} | {message}"
44
+
45
+ logger.add(
46
+ str(log_file),
47
+ level=log_level,
48
+ format=file_format,
49
+ rotation=log_rotation,
50
+ retention=f"{log_retention_days} days",
51
+ encoding="utf-8",
52
+ compression="zip",
53
+ enqueue=True,
54
+ catch=True,
55
+ )
56
+
57
+ app_logger = logger
58
+
59
+ return logger
60
+
61
+ except Exception as e:
62
+ logger.remove()
63
+ logger.add(sys.stderr, level="ERROR")
64
+ logger.error(f"日志系统配置失败: {e}")
65
+ raise
66
+
67
+
68
+ def get_logger():
69
+ """Get the logger instance"""
70
+ global app_logger
71
+ if app_logger is None:
72
+ # 如果没有设置过logger,使用默认配置
73
+ logger.remove() # 移除所有现有处理器
74
+ logger.add(sys.stderr, level="INFO", format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | <level>{message}</level>")
75
+ app_logger = logger
76
+ return app_logger
77
+
78
+
79
+ if __name__ == "__main__":
80
+ """Test the logger"""
81
+ import tempfile
82
+
83
+ with tempfile.TemporaryDirectory() as temp_dir:
84
+ try:
85
+ setup_logger(temp_dir, debug_mode=True)
86
+
87
+ logger.debug("这是一条调试日志")
88
+ logger.info("这是一条信息日志")
89
+ logger.warning("这是一条警告日志")
90
+ logger.error("这是一条错误日志")
91
+ logger.critical("这是一条严重日志")
92
+
93
+ try:
94
+ 1 / 0
95
+ except ZeroDivisionError:
96
+ logger.exception("发生了除零异常")
97
+
98
+ print("✅ 日志测试完成")
99
+
100
+ logger.remove()
101
+
102
+ except Exception as e:
103
+ print(f"❌ 日志测试失败: {e}")
104
+ logger.remove()
105
+ raise
app/utils/reload_config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 热重载配置模块
6
+ 定义 Granian 服务器热重载时需要忽略的目录和文件模式
7
+ """
8
+
9
+ # 忽略的目录列表
10
+ RELOAD_IGNORE_DIRS = [
11
+ "logs", # 忽略日志目录
12
+ "storage", # 忽略存储目录
13
+ "__pycache__", # 忽略 Python 缓存
14
+ ".git", # 忽略 git 目录
15
+ "node_modules", # 忽略 node_modules
16
+ "migrations", # 忽略数据库迁移目录
17
+ ".pytest_cache", # 忽略 pytest 缓存
18
+ ".venv", # 忽略虚拟环境
19
+ "venv", # 忽略虚拟环境
20
+ "env", # 忽略环境目录
21
+ ".mypy_cache", # 忽略 mypy 缓存
22
+ ".ruff_cache", # 忽略 ruff 缓存
23
+ "dist", # 忽略构建分发目录
24
+ "build", # 忽略构建目录
25
+ ".coverage", # 忽略测试覆盖率文件
26
+ "htmlcov", # 忽略覆盖率报告目录
27
+ "tests", # 忽略测试目录
28
+ "z-ai2api-server.pid", # 忽略 PID 文件
29
+ ]
30
+
31
+ # 忽略的文件模式(正则表达式)
32
+ RELOAD_IGNORE_PATTERNS = [
33
+ # 日志文件
34
+ r".*\.log$",
35
+ r".*\.log\.\d+$",
36
+ # 数据库文件
37
+ r".*\.sqlite3.*",
38
+ r".*\.db$",
39
+ r".*\.db-.*$",
40
+ # Python 相关
41
+ r".*\.pyc$",
42
+ r".*\.pyo$",
43
+ r".*\.pyd$",
44
+ # 临时文件
45
+ r".*\.tmp$",
46
+ r".*\.temp$",
47
+ r".*\.swp$",
48
+ r".*\.swo$",
49
+ r".*~$",
50
+ # 系统文件
51
+ r".*\.DS_Store$",
52
+ r".*Thumbs\.db$",
53
+ r".*\.directory$",
54
+ # 编辑器文件
55
+ r".*\.vscode.*",
56
+ r".*\.idea.*",
57
+ # 测试和覆盖率
58
+ r".*\.coverage$",
59
+ r".*\.pytest_cache.*",
60
+ # 构建文件
61
+ r".*\.egg-info.*",
62
+ r".*\.wheel$",
63
+ r".*\.whl$",
64
+ # 版本控制
65
+ r".*\.git.*",
66
+ r".*\.gitignore$",
67
+ r".*\.gitkeep$",
68
+ # 配置文件备份
69
+ r".*\.bak$",
70
+ r".*\.backup$",
71
+ r".*\.orig$",
72
+ # 锁文件
73
+ r".*\.lock$",
74
+ r".*\.pid$",
75
+ ]
76
+
77
+ # 监视的路径(只监视应用相关代码)
78
+ RELOAD_WATCH_PATHS = [
79
+ "app", # 应用主目录
80
+ "main.py", # 主入口文件
81
+ ]
82
+
83
+ # 热重载配置
84
+ RELOAD_CONFIG = {
85
+ "reload_ignore_dirs": RELOAD_IGNORE_DIRS,
86
+ "reload_ignore_patterns": RELOAD_IGNORE_PATTERNS,
87
+ "reload_paths": RELOAD_WATCH_PATHS,
88
+ "reload_tick": 100, # 监视频率(毫秒)
89
+ }
app/utils/sse_tool_handler.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ SSE Tool Handler
6
+
7
+ 处理 Z.AI SSE 流数据并转换为 OpenAI 兼容格式的工具调用处理器。
8
+
9
+ 主要功能:
10
+ - 解析 glm_block 格式的工具调用
11
+ - 从 metadata.arguments 提取完整参数
12
+ - 支持多阶段处理:thinking → tool_call → other → answer
13
+ - 输出符合 OpenAI API 规范的流式响应
14
+ """
15
+
16
+ import json
17
+ import time
18
+ from typing import Dict, Any, Generator
19
+ from enum import Enum
20
+
21
+ from app.utils.logger import get_logger
22
+
23
+ logger = get_logger()
24
+
25
+
26
+ class SSEPhase(Enum):
27
+ """SSE 处理阶段枚举"""
28
+ THINKING = "thinking"
29
+ TOOL_CALL = "tool_call"
30
+ OTHER = "other"
31
+ ANSWER = "answer"
32
+ DONE = "done"
33
+
34
+
35
+ class SSEToolHandler:
36
+ """SSE 工具调用处理器"""
37
+
38
+ def __init__(self, model: str, stream: bool = True):
39
+ self.model = model
40
+ self.stream = stream
41
+
42
+ # 状态管理
43
+ self.current_phase = None
44
+ self.has_tool_call = False
45
+
46
+ # 工具调用状态
47
+ self.tool_id = ""
48
+ self.tool_name = ""
49
+ self.tool_args = ""
50
+ self.tool_call_usage = {}
51
+ self.content_index = 0 # 工具调用索引
52
+
53
+ # 性能优化:内容缓冲
54
+ self.content_buffer = ""
55
+ self.buffer_size = 0
56
+ self.last_flush_time = time.time()
57
+ self.flush_interval = 0.05 # 50ms 刷新间隔
58
+ self.max_buffer_size = 100 # 最大缓冲字符数
59
+
60
+ logger.debug(f"🔧 初始化工具处理器: model={model}, stream={stream}")
61
+
62
+ def process_sse_chunk(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]:
63
+ """
64
+ 处理 SSE 数据块,返回 OpenAI 格式的流式响应
65
+
66
+ Args:
67
+ chunk_data: Z.AI SSE 数据块
68
+
69
+ Yields:
70
+ str: OpenAI 格式的 SSE 响应行
71
+ """
72
+ try:
73
+ phase = chunk_data.get("phase")
74
+ edit_content = chunk_data.get("edit_content", "")
75
+ delta_content = chunk_data.get("delta_content", "")
76
+ edit_index = chunk_data.get("edit_index")
77
+ usage = chunk_data.get("usage", {})
78
+
79
+ # 数据验证
80
+ if not phase:
81
+ logger.warning("⚠️ 收到无效的 SSE 块:缺少 phase 字段")
82
+ return
83
+
84
+ # 阶段变化检测和日志
85
+ if phase != self.current_phase:
86
+ # 阶段变化时强制刷新缓冲区
87
+ if hasattr(self, 'content_buffer') and self.content_buffer:
88
+ yield from self._flush_content_buffer()
89
+
90
+ logger.info(f"📈 SSE 阶段变化: {self.current_phase} → {phase}")
91
+ content_preview = edit_content or delta_content
92
+ if content_preview:
93
+ logger.debug(f" 📝 内容预览: {content_preview[:1000]}{'...' if len(content_preview) > 1000 else ''}")
94
+ if edit_index is not None:
95
+ logger.debug(f" 📍 edit_index: {edit_index}")
96
+ self.current_phase = phase
97
+
98
+ # 根据阶段处理
99
+ if phase == SSEPhase.THINKING.value:
100
+ yield from self._process_thinking_phase(delta_content)
101
+
102
+ elif phase == SSEPhase.TOOL_CALL.value:
103
+ yield from self._process_tool_call_phase(edit_content)
104
+
105
+ elif phase == SSEPhase.OTHER.value:
106
+ yield from self._process_other_phase(usage, edit_content)
107
+
108
+ elif phase == SSEPhase.ANSWER.value:
109
+ yield from self._process_answer_phase(delta_content)
110
+
111
+ elif phase == SSEPhase.DONE.value:
112
+ yield from self._process_done_phase(chunk_data)
113
+ else:
114
+ logger.warning(f"⚠️ 未知的 SSE 阶段: {phase}")
115
+
116
+ except Exception as e:
117
+ logger.error(f"❌ 处理 SSE 块时发生错误: {e}")
118
+ logger.debug(f" 📦 错误块数据: {chunk_data}")
119
+ # 不中断流,继续处理后续块
120
+
121
+ def _process_thinking_phase(self, delta_content: str) -> Generator[str, None, None]:
122
+ """处理思考阶段"""
123
+ if not delta_content:
124
+ return
125
+
126
+ logger.debug(f"🤔 思考内容: +{len(delta_content)} 字符")
127
+
128
+ # 在流模式下输出思考内容
129
+ if self.stream:
130
+ chunk = self._create_content_chunk(delta_content)
131
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
132
+
133
+ def _process_tool_call_phase(self, edit_content: str) -> Generator[str, None, None]:
134
+ """处理工具调用阶段"""
135
+ if not edit_content:
136
+ return
137
+
138
+ logger.debug(f"🔧 进入工具调用阶段,内容长度: {len(edit_content)}")
139
+
140
+ # 检测 glm_block 标记
141
+ if "<glm_block " in edit_content:
142
+ yield from self._handle_glm_blocks(edit_content)
143
+ else:
144
+ # 没有 glm_block 标记,可能是参数补充
145
+ if self.has_tool_call:
146
+ # 只累积���数部分,找到第一个 ", "result"" 之前的内容
147
+ result_pos = edit_content.find('", "result"')
148
+ if result_pos > 0:
149
+ param_fragment = edit_content[:result_pos]
150
+ self.tool_args += param_fragment
151
+ logger.debug(f"📦 累积参数片段: {param_fragment}")
152
+ else:
153
+ # 如果没有找到结束标记,累积整个内容(可能是中间片段)
154
+ self.tool_args += edit_content
155
+ logger.debug(f"📦 累积参数片段: {edit_content[:100]}...")
156
+
157
+ def _handle_glm_blocks(self, edit_content: str) -> Generator[str, None, None]:
158
+ """处理 glm_block 标记的内容"""
159
+ blocks = edit_content.split('<glm_block ')
160
+ logger.debug(f"📦 分割得到 {len(blocks)} 个块")
161
+
162
+ for index, block in enumerate(blocks):
163
+ if not block.strip():
164
+ continue
165
+
166
+ if index == 0:
167
+ # 第一个块:提取参数片段
168
+ if self.has_tool_call:
169
+ logger.debug(f"📦 从第一个块提取参数片段")
170
+ # 找到 "result" 的位置,提取之前的参数片段
171
+ result_pos = edit_content.find('"result"')
172
+ if result_pos > 0:
173
+ # 往前退3个字符去掉 ", "
174
+ param_fragment = edit_content[:result_pos - 3]
175
+ self.tool_args += param_fragment
176
+ logger.debug(f"📦 累积参数片段: {param_fragment}")
177
+ else:
178
+ # 没有活跃工具调用,跳过第一个块
179
+ continue
180
+ else:
181
+ # 后续块:处理新工具调用
182
+ if "</glm_block>" not in block:
183
+ continue
184
+
185
+ # 如果有活跃的工具调用,先完成它
186
+ if self.has_tool_call:
187
+ # 补全参数并完成工具调用
188
+ self.tool_args += '"' # 补全最后的引号
189
+ yield from self._finish_current_tool()
190
+
191
+ # 处理新工具调用
192
+ yield from self._process_metadata_block(block)
193
+
194
+ def _process_metadata_block(self, block: str) -> Generator[str, None, None]:
195
+ """处理包含工具元数据的块"""
196
+ try:
197
+ # 提取 JSON 内容
198
+ start_pos = block.find('>')
199
+ end_pos = block.rfind('</glm_block>')
200
+
201
+ if start_pos == -1 or end_pos == -1:
202
+ logger.warning(f"❌ 无法找到 JSON 内容边界: {block[:1000]}...")
203
+ return
204
+
205
+ json_content = block[start_pos + 1:end_pos]
206
+ logger.debug(f"📦 提取的 JSON 内容: {json_content[:1000]}...")
207
+
208
+ # 解析工具元数据
209
+ metadata_obj = json.loads(json_content)
210
+
211
+ if "data" in metadata_obj and "metadata" in metadata_obj["data"]:
212
+ metadata = metadata_obj["data"]["metadata"]
213
+
214
+ # 开始新的工具调用
215
+ self.tool_id = metadata.get("id", f"call_{int(time.time() * 1000000)}")
216
+ self.tool_name = metadata.get("name", "unknown")
217
+ self.has_tool_call = True
218
+
219
+ # 只有在这是第二个及以后的工具调用时才递增 index
220
+ # 第一个工具调用应该使用 index 0
221
+
222
+ # 从 metadata.arguments 获取参数起始部分
223
+ if "arguments" in metadata:
224
+ arguments_str = metadata["arguments"]
225
+ # 去掉最后一个字符
226
+ self.tool_args = arguments_str[:-1] if arguments_str.endswith('"') else arguments_str
227
+ logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 初始参数: {self.tool_args}")
228
+ else:
229
+ self.tool_args = "{}"
230
+ logger.debug(f"🎯 新工具调用: {self.tool_name}(id={self.tool_id}), 空参数")
231
+
232
+ except (json.JSONDecodeError, KeyError, AttributeError) as e:
233
+ logger.error(f"❌ 解析工具元数据失败: {e}, 块内容: {block[:1000]}...")
234
+
235
+ # 确保返回生成器(即使为空)
236
+ if False: # 永远不会执行,但确保函数是生成器
237
+ yield
238
+
239
+ def _process_other_phase(self, usage: Dict[str, Any], edit_content: str = "") -> Generator[str, None, None]:
240
+ """处理其他阶段"""
241
+ # 保存使用统计信息
242
+ if usage:
243
+ self.tool_call_usage = usage
244
+ logger.debug(f"📊 保存使用统计: {usage}")
245
+
246
+ # 工具调用完成判断:检测到 "null," 开头的 edit_content
247
+ if self.has_tool_call and edit_content and edit_content.startswith("null,"):
248
+ logger.info(f"🏁 检测到工具调用结束标记")
249
+
250
+ # 完成当前工具调用
251
+ yield from self._finish_current_tool()
252
+
253
+ # 发��流结束标记
254
+ if self.stream:
255
+ yield "data: [DONE]\n\n"
256
+
257
+ # 重置状态
258
+ self._reset_all_state()
259
+
260
+ def _process_answer_phase(self, delta_content: str) -> Generator[str, None, None]:
261
+ """处理回答阶段(优化版本)"""
262
+ if not delta_content:
263
+ return
264
+
265
+ logger.info(f"📝 工具处理器收到答案内容: {delta_content[:50]}...")
266
+
267
+ # 添加到缓冲区
268
+ self.content_buffer += delta_content
269
+ self.buffer_size += len(delta_content)
270
+
271
+ current_time = time.time()
272
+ time_since_last_flush = current_time - self.last_flush_time
273
+
274
+ # 检查是否需要刷新缓冲区
275
+ should_flush = (
276
+ self.buffer_size >= self.max_buffer_size or # 缓冲区满了
277
+ time_since_last_flush >= self.flush_interval or # 时间间隔到了
278
+ '\n' in delta_content or # 包含换行符
279
+ '。' in delta_content or '!' in delta_content or '?' in delta_content # 包含句子结束符
280
+ )
281
+
282
+ if should_flush and self.content_buffer:
283
+ yield from self._flush_content_buffer()
284
+
285
+ def _flush_content_buffer(self) -> Generator[str, None, None]:
286
+ """刷新内容缓冲区"""
287
+ if not self.content_buffer:
288
+ return
289
+
290
+ logger.info(f"💬 工具处理器刷新缓冲区: {self.buffer_size} 字符 - {self.content_buffer[:50]}...")
291
+
292
+ if self.stream:
293
+ chunk = self._create_content_chunk(self.content_buffer)
294
+ output_data = f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
295
+ logger.info(f"➡️ 工具处理器输出: {output_data[:100]}...")
296
+ yield output_data
297
+
298
+ # 清空缓冲区
299
+ self.content_buffer = ""
300
+ self.buffer_size = 0
301
+ self.last_flush_time = time.time()
302
+
303
+ def _process_done_phase(self, chunk_data: Dict[str, Any]) -> Generator[str, None, None]:
304
+ """处理完成阶段"""
305
+ logger.info("🏁 对话完成")
306
+
307
+ # 先刷新任何剩余的缓冲内容
308
+ if self.content_buffer:
309
+ yield from self._flush_content_buffer()
310
+
311
+ # 完成任何未完成的工具调用
312
+ if self.has_tool_call:
313
+ yield from self._finish_current_tool()
314
+
315
+ # 发送流结束标记
316
+ if self.stream:
317
+ # 创建最终的完成块
318
+ final_chunk = {
319
+ "id": f"chatcmpl-{int(time.time())}",
320
+ "object": "chat.completion.chunk",
321
+ "created": int(time.time()),
322
+ "model": self.model,
323
+ "choices": [{
324
+ "index": 0,
325
+ "delta": {},
326
+ "finish_reason": "stop"
327
+ }]
328
+ }
329
+
330
+ # 如果有 usage 信息,添加到最终块中
331
+ if "usage" in chunk_data:
332
+ final_chunk["usage"] = chunk_data["usage"]
333
+
334
+ yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
335
+ yield "data: [DONE]\n\n"
336
+
337
+ # 重置所有状态
338
+ self._reset_all_state()
339
+
340
+ def _finish_current_tool(self) -> Generator[str, None, None]:
341
+ """完成当前工具调用"""
342
+ if not self.has_tool_call:
343
+ return
344
+
345
+ # 修复参数格式
346
+ fixed_args = self._fix_tool_arguments(self.tool_args)
347
+ logger.debug(f"✅ 完成工具调用: {self.tool_name}, 参数: {fixed_args}")
348
+
349
+ # 输出工具调用(开始 + 参数 + 完成)
350
+ if self.stream:
351
+ # 发送工具开始块
352
+ start_chunk = self._create_tool_start_chunk()
353
+ yield f"data: {json.dumps(start_chunk, ensure_ascii=False)}\n\n"
354
+
355
+ # 发送参数块
356
+ args_chunk = self._create_tool_arguments_chunk(fixed_args)
357
+ yield f"data: {json.dumps(args_chunk, ensure_ascii=False)}\n\n"
358
+
359
+ # 发送完成块
360
+ finish_chunk = self._create_tool_finish_chunk()
361
+ yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
362
+
363
+ # 重置工具状态
364
+ self._reset_tool_state()
365
+
366
+ def _fix_tool_arguments(self, raw_args: str) -> str:
367
+ """使用 json-repair 库修复工具参数格式"""
368
+ if not raw_args or raw_args == "{}":
369
+ return "{}"
370
+
371
+ logger.debug(f"🔧 开始修复参数: {raw_args[:1000]}{'...' if len(raw_args) > 1000 else ''}")
372
+
373
+ # 统一的修复流程:预处理 -> json-repair -> 后处理
374
+ try:
375
+ # 1. 预处理:只处理 json-repair 无法处理的问题
376
+ processed_args = self._preprocess_json_string(raw_args.strip())
377
+
378
+ # 2. 使用 json-repair 进行主要修复
379
+ from json_repair import repair_json
380
+ repaired_json = repair_json(processed_args)
381
+ logger.debug(f"🔧 json-repair 修复结果: {repaired_json}")
382
+
383
+ # 3. 解析并后处理
384
+ args_obj = json.loads(repaired_json)
385
+ args_obj = self._post_process_args(args_obj)
386
+
387
+ # 4. 生成最终结果
388
+ fixed_result = json.dumps(args_obj, ensure_ascii=False)
389
+
390
+ return fixed_result
391
+
392
+ except Exception as e:
393
+ logger.error(f"❌ JSON 修复失败: {e}, 原始参数: {raw_args[:1000]}..., 使用空参数")
394
+ return "{}"
395
+
396
+ def _post_process_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
397
+ """统一的后处理方法"""
398
+ # 修复路径中的过度转义
399
+ args_obj = self._fix_path_escaping_in_args(args_obj)
400
+
401
+ # 修复命令中的多余引号
402
+ args_obj = self._fix_command_quotes(args_obj)
403
+
404
+ return args_obj
405
+
406
+ def _preprocess_json_string(self, text: str) -> str:
407
+ """预处理 JSON 字符串,只处理 json-repair 无法处理的问题"""
408
+ import re
409
+
410
+ # 只保留 json-repair 无法处理的预处理步骤
411
+
412
+ # 1. 修复缺少开始括号的情况(json-repair 无法处理)
413
+ if not text.startswith('{') and text.endswith('}'):
414
+ text = '{' + text
415
+ logger.debug(f"🔧 补全开始括号")
416
+
417
+ # 2. 修复末尾多余的反斜杠和引号(json-repair 可能处理不当)
418
+ # 匹配模式:字符串值末尾的 \" 后面跟着 } 或 ,
419
+ # 例如:{"url":"https://www.bilibili.com\"} -> {"url":"https://www.bilibili.com"}
420
+ # 例如:{"url":"https://www.bilibili.com\",} -> {"url":"https://www.bilibili.com",}
421
+ pattern = r'([^\\])\\"([}\s,])'
422
+ if re.search(pattern, text):
423
+ text = re.sub(pattern, r'\1"\2', text)
424
+ logger.debug(f"🔧 修复末尾多余的反斜杠")
425
+
426
+ return text
427
+
428
+ def _fix_path_escaping_in_args(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
429
+ """修复参数对象中路径的过度转义问题"""
430
+ import re
431
+
432
+ # 需要检查的路径字段
433
+ path_fields = ['file_path', 'path', 'directory', 'folder']
434
+
435
+ for field in path_fields:
436
+ if field in args_obj and isinstance(args_obj[field], str):
437
+ path_value = args_obj[field]
438
+
439
+ # 检查是否是Windows路径且包含过度转义
440
+ if path_value.startswith('C:') and '\\\\' in path_value:
441
+ logger.debug(f"🔍 检查路径字段 {field}: {repr(path_value)}")
442
+
443
+ # 分析路径结构:正常路径应该是 C:\Users\...
444
+ # 但过度转义的路径可能是 C:\Users\\Documents(多了一个反斜杠)
445
+ # 我们需要找到不正常的双反斜杠模式并修复
446
+
447
+ # 先检查是否有不正常的双反斜杠(不在路径开头)
448
+ # 正常:C:\Users\Documents
449
+ # 异常:C:\Users\\Documents 或 C:\Users\\\\Documents
450
+
451
+ # 使用更精确的模式:匹配路径分隔符后的额外反斜杠
452
+ # 但要保留正常的路径分隔符
453
+ fixed_path = path_value
454
+
455
+ # 检查是否有连续的多个反斜杠(超过正常的路径分隔符)
456
+ if '\\\\' in path_value:
457
+ # 计算反斜杠的数量,如果超过正常数量就修复
458
+ parts = path_value.split('\\')
459
+ # 重新组装路径,去除空的部分(由多余的反斜杠造成)
460
+ clean_parts = [part for part in parts if part]
461
+ if len(clean_parts) > 1:
462
+ fixed_path = '\\'.join(clean_parts)
463
+
464
+ logger.debug(f"🔍 修复后路径: {repr(fixed_path)}")
465
+
466
+ if fixed_path != path_value:
467
+ args_obj[field] = fixed_path
468
+ logger.debug(f"🔧 修复字段 {field} 的路径转义: {path_value} -> {fixed_path}")
469
+ else:
470
+ logger.debug(f"🔍 路径无需修复: {path_value}")
471
+
472
+ return args_obj
473
+
474
+ def _fix_command_quotes(self, args_obj: Dict[str, Any]) -> Dict[str, Any]:
475
+ """修复命令中的多余引号问题"""
476
+ import re
477
+
478
+ # 检查命令字段
479
+ if 'command' in args_obj and isinstance(args_obj['command'], str):
480
+ command = args_obj['command']
481
+
482
+ # 检查是否以双引号结尾(多余的引号)
483
+ if command.endswith('""'):
484
+ logger.debug(f"🔧 发现命令末尾多余引号: {command}")
485
+ # 移除最后一个多余的引号
486
+ fixed_command = command[:-1]
487
+ args_obj['command'] = fixed_command
488
+ logger.debug(f"🔧 修复命令引号: {command} -> {fixed_command}")
489
+
490
+ # 检查其他可能的引号问题
491
+ # 例如:路径末尾的 \"" 模式
492
+ elif re.search(r'\\""+$', command):
493
+ logger.debug(f"🔧 发现命令末尾引号模式问题: {command}")
494
+ # 修复路径末尾的引号问题
495
+ fixed_command = re.sub(r'\\""+$', '\\"', command)
496
+ args_obj['command'] = fixed_command
497
+ logger.debug(f"🔧 修复命令引号模式: {command} -> {fixed_command}")
498
+
499
+ return args_obj
500
+
501
+ def _create_content_chunk(self, content: str) -> Dict[str, Any]:
502
+ """创建内容块"""
503
+ return {
504
+ "id": f"chatcmpl-{int(time.time())}",
505
+ "object": "chat.completion.chunk",
506
+ "created": int(time.time()),
507
+ "model": self.model,
508
+ "choices": [{
509
+ "index": 0,
510
+ "delta": {
511
+ "role": "assistant",
512
+ "content": content
513
+ },
514
+ "finish_reason": None
515
+ }]
516
+ }
517
+
518
+ def _create_tool_start_chunk(self) -> Dict[str, Any]:
519
+ """创建工具开始块"""
520
+ return {
521
+ "id": f"chatcmpl-{int(time.time())}",
522
+ "object": "chat.completion.chunk",
523
+ "created": int(time.time()),
524
+ "model": self.model,
525
+ "choices": [{
526
+ "index": 0,
527
+ "delta": {
528
+ "role": "assistant",
529
+ "tool_calls": [{
530
+ "index": self.content_index,
531
+ "id": self.tool_id,
532
+ "type": "function",
533
+ "function": {
534
+ "name": self.tool_name,
535
+ "arguments": ""
536
+ }
537
+ }]
538
+ },
539
+ "finish_reason": None
540
+ }]
541
+ }
542
+
543
+ def _create_tool_arguments_chunk(self, arguments: str) -> Dict[str, Any]:
544
+ """创建工具参数块"""
545
+ return {
546
+ "id": f"chatcmpl-{int(time.time())}",
547
+ "object": "chat.completion.chunk",
548
+ "created": int(time.time()),
549
+ "model": self.model,
550
+ "choices": [{
551
+ "index": 0,
552
+ "delta": {
553
+ "tool_calls": [{
554
+ "index": self.content_index,
555
+ "id": self.tool_id,
556
+ "function": {
557
+ "arguments": arguments
558
+ }
559
+ }]
560
+ },
561
+ "finish_reason": None
562
+ }]
563
+ }
564
+
565
+ def _create_tool_finish_chunk(self) -> Dict[str, Any]:
566
+ """创建工具完成块"""
567
+ chunk = {
568
+ "id": f"chatcmpl-{int(time.time())}",
569
+ "object": "chat.completion.chunk",
570
+ "created": int(time.time()),
571
+ "model": self.model,
572
+ "choices": [{
573
+ "index": 0,
574
+ "delta": {
575
+ "tool_calls": []
576
+ },
577
+ "finish_reason": "tool_calls"
578
+ }]
579
+ }
580
+
581
+ # 添加使用统计(如果有)
582
+ if self.tool_call_usage:
583
+ chunk["usage"] = self.tool_call_usage
584
+
585
+ return chunk
586
+
587
+ def _reset_tool_state(self):
588
+ """重置工具状态"""
589
+ self.tool_id = ""
590
+ self.tool_name = ""
591
+ self.tool_args = ""
592
+ self.has_tool_call = False
593
+ # content_index 在单次对话中应该保持不变,只有在新的工具调用开始时才递增
594
+
595
+ def _reset_all_state(self):
596
+ """重置所有状态"""
597
+ # 先刷新任何剩余的缓冲内容
598
+ if hasattr(self, 'content_buffer') and self.content_buffer:
599
+ list(self._flush_content_buffer()) # 消费生成器
600
+
601
+ self._reset_tool_state()
602
+ self.current_phase = None
603
+ self.tool_call_usage = {}
604
+
605
+ # 重置缓冲区
606
+ self.content_buffer = ""
607
+ self.buffer_size = 0
608
+ self.last_flush_time = time.time()
609
+
610
+ # content_index 重置为 0,为下一轮对话做准备
611
+ self.content_index = 0
612
+ logger.debug("🔄 重置所有处理器状态")
app/utils/token_pool.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Token池管理器
6
+ 实现AUTH_TOKEN的轮询机制,提供负载均衡和容错功能
7
+ """
8
+
9
+ import asyncio
10
+ import time
11
+ from typing import Dict, List, Optional, Tuple
12
+ from dataclasses import dataclass, field
13
+ from threading import Lock
14
+ import httpx
15
+
16
+ from app.utils.logger import logger
17
+
18
+
19
+ @dataclass
20
+ class TokenStatus:
21
+ """Token状态信息"""
22
+ token: str
23
+ is_available: bool = True
24
+ failure_count: int = 0
25
+ last_failure_time: float = 0.0
26
+ last_success_time: float = 0.0
27
+ total_requests: int = 0
28
+ successful_requests: int = 0
29
+ token_type: str = "unknown" # "user", "guest", "unknown"
30
+
31
+ @property
32
+ def success_rate(self) -> float:
33
+ """成功率"""
34
+ if self.total_requests == 0:
35
+ return 1.0
36
+ return self.successful_requests / self.total_requests
37
+
38
+ @property
39
+ def is_healthy(self) -> bool:
40
+ """
41
+ 是否健康
42
+
43
+ 健康的定义:
44
+ 1. 必须是认证用户token (token_type = "user")
45
+ 2. 当前可用 (is_available = True)
46
+ 3. 成功率 >= 50% 或者总请求数 <= 3(新token容错)
47
+
48
+ 注意:guest token不应该在AUTH_TOKENS中
49
+ """
50
+ # guest token永远不健康
51
+ if self.token_type == "guest":
52
+ return False
53
+
54
+ # 未知类型token不健康
55
+ if self.token_type != "user":
56
+ return False
57
+
58
+ # 不可用的token不健康
59
+ if not self.is_available:
60
+ return False
61
+
62
+ # 对于认证用户token,基于成功率判断
63
+ # 新token或请求数很少时,给予容错
64
+ if self.total_requests <= 3:
65
+ return self.failure_count == 0
66
+
67
+ # 基于成功率判断健康状态
68
+ return self.success_rate >= 0.5
69
+
70
+
71
+ class TokenPool:
72
+ """Token池管理器"""
73
+
74
+ def __init__(self, tokens: List[str], failure_threshold: int = 3, recovery_timeout: int = 1800):
75
+ """
76
+ 初始化Token池
77
+
78
+ Args:
79
+ tokens: token列表
80
+ failure_threshold: 失败阈值,超过此次数将标记为不可用
81
+ recovery_timeout: 恢复超时时间(秒),失败token在此时间后重新尝试
82
+ """
83
+ self.failure_threshold = failure_threshold
84
+ self.recovery_timeout = recovery_timeout
85
+ self._lock = Lock()
86
+ self._current_index = 0
87
+
88
+ # 初始化token状态
89
+ self.token_statuses: Dict[str, TokenStatus] = {}
90
+ original_count = len(tokens)
91
+ unique_tokens = []
92
+
93
+ # 去重处理
94
+ for token in tokens:
95
+ if token and token not in self.token_statuses: # 过滤空token和重复token
96
+ # 预设为认证用户token,因为这些是用户手动配置的token
97
+ self.token_statuses[token] = TokenStatus(token=token, token_type="user")
98
+ unique_tokens.append(token)
99
+
100
+ duplicate_count = original_count - len(unique_tokens)
101
+ if duplicate_count > 0:
102
+ logger.warning(f"⚠️ 检测到 {duplicate_count} 个重复token,已自动去重")
103
+
104
+ if not self.token_statuses:
105
+ logger.warning("⚠️ Token池为空,将依赖匿名模式")
106
+ # else:
107
+ # logger.info(f"🔧 初始化Token池,共 {len(self.token_statuses)} 个token")
108
+
109
+ def get_next_token(self) -> Optional[str]:
110
+ """
111
+ 获取下一个可用的token(轮询算法)
112
+
113
+ Returns:
114
+ 可用的token,如果没有可用token则返回None
115
+ """
116
+ with self._lock:
117
+ if not self.token_statuses:
118
+ return None
119
+
120
+ available_tokens = self._get_available_tokens()
121
+ if not available_tokens:
122
+ # 尝试恢复过期的失败token
123
+ self._try_recover_failed_tokens()
124
+ available_tokens = self._get_available_tokens()
125
+
126
+ if not available_tokens:
127
+ logger.warning("⚠️ 没有可用的token")
128
+ return None
129
+
130
+ # 轮询选择token
131
+ token = available_tokens[self._current_index % len(available_tokens)]
132
+ self._current_index = (self._current_index + 1) % len(available_tokens)
133
+
134
+ return token
135
+
136
+ def _get_available_tokens(self) -> List[str]:
137
+ """
138
+ 获取当前可用的认证用户token列表
139
+
140
+ 返回满足以下条件的token:
141
+ 1. is_available = True (可用状态)
142
+ 2. token_type == "user" (认证用户token)
143
+
144
+ 这确保轮询机制只会选择有效的认证用户token,跳过匿名用户token
145
+ """
146
+ available_user_tokens = [
147
+ status.token for status in self.token_statuses.values()
148
+ if status.is_available and status.token_type == "user"
149
+ ]
150
+
151
+ # 检查是否有匿名用户token并给出警告
152
+ if not available_user_tokens and self.token_statuses:
153
+ guest_tokens = [
154
+ status.token for status in self.token_statuses.values()
155
+ if status.token_type == "guest"
156
+ ]
157
+ if guest_tokens:
158
+ logger.warning(f"⚠️ 检测到 {len(guest_tokens)} 个匿名用户token,轮询机制将跳过这些token")
159
+
160
+ return available_user_tokens
161
+
162
+ def _try_recover_failed_tokens(self):
163
+ """尝试恢复失败的token"""
164
+ current_time = time.time()
165
+ recovered_count = 0
166
+
167
+ for status in self.token_statuses.values():
168
+ if (not status.is_available and
169
+ current_time - status.last_failure_time > self.recovery_timeout):
170
+ status.is_available = True
171
+ status.failure_count = 0
172
+ recovered_count += 1
173
+ logger.info(f"🔄 恢复失败token: {status.token[:20]}...")
174
+
175
+ if recovered_count > 0:
176
+ logger.info(f"✅ 恢复了 {recovered_count} 个失败的token")
177
+
178
+ def mark_token_success(self, token: str):
179
+ """标记token使用成功"""
180
+ with self._lock:
181
+ if token in self.token_statuses:
182
+ status = self.token_statuses[token]
183
+ status.total_requests += 1
184
+ status.successful_requests += 1
185
+ status.last_success_time = time.time()
186
+ status.failure_count = 0 # 重置失败计数
187
+
188
+ if not status.is_available:
189
+ status.is_available = True
190
+ logger.info(f"✅ Token恢复可用: {token[:20]}...")
191
+
192
+ def mark_token_failure(self, token: str, error: Exception = None):
193
+ """标记token使用失败"""
194
+ with self._lock:
195
+ if token in self.token_statuses:
196
+ status = self.token_statuses[token]
197
+ status.total_requests += 1
198
+ status.failure_count += 1
199
+ status.last_failure_time = time.time()
200
+
201
+ if status.failure_count >= self.failure_threshold:
202
+ status.is_available = False
203
+ logger.warning(f"🚫 Token已禁用: {token[:20]}... (失败 {status.failure_count} 次)")
204
+
205
+ def get_pool_status(self) -> Dict:
206
+ """获取token池状态信息"""
207
+ with self._lock:
208
+ available_count = len(self._get_available_tokens())
209
+ total_count = len(self.token_statuses)
210
+
211
+ # 统计健康token数量
212
+ healthy_count = sum(1 for status in self.token_statuses.values() if status.is_healthy)
213
+
214
+ status_info = {
215
+ "total_tokens": total_count,
216
+ "available_tokens": available_count,
217
+ "unavailable_tokens": total_count - available_count,
218
+ "healthy_tokens": healthy_count,
219
+ "unhealthy_tokens": total_count - healthy_count,
220
+ "current_index": self._current_index,
221
+ "tokens": []
222
+ }
223
+
224
+ for token, status in self.token_statuses.items():
225
+ status_info["tokens"].append({
226
+ "token": f"{token[:10]}...{token[-10:]}",
227
+ "token_type": status.token_type,
228
+ "is_available": status.is_available,
229
+ "failure_count": status.failure_count,
230
+ "success_count": status.successful_requests,
231
+ "success_rate": f"{status.success_rate:.2%}",
232
+ "total_requests": status.total_requests,
233
+ "is_healthy": status.is_healthy,
234
+ "last_failure_time": status.last_failure_time,
235
+ "last_success_time": status.last_success_time
236
+ })
237
+
238
+ return status_info
239
+
240
+ def update_tokens(self, new_tokens: List[str]):
241
+ """动态更新token列表"""
242
+ with self._lock:
243
+ # 保留现有token的状态信息
244
+ old_statuses = self.token_statuses.copy()
245
+ self.token_statuses.clear()
246
+
247
+ original_count = len(new_tokens)
248
+ unique_tokens = []
249
+
250
+ # 去重并添加新token,保留已存在token的状态
251
+ for token in new_tokens:
252
+ if token and token not in self.token_statuses: # 过滤空token和重复token
253
+ if token in old_statuses:
254
+ self.token_statuses[token] = old_statuses[token]
255
+ else:
256
+ # 预设为认证用户token,因为这些是用户手动配置的token
257
+ self.token_statuses[token] = TokenStatus(token=token, token_type="user")
258
+ unique_tokens.append(token)
259
+
260
+ # 记录去重信息
261
+ duplicate_count = original_count - len(unique_tokens)
262
+ if duplicate_count > 0:
263
+ logger.warning(f"⚠️ 更新时检测到 {duplicate_count} 个重复token,已自动去重")
264
+
265
+ # 重置索引
266
+ self._current_index = 0
267
+
268
+ logger.info(f"🔄 更新Token池,共 {len(self.token_statuses)} 个token")
269
+
270
+ async def health_check_token(self, token: str, auth_url: str = "https://chat.z.ai/api/v1/auths/") -> bool:
271
+ """
272
+ 异步健康检查单个token
273
+
274
+ 使用Z.AI认证API验证token的有效性,通过检查响应内容判断token是否有效
275
+
276
+ Args:
277
+ token: 要检查的token
278
+ auth_url: 认证URL
279
+
280
+ Returns:
281
+ token是否健康
282
+ """
283
+ try:
284
+ # 构建完整的请求头,模拟真实浏览器请求
285
+ headers = {
286
+ "Accept": "*/*",
287
+ "Accept-Language": "zh-CN,zh;q=0.9",
288
+ "Authorization": f"Bearer {token}",
289
+ "Connection": "keep-alive",
290
+ "Content-Type": "application/json",
291
+ "DNT": "1",
292
+ "Referer": "https://chat.z.ai/",
293
+ "Sec-Fetch-Dest": "empty",
294
+ "Sec-Fetch-Mode": "cors",
295
+ "Sec-Fetch-Site": "same-origin",
296
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36",
297
+ "sec-ch-ua": '"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"',
298
+ "sec-ch-ua-mobile": "?0",
299
+ "sec-ch-ua-platform": "Windows"
300
+ }
301
+
302
+ async with httpx.AsyncClient(timeout=15.0) as client:
303
+ response = await client.get(auth_url, headers=headers)
304
+
305
+ # 验证token有效性并获取类型
306
+ token_type, is_healthy = self._validate_token_response(response)
307
+
308
+ # 更新token类型
309
+ if token in self.token_statuses:
310
+ self.token_statuses[token].token_type = token_type
311
+
312
+ if is_healthy:
313
+ self.mark_token_success(token)
314
+ else:
315
+ # 简化错误信息,只记录关键错误类型
316
+ if token_type == "guest":
317
+ error_msg = "匿名用户token"
318
+ elif response.status_code != 200:
319
+ error_msg = f"HTTP {response.status_code}"
320
+ else:
321
+ error_msg = "认证失败"
322
+
323
+ self.mark_token_failure(token, Exception(error_msg))
324
+
325
+ return is_healthy
326
+
327
+ except (httpx.TimeoutException, httpx.ConnectError, Exception) as e:
328
+ self.mark_token_failure(token, e)
329
+ return False
330
+
331
+ def _validate_token_response(self, response: httpx.Response) -> bool:
332
+ """
333
+ 基于Z.AI API响应中的role字段验证token类型
334
+
335
+ 验证规则:
336
+ - role: "user" = 认证用户token(有效,可用于AUTH_TOKENS)
337
+ - role: "guest" = 匿名用户token(无效,不应在AUTH_TOKENS中)
338
+ - 无role字段或其他值 = 无效token
339
+
340
+ Args:
341
+ response: HTTP响应对象
342
+
343
+ Returns:
344
+ token是否为有效的认证用户token
345
+ """
346
+ # 首先检查HTTP状态码
347
+ if response.status_code != 200:
348
+ return ("unknown", False)
349
+
350
+ try:
351
+ # 尝试解析JSON响应
352
+ response_data = response.json()
353
+
354
+ if not isinstance(response_data, dict):
355
+ return ("unknown", False)
356
+
357
+ # 检查是否包含错误信息
358
+ if "error" in response_data:
359
+ return ("unknown", False)
360
+
361
+ if "message" in response_data and "error" in response_data.get("message", "").lower():
362
+ return ("unknown", False)
363
+
364
+ # 核心验证:检查role字段
365
+ role = response_data.get("role")
366
+
367
+ if role == "user":
368
+ return ("user", True)
369
+ elif role == "guest":
370
+
371
+ if not hasattr(self, '_guest_token_warned'):
372
+ logger.warning("⚠️ 检测到匿名用户token,建议仅在AUTH_TOKENS中配置认证用户token")
373
+ self._guest_token_warned = True
374
+ return ("guest", False)
375
+ else:
376
+ return ("unknown", False)
377
+
378
+ except (ValueError, Exception):
379
+ return ("unknown", False)
380
+
381
+ async def health_check_all(self, auth_url: str = "https://chat.z.ai/api/v1/auths/"):
382
+ """异步健康检查所有token"""
383
+ if not self.token_statuses:
384
+ logger.warning("⚠️ Token池为空,跳过健康检查")
385
+ return
386
+
387
+ total_tokens = len(self.token_statuses)
388
+ logger.info(f"🔍 开始Token池健康检查... (共 {total_tokens} 个token)")
389
+
390
+ # 并发执行所有token的健康检查
391
+ tasks = []
392
+ token_list = list(self.token_statuses.keys())
393
+
394
+ for token in token_list:
395
+ task = self.health_check_token(token, auth_url)
396
+ tasks.append(task)
397
+
398
+ # 执行并收集结果
399
+ results = await asyncio.gather(*tasks, return_exceptions=True)
400
+
401
+ # 统计结果
402
+ healthy_count = 0
403
+ failed_count = 0
404
+ exception_count = 0
405
+
406
+ for i, result in enumerate(results):
407
+ if result is True:
408
+ healthy_count += 1
409
+ elif result is False:
410
+ failed_count += 1
411
+ else:
412
+ # 异常情况
413
+ exception_count += 1
414
+ token = token_list[i]
415
+ logger.error(f"💥 Token {token[:20]}... 健康检查异常: {result}")
416
+
417
+ health_rate = (healthy_count / total_tokens) * 100 if total_tokens > 0 else 0
418
+
419
+ if healthy_count == 0 and total_tokens > 0:
420
+ logger.warning(f"⚠️ 健康检查完成: 0/{total_tokens} 个token健康 - 请检查token配置")
421
+ elif failed_count > 0:
422
+ logger.warning(f"⚠️ 健康检查完成: {healthy_count}/{total_tokens} 个token健康 ({health_rate:.1f}%)")
423
+ else:
424
+ logger.info(f"✅ 健康检查完成: {healthy_count}/{total_tokens} 个token健康")
425
+
426
+ if exception_count > 0:
427
+ logger.error(f"💥 {exception_count} 个token检查异常")
428
+
429
+
430
+ # 全局token池实例
431
+ _token_pool: Optional[TokenPool] = None
432
+ _pool_lock = Lock()
433
+
434
+
435
+ def get_token_pool() -> Optional[TokenPool]:
436
+ """获取全局token池实例"""
437
+ return _token_pool
438
+
439
+
440
+ def initialize_token_pool(tokens: List[str], failure_threshold: int = 3, recovery_timeout: int = 1800) -> TokenPool:
441
+ """初始化全局token池"""
442
+ global _token_pool
443
+ with _pool_lock:
444
+ _token_pool = TokenPool(tokens, failure_threshold, recovery_timeout)
445
+ return _token_pool
446
+
447
+
448
+ def update_token_pool(tokens: List[str]):
449
+ """更新全局token池"""
450
+ global _token_pool
451
+ with _pool_lock:
452
+ if _token_pool:
453
+ _token_pool.update_tokens(tokens)
454
+ else:
455
+ _token_pool = TokenPool(tokens)
app/utils/user_agent.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 用户代理工具模块
6
+ 提供动态随机用户代理生成功能
7
+ """
8
+
9
+ import random
10
+ from typing import Dict, Optional
11
+ from fake_useragent import UserAgent
12
+
13
+ # 全局 UserAgent 实例(单例模式)
14
+ _user_agent_instance: Optional[UserAgent] = None
15
+
16
+
17
+ def get_user_agent_instance() -> UserAgent:
18
+ """获取或创建 UserAgent 实例(单例模式)"""
19
+ global _user_agent_instance
20
+ if _user_agent_instance is None:
21
+ _user_agent_instance = UserAgent()
22
+ return _user_agent_instance
23
+
24
+
25
+ def get_random_user_agent(browser_type: Optional[str] = None) -> str:
26
+ """
27
+ 获取随机用户代理字符串
28
+
29
+ Args:
30
+ browser_type: 指定浏览器类型 ('chrome', 'firefox', 'safari', 'edge')
31
+ 如果为 None,则随机选择
32
+
33
+ Returns:
34
+ str: 用户代理字符串
35
+ """
36
+ ua = get_user_agent_instance()
37
+
38
+ # 如果没有指定浏览器类型,随机选择一个(偏向 Chrome 和 Edge)
39
+ if browser_type is None:
40
+ browser_choices = ["chrome", "chrome", "chrome", "edge", "edge", "firefox", "safari"]
41
+ browser_type = random.choice(browser_choices)
42
+
43
+ # 根据浏览器类型获取用户代理
44
+ if browser_type == "chrome":
45
+ user_agent = ua.chrome
46
+ elif browser_type == "edge":
47
+ user_agent = ua.edge
48
+ elif browser_type == "firefox":
49
+ user_agent = ua.firefox
50
+ elif browser_type == "safari":
51
+ user_agent = ua.safari
52
+ else:
53
+ user_agent = ua.random
54
+
55
+ return user_agent
56
+
57
+
58
+ # 通用 UserAgent headers 生成函数
59
+ def get_dynamic_headers(
60
+ referer: Optional[str] = None,
61
+ origin: Optional[str] = None,
62
+ browser_type: Optional[str] = None,
63
+ additional_headers: Optional[Dict[str, str]] = None
64
+ ) -> Dict[str, str]:
65
+ """
66
+ 生成动态浏览器 headers,包含随机 User-Agent
67
+
68
+ Args:
69
+ referer: 引用页面 URL
70
+ origin: 源站 URL
71
+ browser_type: 指定浏览器类型
72
+ additional_headers: 额外的 headers
73
+
74
+ Returns:
75
+ Dict[str, str]: 包含动态 User-Agent 的 headers
76
+ """
77
+ user_agent = get_random_user_agent(browser_type)
78
+
79
+ # 基础 headers
80
+ headers = {
81
+ "User-Agent": user_agent,
82
+ "Accept": "application/json, text/event-stream",
83
+ "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8",
84
+ "Accept-Encoding": "gzip, deflate, br",
85
+ "Cache-Control": "no-cache",
86
+ "Connection": "keep-alive",
87
+ "Pragma": "no-cache",
88
+ }
89
+
90
+ # 添加可选的 headers
91
+ if referer:
92
+ headers["Referer"] = referer
93
+
94
+ if origin:
95
+ headers["Origin"] = origin
96
+
97
+ # 根据用户代理添加浏览器特定的 headers
98
+ if "Chrome/" in user_agent or "Edg/" in user_agent:
99
+ # Chrome/Edge 特定的 headers
100
+ chrome_version = "139"
101
+ edge_version = "139"
102
+
103
+ try:
104
+ if "Chrome/" in user_agent:
105
+ chrome_version = user_agent.split("Chrome/")[1].split(".")[0]
106
+ except:
107
+ pass
108
+
109
+ try:
110
+ if "Edg/" in user_agent:
111
+ edge_version = user_agent.split("Edg/")[1].split(".")[0]
112
+ sec_ch_ua = f'"Microsoft Edge";v="{edge_version}", "Chromium";v="{chrome_version}", "Not_A Brand";v="24"'
113
+ else:
114
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
115
+ except:
116
+ sec_ch_ua = f'"Not_A Brand";v="8", "Chromium";v="{chrome_version}", "Google Chrome";v="{chrome_version}"'
117
+
118
+ headers.update({
119
+ "sec-ch-ua": sec_ch_ua,
120
+ "sec-ch-ua-mobile": "?0",
121
+ "sec-ch-ua-platform": '"Windows"',
122
+ "Sec-Fetch-Dest": "empty",
123
+ "Sec-Fetch-Mode": "cors",
124
+ "Sec-Fetch-Site": "same-origin",
125
+ })
126
+
127
+ # 添加额外的 headers
128
+ if additional_headers:
129
+ headers.update(additional_headers)
130
+
131
+ return headers
132
+
133
+
longcat_tokens.txt.example ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LongCat Passport Token 配置文件(可选)
2
+ #
3
+ # 说明:
4
+ # 1. 此文件是可选的,如果不需要多个token可以删除此文件
5
+ # 2. 支持两种格式:每行一个token 或 逗号分隔的token
6
+ # 3. 只包含有效的 passport_token_key 值
7
+ # 4. 系统会自动去重和验证token有效性
8
+ # 5. 自动跳过空格、换行符和空token
9
+ # 6. 当设置了 LONGCAT_PASSPORT_TOKEN 环境变量时,优先使用环境变量中的token
10
+ #
11
+ # 格式1:纯换行分隔
12
+ # token1
13
+ # token2
14
+ # token3
15
+
16
+ # 格式2:纯逗号分隔
17
+ # token1,token2,token3
18
+
19
+ # 格式3:混合格式
20
+ # token1,token2
21
+ # token3
22
+ # token4,token5,token6
23
+ # token7
24
+
25
+ # 请在下方添加您的 LongCat passport token(使用任一格式):
26
+
main.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import psutil
7
+ from contextlib import asynccontextmanager
8
+ from fastapi import FastAPI, Response
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+
11
+ from app.core.config import settings
12
+ from app.core import openai
13
+ from app.utils.reload_config import RELOAD_CONFIG
14
+ from app.utils.logger import setup_logger
15
+ from app.utils.token_pool import initialize_token_pool
16
+ from app.providers import initialize_providers
17
+
18
+ from granian import Granian
19
+
20
+
21
+ # Setup logger
22
+ logger = setup_logger(log_dir="logs", debug_mode=settings.DEBUG_LOGGING)
23
+
24
+
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ # 初始化提供商系统
28
+ initialize_providers()
29
+
30
+ # 初始化 token 池
31
+ token_list = settings.auth_token_list
32
+ if token_list:
33
+ token_pool = initialize_token_pool(
34
+ tokens=token_list,
35
+ failure_threshold=settings.TOKEN_FAILURE_THRESHOLD,
36
+ recovery_timeout=settings.TOKEN_RECOVERY_TIMEOUT
37
+ )
38
+
39
+ yield
40
+
41
+ logger.info("🔄 应用正在关闭...")
42
+
43
+
44
+ # Create FastAPI app with lifespan
45
+ app = FastAPI(lifespan=lifespan)
46
+
47
+ # Add CORS middleware
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"],
51
+ allow_credentials=True,
52
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
53
+ allow_headers=["Content-Type", "Authorization"],
54
+ )
55
+
56
+ # Include API routers
57
+ app.include_router(openai.router)
58
+
59
+
60
+ @app.options("/")
61
+ async def handle_options():
62
+ """Handle OPTIONS requests"""
63
+ return Response(status_code=200)
64
+
65
+
66
+ @app.get("/")
67
+ async def root():
68
+ """Root endpoint"""
69
+ return {"message": "OpenAI Compatible API Server"}
70
+
71
+
72
+ def run_server():
73
+ service_name = settings.SERVICE_NAME
74
+
75
+ logger.info(f"🚀 启动 {service_name} 服务...")
76
+ logger.info(f"📡 监听地址: 0.0.0.0:{settings.LISTEN_PORT}")
77
+ logger.info(f"🔧 调试模式: {'开启' if settings.DEBUG_LOGGING else '关闭'}")
78
+ logger.info(f"🔐 匿名模式: {'开启' if settings.ANONYMOUS_MODE else '关闭'}")
79
+
80
+ try:
81
+ Granian(
82
+ "main:app",
83
+ interface="asgi",
84
+ address="0.0.0.0",
85
+ port=settings.LISTEN_PORT,
86
+ reload=False, # 生产环境请关闭热重载
87
+ process_name=service_name, # 设置进程名称
88
+ **RELOAD_CONFIG, # 热重载配置
89
+ ).serve()
90
+ except KeyboardInterrupt:
91
+ logger.info("🛑 收到中断信号,正在关闭服务...")
92
+ except Exception as e:
93
+ logger.error(f"❌ 服务启动失败: {e}")
94
+ sys.exit(1)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ run_server()
pyproject.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "z-ai2api-python"
7
+ version = "0.1.0"
8
+ description = "一个为 Z.ai 提供 OpenAI 兼容接口的 Python 代理服务"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9,<=3.12"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "Contributors" }]
13
+ classifiers = [
14
+ "Development Status :: 4 - Beta",
15
+ "Intended Audience :: Developers",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3.9",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Topic :: Internet :: WWW/HTTP :: HTTP Servers",
24
+ "Topic :: Software Development :: Libraries :: Python Modules",
25
+ ]
26
+ dependencies = [
27
+ "fastapi==0.116.1",
28
+ "granian[reload,pname]==2.5.2",
29
+ "httpx==0.28.1",
30
+ "pydantic==2.11.7",
31
+ "pydantic-settings==2.10.1",
32
+ "pydantic-core==2.33.2",
33
+ "typing-inspection==0.4.1",
34
+ "fake-useragent==2.2.0",
35
+ "loguru==0.7.3",
36
+ "psutil>=7.0.0",
37
+ "json-repair==0.44.1"
38
+ ]
39
+
40
+ [project.scripts]
41
+ z-ai2api = "main:app"
42
+
43
+ [tool.hatch.build.targets.wheel]
44
+ packages = ["."]
45
+
46
+ [tool.uv]
47
+ dev-dependencies = [
48
+ "pytest>=7.0.0",
49
+ "pytest-asyncio>=0.21.0",
50
+ "requests>=2.30.0",
51
+ "ruff>=0.1.0",
52
+ ]
53
+
54
+ [tool.ruff]
55
+ line-length = 88
56
+ target-version = "py38"
57
+ select = ["E", "F", "I", "B"]
58
+ ignore = []
59
+
60
+ [tool.ruff.isort]
61
+ known-first-party = []
62
+
63
+ [tool.pytest.ini_options]
64
+ asyncio_mode = "auto"
65
+ testpaths = ["tests"]
66
+ python_files = ["test_*.py"]
67
+ python_functions = ["test_*"]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.116.1
2
+ granian[reload,pname]==2.5.2
3
+ httpx==0.28.1
4
+ pydantic==2.11.7
5
+ pydantic-settings==2.10.1
6
+ pydantic-core==2.33.2
7
+ typing-inspection==0.4.1
8
+ fake-useragent==2.2.0
9
+ loguru==0.7.3
10
+ psutil>=7.0.0
11
+ json-repair==0.44.1
tests/test_comprehensive_fix.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 全面测试 ZAI Provider 修复效果
6
+ 验证流式输出、工具调用、思考模式、重试机制等功能
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import sys
12
+ import os
13
+
14
+ # 添加项目根目录到路径
15
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16
+
17
+ from app.providers.zai_provider import ZAIProvider
18
+ from app.models.schemas import OpenAIRequest, Message
19
+ from app.core.config import settings
20
+
21
+
22
+ async def test_basic_stream():
23
+ """测试基本流式输出"""
24
+ print("🧪 测试基本流式输出...")
25
+
26
+ provider = ZAIProvider()
27
+
28
+ request = OpenAIRequest(
29
+ model=settings.PRIMARY_MODEL,
30
+ messages=[
31
+ Message(role="user", content="你好,请简单介绍一下自己")
32
+ ],
33
+ stream=True
34
+ )
35
+
36
+ try:
37
+ response = await provider.chat_completion(request)
38
+
39
+ if hasattr(response, '__aiter__'):
40
+ print("✅ 返回了异步生成器")
41
+ chunk_count = 0
42
+ content_chunks = []
43
+
44
+ async for chunk in response:
45
+ chunk_count += 1
46
+ if chunk.startswith("data: ") and not chunk.strip().endswith("[DONE]"):
47
+ try:
48
+ chunk_data = json.loads(chunk[6:].strip())
49
+ if "choices" in chunk_data and chunk_data["choices"]:
50
+ choice = chunk_data["choices"][0]
51
+ if "delta" in choice and "content" in choice["delta"]:
52
+ content = choice["delta"]["content"]
53
+ if content:
54
+ content_chunks.append(content)
55
+ except:
56
+ pass
57
+
58
+ if chunk_count >= 10: # 限制测试长度
59
+ break
60
+
61
+ full_content = "".join(content_chunks)
62
+ print(f"✅ 成功处理了 {chunk_count} 个数据块")
63
+ print(f"📝 内容预览: {full_content[:100]}...")
64
+ return len(content_chunks) > 0
65
+ else:
66
+ print("❌ 返回的不是流式响应")
67
+ return False
68
+
69
+ except Exception as e:
70
+ print(f"❌ 基本流式测试失败: {e}")
71
+ return False
72
+
73
+
74
+ async def test_thinking_mode():
75
+ """测试思考模式"""
76
+ print("\n🧪 测试思考模式...")
77
+
78
+ provider = ZAIProvider()
79
+
80
+ request = OpenAIRequest(
81
+ model=settings.THINKING_MODEL,
82
+ messages=[
83
+ Message(role="user", content="请解释一下量子计算的基本原理")
84
+ ],
85
+ stream=True
86
+ )
87
+
88
+ try:
89
+ response = await provider.chat_completion(request)
90
+
91
+ if hasattr(response, '__aiter__'):
92
+ print("✅ 返回了异步生成器")
93
+ chunk_count = 0
94
+ has_thinking = False
95
+ has_content = False
96
+
97
+ async for chunk in response:
98
+ chunk_count += 1
99
+
100
+ # 检查是否包含思考内容
101
+ if 'thinking' in chunk:
102
+ has_thinking = True
103
+ print("✅ 检测到思考内容")
104
+
105
+ # 检查是否包含普通内容
106
+ if '"content"' in chunk and '"thinking"' not in chunk:
107
+ has_content = True
108
+ print("✅ 检测到答案内容")
109
+
110
+ if chunk_count >= 20: # 限制测试长度
111
+ break
112
+
113
+ print(f"✅ 成功处理了 {chunk_count} 个数据块")
114
+ print(f"🤔 思考模式: {'正常' if has_thinking else '未检测到'}")
115
+ print(f"💬 答案内容: {'正常' if has_content else '未检测到'}")
116
+ return True
117
+ else:
118
+ print("❌ 返回的不是流式响应")
119
+ return False
120
+
121
+ except Exception as e:
122
+ print(f"❌ 思考模式测试失败: {e}")
123
+ return False
124
+
125
+
126
+ async def test_tool_support():
127
+ """测试工具调用支持"""
128
+ print("\n🧪 测试工具调用支持...")
129
+
130
+ if not settings.TOOL_SUPPORT:
131
+ print("⚠️ 工具支持已禁用,跳过测试")
132
+ return True
133
+
134
+ provider = ZAIProvider()
135
+
136
+ # 简单的工具定义
137
+ tools = [
138
+ {
139
+ "type": "function",
140
+ "function": {
141
+ "name": "get_weather",
142
+ "description": "获取指定城市的天气信息",
143
+ "parameters": {
144
+ "type": "object",
145
+ "properties": {
146
+ "city": {
147
+ "type": "string",
148
+ "description": "城市名称"
149
+ }
150
+ },
151
+ "required": ["city"]
152
+ }
153
+ }
154
+ }
155
+ ]
156
+
157
+ request = OpenAIRequest(
158
+ model=settings.PRIMARY_MODEL,
159
+ messages=[
160
+ Message(role="user", content="请帮我查询北京的天气")
161
+ ],
162
+ tools=tools,
163
+ stream=True
164
+ )
165
+
166
+ try:
167
+ response = await provider.chat_completion(request)
168
+
169
+ if hasattr(response, '__aiter__'):
170
+ print("✅ 返回了异步生成器")
171
+ chunk_count = 0
172
+ has_tool_call = False
173
+
174
+ async for chunk in response:
175
+ chunk_count += 1
176
+
177
+ # 检查是否包含工具调用
178
+ if 'tool_calls' in chunk:
179
+ has_tool_call = True
180
+ print("✅ 检测到工具调用")
181
+
182
+ if chunk_count >= 30: # 限制测试长度
183
+ break
184
+
185
+ print(f"✅ 成功处理了 {chunk_count} 个数据块")
186
+ print(f"🔧 工具调用: {'正常' if has_tool_call else '未检测到'}")
187
+ return True
188
+ else:
189
+ print("❌ 返回的不是流式响应")
190
+ return False
191
+
192
+ except Exception as e:
193
+ print(f"❌ 工具调用测试失败: {e}")
194
+ return False
195
+
196
+
197
+ async def test_error_handling():
198
+ """测试错误处理"""
199
+ print("\n🧪 测试错误处理...")
200
+
201
+ provider = ZAIProvider()
202
+
203
+ # 使用无效的消息来触发错误
204
+ request = OpenAIRequest(
205
+ model="invalid-model",
206
+ messages=[
207
+ Message(role="user", content="测试错误处理")
208
+ ],
209
+ stream=True
210
+ )
211
+
212
+ try:
213
+ response = await provider.chat_completion(request)
214
+
215
+ if hasattr(response, '__aiter__'):
216
+ chunk_count = 0
217
+ has_error = False
218
+
219
+ async for chunk in response:
220
+ chunk_count += 1
221
+
222
+ # 检查是否包含错误信息
223
+ if 'error' in chunk:
224
+ has_error = True
225
+ print("✅ 检测到错误处理")
226
+
227
+ if chunk_count >= 5: # 限制测试长度
228
+ break
229
+
230
+ print(f"✅ 错误处理测试完成,处理了 {chunk_count} 个数据块")
231
+ return True
232
+ else:
233
+ print("✅ 返回了错误响应(非流式)")
234
+ return True
235
+
236
+ except Exception as e:
237
+ print(f"✅ 正确捕获了异常: {type(e).__name__}")
238
+ return True
239
+
240
+
241
+ async def main():
242
+ """主测试函数"""
243
+ print("🚀 开始全面测试 ZAI Provider 修复效果\n")
244
+
245
+ # 显示配置信息
246
+ print("📋 当前配置:")
247
+ print(f" - 匿名模式: {settings.ANONYMOUS_MODE}")
248
+ print(f" - 工具支持: {settings.TOOL_SUPPORT}")
249
+ print(f" - 最大重试: {settings.MAX_RETRIES}")
250
+ print(f" - 重试延迟: {settings.RETRY_DELAY}s")
251
+ print()
252
+
253
+ tests = [
254
+ ("基本流式输出", test_basic_stream),
255
+ ("思考模式", test_thinking_mode),
256
+ ("工具调用支持", test_tool_support),
257
+ ("错误处理", test_error_handling),
258
+ ]
259
+
260
+ passed = 0
261
+ total = len(tests)
262
+
263
+ for test_name, test_func in tests:
264
+ try:
265
+ print(f"{'='*50}")
266
+ result = await test_func()
267
+ if result:
268
+ passed += 1
269
+ print(f"✅ {test_name} 测试通过")
270
+ else:
271
+ print(f"❌ {test_name} 测试失败")
272
+ except Exception as e:
273
+ print(f"❌ {test_name} 测试异常: {e}")
274
+
275
+ print()
276
+
277
+ print(f"{'='*50}")
278
+ print(f"📊 测试结果: {passed}/{total} 通过")
279
+
280
+ if passed == total:
281
+ print("🎉 所有测试都通过了!ZAI Provider 修复成功")
282
+ elif passed >= total * 0.75:
283
+ print("✅ 大部分测试通过,ZAI Provider 基本修复成功")
284
+ else:
285
+ print("⚠️ 多个测试失败,需要进一步检查")
286
+
287
+
288
+ if __name__ == "__main__":
289
+ asyncio.run(main())
tests/test_done_phase.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 测试 done 阶段处理
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ from app.utils.sse_tool_handler import SSEToolHandler
11
+ import json
12
+
13
+ def test_done_phase_handling():
14
+ """测试 done 阶段的处理"""
15
+
16
+ handler = SSEToolHandler("test-model", stream=True)
17
+
18
+ print("🧪 测试 done 阶段处理\n")
19
+
20
+ # 模拟完整的对话流程
21
+ test_chunks = [
22
+ # 回答阶段
23
+ {
24
+ "phase": "answer",
25
+ "delta_content": "这是回答内容",
26
+ "edit_content": ""
27
+ },
28
+ # 完成阶段
29
+ {
30
+ "phase": "done",
31
+ "done": True,
32
+ "delta_content": "",
33
+ "usage": {
34
+ "prompt_tokens": 100,
35
+ "completion_tokens": 50,
36
+ "total_tokens": 150
37
+ }
38
+ }
39
+ ]
40
+
41
+ output_chunks = []
42
+
43
+ for i, chunk in enumerate(test_chunks, 1):
44
+ print(f"处理块 {i}: phase={chunk['phase']}")
45
+
46
+ results = list(handler.process_sse_chunk(chunk))
47
+ output_chunks.extend(results)
48
+
49
+ print(f" 输出数量: {len(results)}")
50
+ for j, result in enumerate(results):
51
+ if result.strip() == "data: [DONE]":
52
+ print(f" 输出 {j+1}: [DONE] 标记")
53
+ else:
54
+ print(f" 输出 {j+1}: {result[:80]}{'...' if len(result) > 80 else ''}")
55
+ print()
56
+
57
+ print(f"📊 测试结果:")
58
+ print(f" 总输出块数量: {len(output_chunks)}")
59
+
60
+ # 验证输出内容
61
+ has_content = False
62
+ has_final_chunk = False
63
+ has_done_marker = False
64
+ has_usage = False
65
+
66
+ for output in output_chunks:
67
+ if output.startswith("data: "):
68
+ json_str = output[6:].strip()
69
+ if json_str == "[DONE]":
70
+ has_done_marker = True
71
+ print(" ✅ 找到 [DONE] 标记")
72
+ elif json_str:
73
+ try:
74
+ data = json.loads(json_str)
75
+ if "choices" in data and data["choices"]:
76
+ delta = data["choices"][0].get("delta", {})
77
+ content = delta.get("content", "")
78
+ finish_reason = data["choices"][0].get("finish_reason")
79
+
80
+ if content:
81
+ has_content = True
82
+ print(f" ✅ 找到内容: '{content}'")
83
+
84
+ if finish_reason == "stop":
85
+ has_final_chunk = True
86
+ print(" ✅ 找到最终完成块")
87
+
88
+ if "usage" in data:
89
+ has_usage = True
90
+ print(f" ✅ 找到 usage 信息: {data['usage']}")
91
+
92
+ except json.JSONDecodeError as e:
93
+ print(f" ❌ JSON 解析错误: {e}")
94
+
95
+ # 验证结果
96
+ success = has_content and has_final_chunk and has_done_marker
97
+
98
+ print(f"\n📋 验证结果:")
99
+ print(f" 包含回答内容: {'✅' if has_content else '❌'}")
100
+ print(f" 包含最终完成块: {'✅' if has_final_chunk else '❌'}")
101
+ print(f" 包含 [DONE] 标记: {'✅' if has_done_marker else '❌'}")
102
+ print(f" 包含 usage 信息: {'✅' if has_usage else '❌'}")
103
+
104
+ if success:
105
+ print("\n✅ done 阶段处理测试通过!")
106
+ return True
107
+ else:
108
+ print("\n❌ done 阶段处理测试失败!")
109
+ return False
110
+
111
+ def test_done_phase_with_tool_call():
112
+ """测试带工具调用的 done 阶段处理"""
113
+
114
+ handler = SSEToolHandler("test-model", stream=True)
115
+
116
+ print("🧪 测试带工具调用的 done 阶段处理\n")
117
+
118
+ # 模拟工具调用 + 回答 + 完成的流程
119
+ test_chunks = [
120
+ # 工具调用开始
121
+ {
122
+ "phase": "tool_call",
123
+ "edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_test", "name": "test_tool", "arguments": "{}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
124
+ "edit_index": 100
125
+ },
126
+ # 工具调用结束
127
+ {
128
+ "phase": "other",
129
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
130
+ "edit_index": 200
131
+ },
132
+ # 回答阶段
133
+ {
134
+ "phase": "answer",
135
+ "delta_content": "工具调用完成,这是回答。",
136
+ "edit_content": ""
137
+ },
138
+ # 完成阶段
139
+ {
140
+ "phase": "done",
141
+ "done": True,
142
+ "delta_content": ""
143
+ }
144
+ ]
145
+
146
+ output_chunks = []
147
+
148
+ for i, chunk in enumerate(test_chunks, 1):
149
+ print(f"处理块 {i}: phase={chunk['phase']}")
150
+
151
+ results = list(handler.process_sse_chunk(chunk))
152
+ output_chunks.extend(results)
153
+
154
+ print(f" 输出数量: {len(results)}")
155
+ print()
156
+
157
+ # 检查是否有工具调用、回答内容和完成标记
158
+ has_tool_call = any("tool_calls" in output for output in output_chunks)
159
+ has_answer_content = any("工具调用完成" in output for output in output_chunks)
160
+ has_done_marker = any(output.strip() == "data: [DONE]" for output in output_chunks)
161
+
162
+ print(f"📊 混合流程测试结果:")
163
+ print(f" 包含工具调用: {'✅' if has_tool_call else '❌'}")
164
+ print(f" 包含回答内容: {'✅' if has_answer_content else '❌'}")
165
+ print(f" 包含 [DONE] 标记: {'✅' if has_done_marker else '❌'}")
166
+
167
+ success = has_tool_call and has_answer_content and has_done_marker
168
+
169
+ if success:
170
+ print("\n✅ 混合流程 done 阶段测试通过!")
171
+ return True
172
+ else:
173
+ print("\n❌ 混合流程 done 阶段测试失败!")
174
+ return False
175
+
176
+ def test_done_phase_warning_fix():
177
+ """测试 done 阶段不再产生警告"""
178
+
179
+ handler = SSEToolHandler("test-model", stream=True)
180
+
181
+ print("🧪 测试 done 阶段警告修复\n")
182
+
183
+ # 模拟 done 阶段
184
+ chunk = {
185
+ "phase": "done",
186
+ "done": True,
187
+ "delta_content": ""
188
+ }
189
+
190
+ print("处理 done 阶段块...")
191
+
192
+ # 捕获日志输出(这里我们主要检查是否有异常)
193
+ try:
194
+ results = list(handler.process_sse_chunk(chunk))
195
+ print(f" 成功处理,输出 {len(results)} 个块")
196
+
197
+ # 检查是否有 [DONE] 标记
198
+ has_done = any(output.strip() == "data: [DONE]" for output in results)
199
+ print(f" 包含 [DONE] 标记: {'✅' if has_done else '❌'}")
200
+
201
+ print("\n✅ done 阶段不再产生警告!")
202
+ return True
203
+
204
+ except Exception as e:
205
+ print(f"\n❌ 处理 done 阶段时出错: {e}")
206
+ return False
207
+
208
+ if __name__ == "__main__":
209
+ print("🔧 测试 done 阶段处理\n")
210
+
211
+ test1_success = test_done_phase_handling()
212
+ print("\n" + "="*50 + "\n")
213
+ test2_success = test_done_phase_with_tool_call()
214
+ print("\n" + "="*50 + "\n")
215
+ test3_success = test_done_phase_warning_fix()
216
+
217
+ print("\n" + "="*50)
218
+ print("🎯 总结:")
219
+ print(f" done 阶段基本处理: {'✅ 通过' if test1_success else '❌ 失败'}")
220
+ print(f" done 阶段混合流程: {'✅ 通过' if test2_success else '❌ 失败'}")
221
+ print(f" done 阶段警告修复: {'✅ 通过' if test3_success else '❌ 失败'}")
222
+
223
+ if test1_success and test2_success and test3_success:
224
+ print("\n🎉 所有测试通过!done 阶段处理完善!")
225
+ print("\n💡 修复效果:")
226
+ print(" - 不再显示 '未知的 SSE 阶段: done' 警告")
227
+ print(" - 正确处理对话完成流程")
228
+ print(" - 自动刷新缓冲区和完成工具调用")
229
+ print(" - 发送标准的 OpenAI 完成标记")
230
+ else:
231
+ print("\n❌ 部分测试失败,需要进一步调试")
tests/test_longcat_connection.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ 测试 LongCat API 连接性
6
+ """
7
+
8
+ import asyncio
9
+ import httpx
10
+ import json
11
+
12
+ # LongCat API 端点
13
+ LONGCAT_API_ENDPOINT = "https://longcat.chat/api/v1/chat-completion-oversea"
14
+
15
+ async def test_longcat_api():
16
+ """测试 LongCat API 连接"""
17
+ print(f"🧪 测试 LongCat API 连接...")
18
+ print(f"📡 API 端点: {LONGCAT_API_ENDPOINT}")
19
+
20
+ headers = {
21
+ 'accept': 'text/event-stream,application/json',
22
+ 'content-type': 'application/json',
23
+ 'origin': 'https://longcat.chat',
24
+ 'referer': 'https://longcat.chat/t',
25
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36'
26
+ }
27
+
28
+ payload = {
29
+ "stream": True,
30
+ "temperature": 0.7,
31
+ "content": "Hello",
32
+ "messages": [
33
+ {
34
+ "role": "user",
35
+ "content": "Hello"
36
+ }
37
+ ]
38
+ }
39
+
40
+ print(f"📤 发送请求...")
41
+ print(f"📋 Headers: {json.dumps(headers, indent=2)}")
42
+ print(f"📋 Payload: {json.dumps(payload, indent=2)}")
43
+
44
+ try:
45
+ async with httpx.AsyncClient(timeout=30.0) as client:
46
+ response = await client.post(
47
+ LONGCAT_API_ENDPOINT,
48
+ headers=headers,
49
+ json=payload
50
+ )
51
+
52
+ print(f"📡 响应状态码: {response.status_code}")
53
+ print(f"📋 响应头: {dict(response.headers)}")
54
+
55
+ if not response.is_success:
56
+ error_text = await response.atext()
57
+ print(f"❌ API 错误: {error_text}")
58
+ return False
59
+
60
+ print(f"✅ 连接成功,开始读取流数据...")
61
+
62
+ line_count = 0
63
+ async for line in response.aiter_lines():
64
+ line_count += 1
65
+ line = line.strip()
66
+ print(f"📥 第 {line_count} 行: {line}")
67
+
68
+ if line_count > 10: # 只读取前10行
69
+ print(f"⏹️ 停止读取(已读取 {line_count} 行)")
70
+ break
71
+
72
+ if line.startswith('data:'):
73
+ data_str = line[5:].strip()
74
+ if data_str == '[DONE]':
75
+ print(f"🏁 收到结束标记")
76
+ break
77
+
78
+ try:
79
+ data = json.loads(data_str)
80
+ print(f"📦 解析成功: {json.dumps(data, ensure_ascii=False, indent=2)}")
81
+ except json.JSONDecodeError as e:
82
+ print(f"❌ JSON 解析失败: {e}")
83
+
84
+ print(f"✅ 测试完成,共读取 {line_count} 行")
85
+ return True
86
+
87
+ except httpx.TimeoutException:
88
+ print(f"❌ 请求超时")
89
+ return False
90
+ except httpx.ConnectError as e:
91
+ print(f"❌ 连接错误: {e}")
92
+ return False
93
+ except Exception as e:
94
+ print(f"❌ 未知错误: {e}")
95
+ import traceback
96
+ print(f"❌ 错误堆栈: {traceback.format_exc()}")
97
+ return False
98
+
99
+ async def test_simple_request():
100
+ """测试简单的非流式请求"""
101
+ print(f"\n🧪 测试简单的非流式请求...")
102
+
103
+ headers = {
104
+ 'accept': 'application/json',
105
+ 'content-type': 'application/json',
106
+ 'origin': 'https://longcat.chat',
107
+ 'referer': 'https://longcat.chat/t',
108
+ 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
109
+ }
110
+
111
+ payload = {
112
+ "stream": False,
113
+ "temperature": 0.7,
114
+ "content": "Hello",
115
+ "messages": [
116
+ {
117
+ "role": "user",
118
+ "content": "Hello"
119
+ }
120
+ ]
121
+ }
122
+
123
+ try:
124
+ async with httpx.AsyncClient(timeout=30.0) as client:
125
+ response = await client.post(
126
+ LONGCAT_API_ENDPOINT,
127
+ headers=headers,
128
+ json=payload
129
+ )
130
+
131
+ print(f"📡 响应状态码: {response.status_code}")
132
+
133
+ if response.is_success:
134
+ content = await response.atext()
135
+ print(f"✅ 响应内容: {content[:500]}...")
136
+ return True
137
+ else:
138
+ error_text = await response.atext()
139
+ print(f"❌ 错误响应: {error_text}")
140
+ return False
141
+
142
+ except Exception as e:
143
+ print(f"❌ 请求失败: {e}")
144
+ return False
145
+
146
+ async def main():
147
+ """运行所有测试"""
148
+ print("🚀 开始 LongCat API 连接测试...\n")
149
+
150
+ # 测试流式请求
151
+ stream_result = await test_longcat_api()
152
+
153
+ # 测试非流式请求
154
+ simple_result = await test_simple_request()
155
+
156
+ print(f"\n📊 测试结果:")
157
+ print(f" 流式请求: {'✅ 成功' if stream_result else '❌ 失败'}")
158
+ print(f" 非流式请求: {'✅ 成功' if simple_result else '❌ 失败'}")
159
+
160
+ if stream_result and simple_result:
161
+ print(f"🎉 所有测试通过!")
162
+ else:
163
+ print(f"⚠️ 部分测试失败,请检查网络连接和 API 端点")
164
+
165
+ if __name__ == "__main__":
166
+ asyncio.run(main())
tests/test_multiple_tools.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 测试多个工具调用的处理逻辑
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ from app.utils.sse_tool_handler import SSEToolHandler
11
+
12
+ def test_multiple_tool_calls():
13
+ """测试多个工具调用的处理"""
14
+
15
+ handler = SSEToolHandler("test-model", stream=False)
16
+
17
+ print("🧪 测试多个工具调用处理\n")
18
+
19
+ # 模拟真实的多工具调用序列(基于日志)
20
+ test_chunks = [
21
+ # 第一个工具调用开始
22
+ {
23
+ "phase": "tool_call",
24
+ "edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_5y5gir0mygx", "name": "mcp__playwright__browser_navigate", "arguments": "{\\"url\\":\\"https://www.bil", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
25
+ "edit_index": 24
26
+ },
27
+ # 第一个工具调用参数补充
28
+ {
29
+ "phase": "tool_call",
30
+ "edit_content": 'ibili.com\\"}',
31
+ "edit_index": 194
32
+ },
33
+ # 第一个工具调用结束
34
+ {
35
+ "phase": "other",
36
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
37
+ "edit_index": 219
38
+ },
39
+ # 第二个工具调用开始
40
+ {
41
+ "phase": "tool_call",
42
+ "edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_j8r24x6xtg", "name": "mcp__playwright__browser_snapshot", "arguments": "{}", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
43
+ "edit_index": 406
44
+ },
45
+ # 第二个工具调用结束
46
+ {
47
+ "phase": "other",
48
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
49
+ "edit_index": 566
50
+ },
51
+ # 第三个工具调用开始(重复的 navigate)
52
+ {
53
+ "phase": "tool_call",
54
+ "edit_content": '<glm_block view="">{"type": "mcp", "data": {"metadata": {"id": "call_scvwo0xaoil", "name": "mcp__playwright__browser_navigate", "arguments": "{\\"url\\":\\"https://www.bil", "result": "", "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
55
+ "edit_index": 753
56
+ },
57
+ # 第三个工具调用参数补充
58
+ {
59
+ "phase": "tool_call",
60
+ "edit_content": 'ibili.com\\"}',
61
+ "edit_index": 925
62
+ },
63
+ # 第三个工具调用结束
64
+ {
65
+ "phase": "other",
66
+ "edit_content": 'null, "display_result": "", "duration": "...", "status": "completed", "is_error": false, "mcp_server": {"name": "mcp-server"}}, "thought": null, "ppt": null, "browser": null}}</glm_block>',
67
+ "edit_index": 950
68
+ }
69
+ ]
70
+
71
+ tool_calls_completed = []
72
+
73
+ for i, chunk in enumerate(test_chunks, 1):
74
+ print(f"处理块 {i}: edit_index={chunk['edit_index']}, phase={chunk['phase']}")
75
+
76
+ # 记录处理前的工具状态
77
+ old_tool_id = handler.tool_id
78
+ old_tool_name = handler.tool_name
79
+ old_has_tool_call = handler.has_tool_call
80
+
81
+ # 处理块
82
+ results = list(handler.process_sse_chunk(chunk))
83
+
84
+ # 检查是否有新工具调用开始
85
+ if handler.tool_id != old_tool_id and handler.tool_id:
86
+ print(f" 🎯 新工具调用开始: {handler.tool_name} (id: {handler.tool_id})")
87
+
88
+ # 检查是否有工具调用完成
89
+ if old_has_tool_call and not handler.has_tool_call:
90
+ tool_calls_completed.append({
91
+ "name": old_tool_name or "unknown",
92
+ "id": old_tool_id
93
+ })
94
+ print(f" ✅ 工具调用完成: {old_tool_name or 'unknown'}")
95
+
96
+ print(f" 当前状态: has_tool_call={handler.has_tool_call}, tool_id={handler.tool_id}")
97
+ print()
98
+
99
+ print(f"📊 测试结果:")
100
+ print(f" 完成的工具调用数量: {len(tool_calls_completed)}")
101
+ for i, tool in enumerate(tool_calls_completed, 1):
102
+ print(f" {i}. {tool['name']} (id: {tool['id']})")
103
+
104
+ # 验证是否正确处理了所有工具调用
105
+ expected_tools = [
106
+ "mcp__playwright__browser_navigate",
107
+ "mcp__playwright__browser_snapshot",
108
+ "mcp__playwright__browser_navigate"
109
+ ]
110
+
111
+ completed_tool_names = [tool['name'] for tool in tool_calls_completed]
112
+
113
+ if completed_tool_names == expected_tools:
114
+ print("\n✅ 测试通过!正确处理了所有工具调用")
115
+ print("📝 结论:重复的工具调用是上游发送的,我们的处理逻辑是正确的")
116
+ return True
117
+ else:
118
+ print(f"\n❌ 测试失败!")
119
+ print(f" 期望: {expected_tools}")
120
+ print(f" 实际: {completed_tool_names}")
121
+ return False
122
+
123
+ if __name__ == "__main__":
124
+ success = test_multiple_tool_calls()
125
+
126
+ if success:
127
+ print("\n🎯 总结:")
128
+ print("1. 我们的 API 代理正确处理了每个不同的工具调用")
129
+ print("2. 重复的工具调用是上游 Z.AI 模型发送的,不是我们的问题")
130
+ print("3. 每个工具调用都有不同的 ID,说明这是模型的有意行为")
131
+ print("4. 可能的原因:模型重试、验证操作、或处理复杂任务的策略")
132
+ else:
133
+ print("\n❌ 需要进一步调试处理逻辑")
tests/test_simple_performance.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 简化的性能测试,避免过多日志输出
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ import time
9
+ import json
10
+ import logging
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ # 临时禁用日志以避免性能测试中的噪音
14
+ logging.getLogger().setLevel(logging.CRITICAL)
15
+
16
+ from app.utils.sse_tool_handler import SSEToolHandler
17
+
18
+ def test_optimized_performance():
19
+ """测试优化后的性能"""
20
+
21
+ print("🧪 测试优化后的 JSON 修复性能\n")
22
+
23
+ # 测试用例
24
+ test_cases = [
25
+ {
26
+ "name": "简单JSON",
27
+ "input": '{"command":"echo hello","description":"test"}',
28
+ "iterations": 100
29
+ },
30
+ {
31
+ "name": "复杂命令行参数",
32
+ "input": '{"command":"echo \\"添加更多内容\\uff1a$(date)\\\\\\" >> \\\\\\"C:\\\\\\\\Users\\\\\\\\test\\\\\\\\1.txt\\\\\\"\\"","description":"test"}',
33
+ "iterations": 50
34
+ },
35
+ {
36
+ "name": "缺少开始括号",
37
+ "input": '"command":"echo hello","description":"test"}',
38
+ "iterations": 50
39
+ },
40
+ {
41
+ "name": "Windows路径问题",
42
+ "input": '{"path":"C:\\\\\\\\Users\\\\\\\\Documents","command":"dir"}',
43
+ "iterations": 50
44
+ }
45
+ ]
46
+
47
+ handler = SSEToolHandler("test-model", stream=False)
48
+
49
+ total_time = 0
50
+ total_iterations = 0
51
+
52
+ for test_case in test_cases:
53
+ print(f"测试: {test_case['name']}")
54
+ print(f" 输入长度: {len(test_case['input'])} 字符")
55
+ print(f" 迭代次数: {test_case['iterations']}")
56
+
57
+ # 预热
58
+ for _ in range(5):
59
+ handler._fix_tool_arguments(test_case['input'])
60
+
61
+ # 性能测试
62
+ start_time = time.time()
63
+ for _ in range(test_case['iterations']):
64
+ result = handler._fix_tool_arguments(test_case['input'])
65
+ end_time = time.time()
66
+
67
+ duration = end_time - start_time
68
+ if duration > 0:
69
+ avg_time = duration / test_case['iterations'] * 1000 # 毫秒
70
+ throughput = test_case['iterations'] / duration
71
+ else:
72
+ avg_time = 0
73
+ throughput = float('inf')
74
+
75
+ print(f" 总时间: {duration:.4f}s")
76
+ print(f" 平均时间: {avg_time:.4f}ms")
77
+ print(f" 吞吐量: {throughput:.1f} ops/s")
78
+
79
+ total_time += duration
80
+ total_iterations += test_case['iterations']
81
+
82
+ # 验证结果正确性
83
+ try:
84
+ parsed = json.loads(result)
85
+ print(f" ✅ 结果有效")
86
+ except:
87
+ print(f" ❌ 结果无效")
88
+
89
+ print()
90
+
91
+ print(f"📊 总体性能:")
92
+ print(f" 总时间: {total_time:.4f}s")
93
+ print(f" 总迭代: {total_iterations}")
94
+ if total_time > 0:
95
+ print(f" 平均性能: {total_iterations/total_time:.1f} ops/s")
96
+ print(f" 平均延迟: {total_time/total_iterations*1000:.4f}ms")
97
+ else:
98
+ print(f" 平均性能: ∞ ops/s")
99
+ print(f" 平均延迟: 0.0000ms")
100
+
101
+ def test_code_simplification_benefits():
102
+ """测试代码简化的好处"""
103
+
104
+ print("\n🧪 测试代码简化的好处\n")
105
+
106
+ # 测试不同复杂度的JSON
107
+ test_cases = [
108
+ '{"command":"echo hello"}', # 简单
109
+ '{"command":"echo \\"hello\\"","description":"test"}', # 转义引号
110
+ '"command":"echo hello","description":"test"}', # 缺少开始括号
111
+ '{"command":"echo hello > file.txt\\"","description":"test"}', # 多余引号
112
+ ]
113
+
114
+ handler = SSEToolHandler("test-model", stream=False)
115
+
116
+ print("测试各种JSON修复场景:")
117
+ for i, test_input in enumerate(test_cases, 1):
118
+ print(f"\n场景 {i}: {test_input[:50]}{'...' if len(test_input) > 50 else ''}")
119
+
120
+ start_time = time.time()
121
+ result = handler._fix_tool_arguments(test_input)
122
+ end_time = time.time()
123
+
124
+ duration = (end_time - start_time) * 1000 # 毫秒
125
+
126
+ try:
127
+ parsed = json.loads(result)
128
+ status = "✅ 成功"
129
+ except:
130
+ status = "❌ 失败"
131
+
132
+ print(f" 处理时间: {duration:.4f}ms")
133
+ print(f" 修复状态: {status}")
134
+ print(f" 结果长度: {len(result)} 字符")
135
+
136
+ def test_memory_efficiency():
137
+ """测试内存效率"""
138
+
139
+ print("\n🧪 测试内存效率\n")
140
+
141
+ try:
142
+ import psutil
143
+ process = psutil.Process()
144
+
145
+ # 基线内存
146
+ baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
147
+ print(f"基线内存: {baseline_memory:.2f} MB")
148
+
149
+ handler = SSEToolHandler("test-model", stream=False)
150
+
151
+ # 测试大量小JSON
152
+ test_data = '{"command":"echo test","description":"test"}'
153
+
154
+ start_memory = process.memory_info().rss / 1024 / 1024
155
+
156
+ for i in range(100):
157
+ result = handler._fix_tool_arguments(test_data)
158
+
159
+ end_memory = process.memory_info().rss / 1024 / 1024
160
+
161
+ print(f"处理100次后内存: {end_memory:.2f} MB")
162
+ print(f"内存增长: {end_memory - baseline_memory:.2f} MB")
163
+ print(f"平均每次处理: {(end_memory - start_memory) / 100 * 1024:.2f} KB")
164
+
165
+ except ImportError:
166
+ print("psutil 未安装,跳过内存测试")
167
+
168
+ if __name__ == "__main__":
169
+ test_optimized_performance()
170
+ test_code_simplification_benefits()
171
+ test_memory_efficiency()
172
+
173
+ print("\n🎯 优化总结:")
174
+ print("✅ 简化了预处理逻辑")
175
+ print("✅ 统一了修复流程")
176
+ print("✅ 减少了代码复杂度")
177
+ print("✅ 保持了修复质量")
178
+ print("✅ 提高了可维护性")
tokens.txt.example ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 认证Token配置文件(可选)
2
+ #
3
+ # 说明:
4
+ # 1. 此文件是可选的,如果不需要备用token可以删除此文件
5
+ # 2. 支持两种格式:每行一个token 或 逗号分隔的token
6
+ # 3. 只包含认证用户token (role: "user"),不要添加匿名用户token (role: "guest")
7
+ # 4. 系统会自动去重和验证token有效性
8
+ # 5. 自动跳过空格、换行符和空token
9
+ # 6. 当匿名模式正常工作时,此文件中的token不会被使用
10
+ #
11
+ # 格式1:纯换行分隔
12
+ # token1
13
+ # token2
14
+ # token3
15
+
16
+ # 格式2:纯逗号分隔
17
+ # token1,token2,token3
18
+
19
+ # 格式3:混合格式
20
+ # token1,token2
21
+ # token3
22
+ # token4,token5,token6
23
+ # token7
24
+
25
+ # 请在下方添加您的认证用户token(使用任一格式):
26
+