youbiaokachi commited on
Commit
ea1a2cc
·
verified ·
1 Parent(s): 8ee393d

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. k2think_proxy.py +1166 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # 复制依赖文件
6
+ COPY requirements.txt .
7
+
8
+ # 安装依赖
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # 复制应用代码
12
+ COPY k2think_proxy.py .
13
+
14
+ # 暴露端口
15
+ EXPOSE 8001
16
+
17
+ # 健康检查
18
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
19
+ CMD python -c "import requests; requests.get('http://localhost:8001/', timeout=10)" || exit 1
20
+
21
+ # 启动应用
22
+ CMD ["python", "k2think_proxy.py"]
k2think_proxy.py ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request, Response
2
+ from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict, Optional, Union, AsyncGenerator
6
+ import httpx
7
+ import json
8
+ import asyncio
9
+ import time
10
+ import os
11
+ import logging
12
+ import re
13
+ from contextlib import asynccontextmanager
14
+ from dotenv import load_dotenv
15
+
16
+ # 加载环境变量
17
+ load_dotenv()
18
+
19
+ # 配置
20
+ VALID_API_KEY = os.getenv("VALID_API_KEY")
21
+ if not VALID_API_KEY:
22
+ raise ValueError("错误:VALID_API_KEY 环境变量未设置。请在 .env 文件中提供一个安全的API密钥。")
23
+ K2THINK_API_URL = os.getenv("K2THINK_API_URL", "https://www.k2think.ai/api/chat/completions")
24
+ K2THINK_TOKEN = os.getenv("K2THINK_TOKEN")
25
+ OUTPUT_THINKING = os.getenv("OUTPUT_THINKING", "true").lower() == "true"
26
+ TOOL_SUPPORT = os.getenv("TOOL_SUPPORT", "true").lower() == "true"
27
+ SCAN_LIMIT = int(os.getenv("SCAN_LIMIT", "200000"))
28
+ SYSTEM_MESSAGE_LENTH = int(os.getenv("SYSTEM_MESSAGE_LENTH", "200000"))
29
+
30
+ # 高级配置
31
+ REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "60"))
32
+ MAX_KEEPALIVE_CONNECTIONS = int(os.getenv("MAX_KEEPALIVE_CONNECTIONS", "20"))
33
+ MAX_CONNECTIONS = int(os.getenv("MAX_CONNECTIONS", "100"))
34
+ DEBUG_LOGGING = os.getenv("DEBUG_LOGGING", "false").lower() == "true"
35
+ STREAM_DELAY = float(os.getenv("STREAM_DELAY", "0.05"))
36
+ STREAM_CHUNK_SIZE = int(os.getenv("STREAM_CHUNK_SIZE", "50"))
37
+ MAX_STREAM_TIME = float(os.getenv("MAX_STREAM_TIME", "10.0")) # 最大流式输出时间(秒)
38
+ ENABLE_ACCESS_LOG = os.getenv("ENABLE_ACCESS_LOG", "true").lower() == "true"
39
+ CORS_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",") if os.getenv("CORS_ORIGINS", "*") != "*" else ["*"]
40
+
41
+ # 设置日志
42
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
43
+ if LOG_LEVEL == "DEBUG":
44
+ logging.basicConfig(level=logging.DEBUG)
45
+ elif LOG_LEVEL == "WARNING":
46
+ logging.basicConfig(level=logging.WARNING)
47
+ elif LOG_LEVEL == "ERROR":
48
+ logging.basicConfig(level=logging.ERROR)
49
+ else:
50
+ logging.basicConfig(level=logging.INFO)
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ # 数据模型
55
+ class ContentPart(BaseModel):
56
+ """Content part model for OpenAI's new content format"""
57
+ type: str
58
+ text: Optional[str] = None
59
+
60
+ class Message(BaseModel):
61
+ role: str
62
+ content: Optional[Union[str, List[ContentPart]]] = None
63
+ tool_calls: Optional[List[Dict]] = None
64
+
65
+ class ChatCompletionRequest(BaseModel):
66
+ model: str = "MBZUAI-IFM/K2-Think"
67
+ messages: List[Message]
68
+ stream: bool = False
69
+ temperature: float = 0.7
70
+ max_tokens: Optional[int] = None
71
+ top_p: Optional[float] = None
72
+ frequency_penalty: Optional[float] = None
73
+ presence_penalty: Optional[float] = None
74
+ stop: Optional[Union[str, List[str]]] = None
75
+ tools: Optional[List[Dict]] = None
76
+ tool_choice: Optional[Union[str, Dict]] = None
77
+
78
+ class ModelInfo(BaseModel):
79
+ id: str
80
+ object: str = "model"
81
+ created: int
82
+ owned_by: str
83
+ permission: List[Dict] = []
84
+ root: str
85
+ parent: Optional[str] = None
86
+
87
+ class ModelsResponse(BaseModel):
88
+ object: str = "list"
89
+ data: List[ModelInfo]
90
+
91
+ # HTTP客户端工厂函数
92
+ def create_http_client() -> httpx.AsyncClient:
93
+ """创建HTTP客户端"""
94
+ base_kwargs = {
95
+ "timeout": httpx.Timeout(timeout=None, connect=10.0),
96
+ "limits": httpx.Limits(
97
+ max_keepalive_connections=MAX_KEEPALIVE_CONNECTIONS,
98
+ max_connections=MAX_CONNECTIONS
99
+ ),
100
+ "follow_redirects": True
101
+ }
102
+
103
+ try:
104
+ return httpx.AsyncClient(**base_kwargs)
105
+ except Exception as e:
106
+ logger.error(f"创建客户端失败: {e}")
107
+ raise e
108
+
109
+ # 全局HTTP客户端管理
110
+ @asynccontextmanager
111
+ async def lifespan(app: FastAPI):
112
+ yield
113
+
114
+ # 创建FastAPI应用
115
+ app = FastAPI(title="K2Think API Proxy", lifespan=lifespan)
116
+
117
+ # CORS配置
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=CORS_ORIGINS,
121
+ allow_credentials=True,
122
+ allow_methods=["*"],
123
+ allow_headers=["*"],
124
+ )
125
+
126
+
127
+ def validate_api_key(authorization: str) -> bool:
128
+ """验证API密钥"""
129
+ if not authorization or not authorization.startswith("Bearer "):
130
+ return False
131
+ api_key = authorization[7:] # 移除 "Bearer " 前缀
132
+ return api_key == VALID_API_KEY
133
+
134
+ def generate_session_id() -> str:
135
+ """生成会话ID"""
136
+ import uuid
137
+ return str(uuid.uuid4())
138
+
139
+ def generate_chat_id() -> str:
140
+ """生成聊天ID"""
141
+ import uuid
142
+ return str(uuid.uuid4())
143
+
144
+ def get_current_datetime_info():
145
+ """获取当前时间信息"""
146
+ from datetime import datetime
147
+ import pytz
148
+
149
+ # 设置时区为上海
150
+ tz = pytz.timezone('Asia/Shanghai')
151
+ now = datetime.now(tz)
152
+
153
+ return {
154
+ "{{USER_NAME}}": "User",
155
+ "{{USER_LOCATION}}": "Unknown",
156
+ "{{CURRENT_DATETIME}}": now.strftime("%Y-%m-%d %H:%M:%S"),
157
+ "{{CURRENT_DATE}}": now.strftime("%Y-%m-%d"),
158
+ "{{CURRENT_TIME}}": now.strftime("%H:%M:%S"),
159
+ "{{CURRENT_WEEKDAY}}": now.strftime("%A"),
160
+ "{{CURRENT_TIMEZONE}}": "Asia/Shanghai",
161
+ "{{USER_LANGUAGE}}": "en-US"
162
+ }
163
+
164
+ def extract_answer_content(full_content: str) -> str:
165
+ """删除第一个<answer>标签和最后一个</answer>标签,保留内容"""
166
+ if not full_content:
167
+ return full_content
168
+ if OUTPUT_THINKING:
169
+ # 删除第一个<answer>
170
+ answer_start = full_content.find('<answer>')
171
+ if answer_start != -1:
172
+ full_content = full_content[:answer_start] + full_content[answer_start + 8:]
173
+
174
+ # 删除最后一个</answer>
175
+ answer_end = full_content.rfind('</answer>')
176
+ if answer_end != -1:
177
+ full_content = full_content[:answer_end] + full_content[answer_end + 9:]
178
+
179
+ return full_content.strip()
180
+ else:
181
+ # 删除<think>部分(包括标签)
182
+ think_start = full_content.find('<think>')
183
+ think_end = full_content.find('</think>')
184
+ if think_start != -1 and think_end != -1:
185
+ full_content = full_content[:think_start] + full_content[think_end + 8:]
186
+
187
+ # 删除<answer>标签及其内容之外的部分
188
+ answer_start = full_content.find('<answer>')
189
+ answer_end = full_content.rfind('</answer>')
190
+ if answer_start != -1 and answer_end != -1:
191
+ content = full_content[answer_start + 8:answer_end]
192
+ return content.strip()
193
+
194
+ return full_content.strip()
195
+
196
+ def calculate_dynamic_chunk_size(content_length: int) -> int:
197
+ """
198
+ 动态计算流式输出的chunk大小
199
+ 确保总输出时间不超过MAX_STREAM_TIME秒
200
+
201
+ Args:
202
+ content_length: 待输出内容的总长度
203
+
204
+ Returns:
205
+ int: 动态计算的chunk大小,最小为50
206
+ """
207
+ if content_length <= 0:
208
+ return STREAM_CHUNK_SIZE
209
+
210
+ # 计算需要的总chunk数量以满足时间限制
211
+ # 总时间 = chunk数量 * STREAM_DELAY
212
+ # chunk数量 = content_length / chunk_size
213
+ # 所以:总时间 = (content_length / chunk_size) * STREAM_DELAY
214
+ # 解出:chunk_size = (content_length * STREAM_DELAY) / MAX_STREAM_TIME
215
+
216
+ calculated_chunk_size = int((content_length * STREAM_DELAY) / MAX_STREAM_TIME)
217
+
218
+ # 确保chunk_size不小于最小值50
219
+ min_chunk_size = 50
220
+ dynamic_chunk_size = max(calculated_chunk_size, min_chunk_size)
221
+
222
+ # 如果计算出的chunk_size太大(比如内容很短),使用默认值
223
+ if dynamic_chunk_size > content_length:
224
+ dynamic_chunk_size = min(STREAM_CHUNK_SIZE, content_length)
225
+
226
+ logger.debug(f"动态chunk_size计算: 内容长度={content_length}, 计算值={calculated_chunk_size}, 最终值={dynamic_chunk_size}")
227
+
228
+ return dynamic_chunk_size
229
+
230
+ def content_to_string(content) -> str:
231
+ """Convert content from various formats to string"""
232
+ if content is None:
233
+ return ""
234
+ if isinstance(content, str):
235
+ return content
236
+ if isinstance(content, list):
237
+ parts = []
238
+ for p in content:
239
+ if hasattr(p, 'text'): # ContentPart object
240
+ parts.append(getattr(p, 'text', ''))
241
+ elif isinstance(p, dict) and p.get("type") == "text":
242
+ parts.append(p.get("text", ""))
243
+ elif isinstance(p, str):
244
+ parts.append(p)
245
+ else:
246
+ # 处理其他类型的对象
247
+ try:
248
+ if hasattr(p, '__dict__'):
249
+ # 如果是对象,尝试获取text属性或转换为字符串
250
+ parts.append(str(getattr(p, 'text', str(p))))
251
+ else:
252
+ parts.append(str(p))
253
+ except:
254
+ continue
255
+ return " ".join(parts)
256
+ # 处理其他类型
257
+ try:
258
+ return str(content)
259
+ except:
260
+ return ""
261
+
262
+ def generate_tool_prompt(tools: List[Dict]) -> str:
263
+ """Generate concise tool injection prompt"""
264
+ if not tools:
265
+ return ""
266
+
267
+ tool_definitions = []
268
+ for tool in tools:
269
+ if tool.get("type") != "function":
270
+ continue
271
+
272
+ function_spec = tool.get("function", {}) or {}
273
+ function_name = function_spec.get("name", "unknown")
274
+ function_description = function_spec.get("description", "")
275
+ parameters = function_spec.get("parameters", {}) or {}
276
+
277
+ # Create concise tool definition
278
+ tool_info = f"{function_name}: {function_description}"
279
+
280
+ # Add simplified parameter info
281
+ parameter_properties = parameters.get("properties", {}) or {}
282
+ required_parameters = set(parameters.get("required", []) or [])
283
+
284
+ if parameter_properties:
285
+ param_list = []
286
+ for param_name, param_details in parameter_properties.items():
287
+ param_desc = (param_details or {}).get("description", "")
288
+ is_required = param_name in required_parameters
289
+ param_list.append(f"{param_name}{'*' if is_required else ''}: {param_desc}")
290
+ tool_info += f" Parameters: {', '.join(param_list)}"
291
+
292
+ tool_definitions.append(tool_info)
293
+
294
+ if not tool_definitions:
295
+ return ""
296
+
297
+ # Build concise tool prompt
298
+ prompt_template = (
299
+ f"\n\nAvailable tools: {'; '.join(tool_definitions)}. "
300
+ "To use a tool, respond with JSON: "
301
+ '{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{\\"param\\":\\"value\\"}"}}]}'
302
+ )
303
+
304
+ return prompt_template
305
+
306
+ def process_messages_with_tools(messages: List[Dict], tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> List[Dict]:
307
+ """Process messages and inject tool prompts"""
308
+ if not tools or not TOOL_SUPPORT or (tool_choice == "none"):
309
+ # 如果没有工具或禁用工具,直接返回原消息
310
+ return [dict(m) for m in messages]
311
+
312
+ tools_prompt = generate_tool_prompt(tools)
313
+
314
+ # 限制工具提示长度,避免过长导致上游API拒绝
315
+ if len(tools_prompt) > 1000:
316
+ logger.warning(f"工具提示过长 ({len(tools_prompt)} 字符),将截断")
317
+ tools_prompt = tools_prompt[:1000] + "..."
318
+
319
+ processed = []
320
+ has_system = any(m.get("role") == "system" for m in messages)
321
+
322
+ if has_system:
323
+ # 如果已有系统消息,在第一个系统消息中添加工具提示
324
+ for m in messages:
325
+ if m.get("role") == "system":
326
+ mm = dict(m)
327
+ content = content_to_string(mm.get("content", ""))
328
+ # 确保系统消息不会过长
329
+ new_content = content + tools_prompt
330
+ if len(new_content) > SYSTEM_MESSAGE_LENTH:
331
+ logger.warning(f"系统消息过长 ({len(new_content)} 字符),使用简化版本")
332
+ mm["content"] = "你是一个有用的助手。" + tools_prompt
333
+ else:
334
+ mm["content"] = new_content
335
+ processed.append(mm)
336
+ # 只在第一个系统消息中添加工具提示
337
+ tools_prompt = ""
338
+ else:
339
+ processed.append(dict(m))
340
+ else:
341
+ # 如果没有系统消息,需要添加一个,但只有当确实需要工具时
342
+ if tools_prompt.strip():
343
+ processed = [{"role": "system", "content": "你是一个有用的助手。" + tools_prompt}]
344
+ processed.extend([dict(m) for m in messages])
345
+ else:
346
+ processed = [dict(m) for m in messages]
347
+
348
+ # Add simplified tool choice hints
349
+ if tool_choice == "required":
350
+ if processed and processed[-1].get("role") == "user":
351
+ last = processed[-1]
352
+ content = content_to_string(last.get("content", ""))
353
+ last["content"] = content + "\n请使用工具来处理这个请求。"
354
+ elif isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
355
+ fname = (tool_choice.get("function") or {}).get("name")
356
+ if fname and processed and processed[-1].get("role") == "user":
357
+ last = processed[-1]
358
+ content = content_to_string(last.get("content", ""))
359
+ last["content"] = content + f"\n请使用 {fname} 工具。"
360
+
361
+ # Handle tool/function messages
362
+ final_msgs = []
363
+ for m in processed:
364
+ role = m.get("role")
365
+ if role in ("tool", "function"):
366
+ tool_name = m.get("name", "unknown")
367
+ tool_content = content_to_string(m.get("content", ""))
368
+ if isinstance(tool_content, dict):
369
+ tool_content = json.dumps(tool_content, ensure_ascii=False)
370
+
371
+ # 简化工具结果消息
372
+ content = f"工具 {tool_name} 结果: {tool_content}"
373
+ if not content.strip():
374
+ content = f"工具 {tool_name} 执行完成"
375
+
376
+ final_msgs.append({
377
+ "role": "assistant",
378
+ "content": content,
379
+ })
380
+ else:
381
+ # For regular messages, ensure content is string format
382
+ final_msg = dict(m)
383
+ content = content_to_string(final_msg.get("content", ""))
384
+ final_msg["content"] = content
385
+ final_msgs.append(final_msg)
386
+
387
+ return final_msgs
388
+
389
+ # Tool Extraction Patterns
390
+ TOOL_CALL_FENCE_PATTERN = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
391
+ FUNCTION_CALL_PATTERN = re.compile(r"调用函数\s*[::]\s*([\w\-\.]+)\s*(?:参数|arguments)[::]\s*(\{.*?\})", re.DOTALL)
392
+
393
+ def extract_tool_invocations(text: str) -> Optional[List[Dict]]:
394
+ """Extract tool invocations from response text"""
395
+ if not text:
396
+ return None
397
+
398
+ # Limit scan size for performance
399
+ scannable_text = text[:SCAN_LIMIT]
400
+
401
+ # Attempt 1: Extract from JSON code blocks
402
+ json_blocks = TOOL_CALL_FENCE_PATTERN.findall(scannable_text)
403
+ for json_block in json_blocks:
404
+ try:
405
+ parsed_data = json.loads(json_block)
406
+ tool_calls = parsed_data.get("tool_calls")
407
+ if tool_calls and isinstance(tool_calls, list):
408
+ # Ensure arguments field is a string
409
+ for tc in tool_calls:
410
+ if "function" in tc:
411
+ func = tc["function"]
412
+ if "arguments" in func:
413
+ if isinstance(func["arguments"], dict):
414
+ # Convert dict to JSON string
415
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
416
+ elif not isinstance(func["arguments"], str):
417
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
418
+ return tool_calls
419
+ except (json.JSONDecodeError, AttributeError):
420
+ continue
421
+
422
+ # Attempt 2: Extract inline JSON objects using bracket balance method
423
+ i = 0
424
+ while i < len(scannable_text):
425
+ if scannable_text[i] == '{':
426
+ # 尝试找到匹配的右括号
427
+ brace_count = 1
428
+ j = i + 1
429
+ in_string = False
430
+ escape_next = False
431
+
432
+ while j < len(scannable_text) and brace_count > 0:
433
+ if escape_next:
434
+ escape_next = False
435
+ elif scannable_text[j] == '\\':
436
+ escape_next = True
437
+ elif scannable_text[j] == '"' and not escape_next:
438
+ in_string = not in_string
439
+ elif not in_string:
440
+ if scannable_text[j] == '{':
441
+ brace_count += 1
442
+ elif scannable_text[j] == '}':
443
+ brace_count -= 1
444
+ j += 1
445
+
446
+ if brace_count == 0:
447
+ # 找到了完整的 JSON 对象
448
+ json_str = scannable_text[i:j]
449
+ try:
450
+ parsed_data = json.loads(json_str)
451
+ tool_calls = parsed_data.get("tool_calls")
452
+ if tool_calls and isinstance(tool_calls, list):
453
+ # Ensure arguments field is a string
454
+ for tc in tool_calls:
455
+ if "function" in tc:
456
+ func = tc["function"]
457
+ if "arguments" in func:
458
+ if isinstance(func["arguments"], dict):
459
+ # Convert dict to JSON string
460
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
461
+ elif not isinstance(func["arguments"], str):
462
+ func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
463
+ return tool_calls
464
+ except (json.JSONDecodeError, AttributeError):
465
+ pass
466
+
467
+ i += 1
468
+ else:
469
+ i += 1
470
+
471
+ # Attempt 3: Parse natural language function calls
472
+ natural_lang_match = FUNCTION_CALL_PATTERN.search(scannable_text)
473
+ if natural_lang_match:
474
+ function_name = natural_lang_match.group(1).strip()
475
+ arguments_str = natural_lang_match.group(2).strip()
476
+ try:
477
+ # Validate JSON format
478
+ json.loads(arguments_str)
479
+ return [
480
+ {
481
+ "id": f"call_{int(time.time() * 1000000)}",
482
+ "type": "function",
483
+ "function": {"name": function_name, "arguments": arguments_str},
484
+ }
485
+ ]
486
+ except json.JSONDecodeError:
487
+ return None
488
+
489
+ return None
490
+
491
+ def remove_tool_json_content(text: str) -> str:
492
+ """Remove tool JSON content from response text - using bracket balance method"""
493
+
494
+ def remove_tool_call_block(match: re.Match) -> str:
495
+ json_content = match.group(1)
496
+ try:
497
+ parsed_data = json.loads(json_content)
498
+ if "tool_calls" in parsed_data:
499
+ return ""
500
+ except (json.JSONDecodeError, AttributeError):
501
+ pass
502
+ return match.group(0)
503
+
504
+ # Step 1: Remove fenced tool JSON blocks
505
+ cleaned_text = TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
506
+
507
+ # Step 2: Remove inline tool JSON - 使用基于括号平衡的智能方法
508
+ result = []
509
+ i = 0
510
+ while i < len(cleaned_text):
511
+ if cleaned_text[i] == '{':
512
+ # 尝试找到匹配的右括号
513
+ brace_count = 1
514
+ j = i + 1
515
+ in_string = False
516
+ escape_next = False
517
+
518
+ while j < len(cleaned_text) and brace_count > 0:
519
+ if escape_next:
520
+ escape_next = False
521
+ elif cleaned_text[j] == '\\':
522
+ escape_next = True
523
+ elif cleaned_text[j] == '"' and not escape_next:
524
+ in_string = not in_string
525
+ elif not in_string:
526
+ if cleaned_text[j] == '{':
527
+ brace_count += 1
528
+ elif cleaned_text[j] == '}':
529
+ brace_count -= 1
530
+ j += 1
531
+
532
+ if brace_count == 0:
533
+ # 找到了完整的 JSON 对象
534
+ json_str = cleaned_text[i:j]
535
+ try:
536
+ parsed = json.loads(json_str)
537
+ if "tool_calls" in parsed:
538
+ # 这是一个工具调用,跳过它
539
+ i = j
540
+ continue
541
+ except:
542
+ pass
543
+
544
+ # 不是工具调用或无法解析,保留这个字符
545
+ result.append(cleaned_text[i])
546
+ i += 1
547
+ else:
548
+ result.append(cleaned_text[i])
549
+ i += 1
550
+
551
+ return ''.join(result).strip()
552
+
553
+ async def make_request(method: str, url: str, headers: dict, json_data: dict = None,
554
+ stream: bool = False) -> httpx.Response:
555
+ """发送HTTP请求"""
556
+ client = None
557
+
558
+ try:
559
+ client = create_http_client()
560
+
561
+ if stream:
562
+ # 流式请求返回context manager
563
+ return client.stream(method, url, headers=headers, json=json_data, timeout=None)
564
+ else:
565
+ response = await client.request(method, url, headers=headers, json=json_data, timeout=REQUEST_TIMEOUT)
566
+
567
+ # 详细记录非200响应
568
+ if response.status_code != 200:
569
+ logger.error(f"上游API返回错误状态码: {response.status_code}")
570
+ logger.error(f"响应头: {dict(response.headers)}")
571
+ try:
572
+ error_body = response.text
573
+ logger.error(f"错误响应体: {error_body}")
574
+ except:
575
+ logger.error("无法读取错误响应体")
576
+
577
+ response.raise_for_status()
578
+ return response
579
+
580
+ except httpx.HTTPStatusError as e:
581
+ logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}")
582
+ if client and not stream:
583
+ await client.aclose()
584
+ raise e
585
+ except Exception as e:
586
+ logger.error(f"请求异常: {e}")
587
+ if client and not stream:
588
+ await client.aclose()
589
+ raise e
590
+
591
+ @app.get("/")
592
+ async def homepage():
593
+ """首页 - 返回服务状态"""
594
+ return JSONResponse(content={
595
+ "status": "success",
596
+ "message": "K2Think API Proxy is running",
597
+ "service": "K2Think API Gateway",
598
+ "model": "MBZUAI-IFM/K2-Think",
599
+ "version": "1.0.0",
600
+ "endpoints": {
601
+ "chat": "/v1/chat/completions",
602
+ "models": "/v1/models"
603
+ }
604
+ })
605
+
606
+ @app.get("/health")
607
+ async def health_check():
608
+ """健康检查"""
609
+ return JSONResponse(content={
610
+ "status": "healthy",
611
+ "timestamp": int(time.time())
612
+ })
613
+
614
+ @app.get("/favicon.ico")
615
+ async def favicon():
616
+ """返回favicon"""
617
+ return Response(content="", media_type="image/x-icon")
618
+
619
+ @app.get("/v1/models")
620
+ async def get_models() -> ModelsResponse:
621
+ """获取模型列表"""
622
+ model_info = ModelInfo(
623
+ id="MBZUAI-IFM/K2-Think",
624
+ created=int(time.time()),
625
+ owned_by="MBZUAI",
626
+ root="mbzuai-k2-think-2508"
627
+ )
628
+ return ModelsResponse(data=[model_info])
629
+
630
+
631
+ async def process_non_stream_response(k2think_payload: dict, headers: dict) -> tuple[str, dict]:
632
+ """处理非流式响应"""
633
+ try:
634
+ response = await make_request(
635
+ "POST",
636
+ K2THINK_API_URL,
637
+ headers,
638
+ k2think_payload,
639
+ stream=False
640
+ )
641
+
642
+ # K2Think 非流式请求返回标准JSON格式
643
+ result = response.json()
644
+
645
+ # 提取内容
646
+ full_content = ""
647
+ if result.get('choices') and len(result['choices']) > 0:
648
+ choice = result['choices'][0]
649
+ if choice.get('message') and choice['message'].get('content'):
650
+ raw_content = choice['message']['content']
651
+ # 提取<answer>标签中的内容,去除标签
652
+ full_content = extract_answer_content(raw_content)
653
+
654
+ # 提取token信息
655
+ token_info = result.get('usage', {
656
+ "prompt_tokens": 0,
657
+ "completion_tokens": 0,
658
+ "total_tokens": 0
659
+ })
660
+
661
+ await response.aclose()
662
+ return full_content, token_info
663
+
664
+ except Exception as e:
665
+ logger.error(f"处理非流式响应错误: {e}")
666
+ raise
667
+
668
+ async def process_stream_response(k2think_payload: dict, headers: dict) -> AsyncGenerator[str, None]:
669
+ """处理流式响应 - 使用模拟流式输出"""
670
+ try:
671
+ # 将流式请求转换为非流式请求
672
+ k2think_payload_copy = k2think_payload.copy()
673
+ k2think_payload_copy["stream"] = False
674
+
675
+ # 修改headers为非流式
676
+ headers_copy = headers.copy()
677
+ headers_copy["accept"] = "application/json"
678
+
679
+ # 获取完整响应
680
+ full_content, token_info = await process_non_stream_response(k2think_payload_copy, headers_copy)
681
+
682
+ if not full_content:
683
+ yield "data: [DONE]\n\n"
684
+ return
685
+
686
+ # 开始流式输出 - 发送开始chunk
687
+ start_chunk = {
688
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
689
+ "object": "chat.completion.chunk",
690
+ "created": int(time.time()),
691
+ "model": "MBZUAI-IFM/K2-Think",
692
+ "choices": [{
693
+ "index": 0,
694
+ "delta": {
695
+ "role": "assistant",
696
+ "content": ""
697
+ },
698
+ "finish_reason": None
699
+ }]
700
+ }
701
+ yield f"data: {json.dumps(start_chunk)}\n\n"
702
+
703
+ # 模拟流式输出 - 按字符分块发送,使用动态chunk_size
704
+
705
+ chunk_size = calculate_dynamic_chunk_size(len(full_content)) # 动态计算每次发送的字符数
706
+
707
+ for i in range(0, len(full_content), chunk_size):
708
+ chunk_content = full_content[i:i + chunk_size]
709
+
710
+ chunk = {
711
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
712
+ "object": "chat.completion.chunk",
713
+ "created": int(time.time()),
714
+ "model": "MBZUAI-IFM/K2-Think",
715
+ "choices": [{
716
+ "index": 0,
717
+ "delta": {
718
+ "content": chunk_content
719
+ },
720
+ "finish_reason": None
721
+ }]
722
+ }
723
+
724
+ yield f"data: {json.dumps(chunk)}\n\n"
725
+ # 添加小延迟模拟真实流式效果
726
+ await asyncio.sleep(STREAM_DELAY)
727
+
728
+ # 发送结束chunk
729
+ end_chunk = {
730
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
731
+ "object": "chat.completion.chunk",
732
+ "created": int(time.time()),
733
+ "model": "MBZUAI-IFM/K2-Think",
734
+ "choices": [{
735
+ "index": 0,
736
+ "delta": {},
737
+ "finish_reason": "stop"
738
+ }]
739
+ }
740
+ yield f"data: {json.dumps(end_chunk)}\n\n"
741
+ yield "data: [DONE]\n\n"
742
+
743
+ except Exception as e:
744
+ logger.error(f"流式请求失败: {e}")
745
+ # 发送错误信息
746
+ error_chunk = {
747
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
748
+ "object": "chat.completion.chunk",
749
+ "created": int(time.time()),
750
+ "model": "MBZUAI-IFM/K2-Think",
751
+ "choices": [{
752
+ "index": 0,
753
+ "delta": {
754
+ "content": f"Error: {str(e)}"
755
+ },
756
+ "finish_reason": "stop"
757
+ }]
758
+ }
759
+ yield f"data: {json.dumps(error_chunk)}\n\n"
760
+ yield "data: [DONE]\n\n"
761
+
762
+ async def process_stream_response_with_tools(k2think_payload: dict, headers: dict, has_tools: bool = False) -> AsyncGenerator[str, None]:
763
+ """处理流式响应 - 支持工具调用,优化性能"""
764
+ try:
765
+ # 发送开始chunk
766
+ start_chunk = {
767
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
768
+ "object": "chat.completion.chunk",
769
+ "created": int(time.time()),
770
+ "model": "MBZUAI-IFM/K2-Think",
771
+ "choices": [{
772
+ "index": 0,
773
+ "delta": {
774
+ "role": "assistant",
775
+ "content": ""
776
+ },
777
+ "finish_reason": None
778
+ }]
779
+ }
780
+ yield f"data: {json.dumps(start_chunk)}\n\n"
781
+
782
+ # 优化的模拟流式输出 - 立即开始获取响应并流式发送
783
+ k2think_payload_copy = k2think_payload.copy()
784
+ k2think_payload_copy["stream"] = False
785
+
786
+ headers_copy = headers.copy()
787
+ headers_copy["accept"] = "application/json"
788
+
789
+ # 获取完整响应
790
+ full_content, token_info = await process_non_stream_response(k2think_payload_copy, headers_copy)
791
+
792
+ if not full_content:
793
+ yield "data: [DONE]\n\n"
794
+ return
795
+
796
+ # Handle tool calls for streaming
797
+ finish_reason = "stop"
798
+ if has_tools:
799
+ tool_calls = extract_tool_invocations(full_content)
800
+ if tool_calls:
801
+ # Send tool calls with proper format
802
+ for i, tc in enumerate(tool_calls):
803
+ tool_call_delta = {
804
+ "index": i,
805
+ "id": tc.get("id"),
806
+ "type": tc.get("type", "function"),
807
+ "function": tc.get("function", {}),
808
+ }
809
+
810
+ tool_chunk = {
811
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
812
+ "object": "chat.completion.chunk",
813
+ "created": int(time.time()),
814
+ "model": "MBZUAI-IFM/K2-Think",
815
+ "choices": [{
816
+ "index": 0,
817
+ "delta": {
818
+ "tool_calls": [tool_call_delta]
819
+ },
820
+ "finish_reason": None
821
+ }]
822
+ }
823
+ yield f"data: {json.dumps(tool_chunk)}\n\n"
824
+
825
+ finish_reason = "tool_calls"
826
+ else:
827
+ # Send regular content with true streaming feel
828
+ trimmed_content = remove_tool_json_content(full_content)
829
+ if trimmed_content:
830
+ # 快速流式输出 - 动态计算块大小
831
+ chunk_size = calculate_dynamic_chunk_size(len(trimmed_content)) # 动态计算每次发送的字符数
832
+
833
+ for i in range(0, len(trimmed_content), chunk_size):
834
+ chunk_content = trimmed_content[i:i + chunk_size]
835
+
836
+ chunk = {
837
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
838
+ "object": "chat.completion.chunk",
839
+ "created": int(time.time()),
840
+ "model": "MBZUAI-IFM/K2-Think",
841
+ "choices": [{
842
+ "index": 0,
843
+ "delta": {
844
+ "content": chunk_content
845
+ },
846
+ "finish_reason": None
847
+ }]
848
+ }
849
+
850
+ yield f"data: {json.dumps(chunk)}\n\n"
851
+ # 添加极小延迟确保块分别发送
852
+ await asyncio.sleep(STREAM_DELAY) # 毫秒延迟
853
+ else:
854
+ # No tools - send regular content with fast streaming
855
+ chunk_size = calculate_dynamic_chunk_size(len(full_content)) # 动态计算每次发送的字符数
856
+
857
+ for i in range(0, len(full_content), chunk_size):
858
+ chunk_content = full_content[i:i + chunk_size]
859
+
860
+ chunk = {
861
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
862
+ "object": "chat.completion.chunk",
863
+ "created": int(time.time()),
864
+ "model": "MBZUAI-IFM/K2-Think",
865
+ "choices": [{
866
+ "index": 0,
867
+ "delta": {
868
+ "content": chunk_content
869
+ },
870
+ "finish_reason": None
871
+ }]
872
+ }
873
+
874
+ yield f"data: {json.dumps(chunk)}\n\n"
875
+ # 添加极小延迟确保块分别发送
876
+ await asyncio.sleep(STREAM_DELAY) # 毫秒延迟
877
+
878
+ # 发送结束chunk
879
+ end_chunk = {
880
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
881
+ "object": "chat.completion.chunk",
882
+ "created": int(time.time()),
883
+ "model": "MBZUAI-IFM/K2-Think",
884
+ "choices": [{
885
+ "index": 0,
886
+ "delta": {},
887
+ "finish_reason": finish_reason
888
+ }]
889
+ }
890
+ yield f"data: {json.dumps(end_chunk)}\n\n"
891
+ yield "data: [DONE]\n\n"
892
+
893
+ except Exception as e:
894
+ logger.error(f"流式响应处理错误: {e}")
895
+ error_chunk = {
896
+ "id": f"chatcmpl-{int(time.time() * 1000)}",
897
+ "object": "chat.completion.chunk",
898
+ "created": int(time.time()),
899
+ "model": "MBZUAI-IFM/K2-Think",
900
+ "choices": [{
901
+ "index": 0,
902
+ "delta": {},
903
+ "finish_reason": "error"
904
+ }]
905
+ }
906
+ yield f"data: {json.dumps(error_chunk)}\n\n"
907
+ yield "data: [DONE]\n\n"
908
+
909
+ @app.post("/v1/chat/completions")
910
+ async def chat_completions(request: ChatCompletionRequest, auth_request: Request):
911
+ """处理聊天补全请求"""
912
+ # 验证API密钥
913
+ authorization = auth_request.headers.get("Authorization", "")
914
+ if not validate_api_key(authorization):
915
+ raise HTTPException(
916
+ status_code=401,
917
+ detail={
918
+ "error": {
919
+ "message": "Invalid API key provided",
920
+ "type": "authentication_error"
921
+ }
922
+ }
923
+ )
924
+
925
+ try:
926
+ # Process messages with tools - 确保内容被正确转换为字符串
927
+ raw_messages = []
928
+ for msg in request.messages:
929
+ try:
930
+ content = content_to_string(msg.content)
931
+ raw_messages.append({
932
+ "role": msg.role,
933
+ "content": content,
934
+ "tool_calls": msg.tool_calls
935
+ })
936
+ except Exception as e:
937
+ logger.error(f"处理消息时出错: {e}, 消息: {msg}")
938
+ # 使用默认值
939
+ raw_messages.append({
940
+ "role": msg.role,
941
+ "content": str(msg.content) if msg.content else "",
942
+ "tool_calls": msg.tool_calls
943
+ })
944
+
945
+ # Check if tools are enabled and present
946
+ has_tools = (TOOL_SUPPORT and
947
+ request.tools and
948
+ len(request.tools) > 0 and
949
+ request.tool_choice != "none")
950
+
951
+ logger.info(f"🔧 工具调用状态: has_tools={has_tools}, tools_count={len(request.tools) if request.tools else 0}")
952
+ logger.info(f"📥 接收到的原始消息数: {len(raw_messages)}")
953
+
954
+ # 记录原始消息的角色分布
955
+ role_count = {}
956
+ for msg in raw_messages:
957
+ role = msg.get("role", "unknown")
958
+ role_count[role] = role_count.get(role, 0) + 1
959
+ logger.info(f"📊 原始消息角色分布: {role_count}")
960
+
961
+ if has_tools:
962
+ processed_messages = process_messages_with_tools(
963
+ raw_messages,
964
+ request.tools,
965
+ request.tool_choice
966
+ )
967
+ logger.info(f"🔄 消息处理完成,原始消息数: {len(raw_messages)}, 处理后消息数: {len(processed_messages)}")
968
+
969
+ # 记录处理后消息的角色分布
970
+ processed_role_count = {}
971
+ for msg in processed_messages:
972
+ role = msg.get("role", "unknown")
973
+ processed_role_count[role] = processed_role_count.get(role, 0) + 1
974
+ logger.info(f"📊 处理后消息角色分布: {processed_role_count}")
975
+ else:
976
+ processed_messages = raw_messages
977
+ logger.info("⏭️ 无工具调用,直接使用原始消息")
978
+
979
+ # 构建 K2Think 格式的请求体 - 确保所有内容可JSON序列化
980
+ k2think_messages = []
981
+ for msg in processed_messages:
982
+ try:
983
+ # 确保消息内容是字符串
984
+ content = content_to_string(msg.get("content", ""))
985
+ k2think_messages.append({
986
+ "role": msg["role"],
987
+ "content": content
988
+ })
989
+ except Exception as e:
990
+ logger.error(f"构建K2Think消息时出错: {e}, 消息: {msg}")
991
+ # 使用安全的默认值
992
+ k2think_messages.append({
993
+ "role": msg.get("role", "user"),
994
+ "content": str(msg.get("content", ""))
995
+ })
996
+
997
+ k2think_payload = {
998
+ "stream": request.stream,
999
+ "model": "MBZUAI-IFM/K2-Think",
1000
+ "messages": k2think_messages,
1001
+ "params": {},
1002
+ "tool_servers": [],
1003
+ "features": {
1004
+ "image_generation": False,
1005
+ "code_interpreter": False,
1006
+ "web_search": False
1007
+ },
1008
+ "variables": get_current_datetime_info(),
1009
+ "model_item": {
1010
+ "id": "MBZUAI-IFM/K2-Think",
1011
+ "object": "model",
1012
+ "owned_by": "MBZUAI",
1013
+ "root": "mbzuai-k2-think-2508",
1014
+ "parent": None,
1015
+ "status": "active",
1016
+ "connection_type": "external",
1017
+ "name": "MBZUAI-IFM/K2-Think"
1018
+ },
1019
+ "background_tasks": {
1020
+ "title_generation": True,
1021
+ "tags_generation": True
1022
+ },
1023
+ "chat_id": generate_chat_id(),
1024
+ "id": generate_session_id(),
1025
+ "session_id": generate_session_id()
1026
+ }
1027
+
1028
+ # 验证JSON序列化并记录发送到上游的请求
1029
+ try:
1030
+ # 测试JSON序列��
1031
+ json.dumps(k2think_payload, ensure_ascii=False)
1032
+ logger.info(f"✅ K2Think请求体JSON序列化验证通过")
1033
+ except Exception as e:
1034
+ logger.error(f"❌ K2Think请求体JSON序列化失败: {e}")
1035
+ # 尝试修复序列化问题
1036
+ try:
1037
+ k2think_payload = json.loads(json.dumps(k2think_payload, default=str, ensure_ascii=False))
1038
+ logger.info("🔧 使用default=str修复了序列化问题")
1039
+ except Exception as fix_error:
1040
+ logger.error(f"无法修复序列化问题: {fix_error}")
1041
+ raise HTTPException(status_code=500, detail="请求数据序列化失败")
1042
+
1043
+ logger.info(f"发送到 K2Think 的消息数量: {len(k2think_payload['messages'])}")
1044
+ if DEBUG_LOGGING or logger.level <= logging.DEBUG:
1045
+ for i, msg in enumerate(k2think_payload['messages']):
1046
+ content_preview = msg['content'][:200] + "..." if len(msg['content']) > 200 else msg['content']
1047
+ logger.debug(f"消息 {i+1} ({msg['role']}): {content_preview}")
1048
+
1049
+ # 设置请求头
1050
+ headers = {
1051
+ "accept": "text/event-stream,application/json" if request.stream else "application/json",
1052
+ "content-type": "application/json",
1053
+ "authorization": f"Bearer {K2THINK_TOKEN}",
1054
+ "origin": "https://www.k2think.ai",
1055
+ "referer": "https://www.k2think.ai/c/" + k2think_payload["chat_id"],
1056
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36 Edg/140.0.0.0"
1057
+ }
1058
+
1059
+ if request.stream:
1060
+ # 流式响应
1061
+ return StreamingResponse(
1062
+ process_stream_response_with_tools(k2think_payload, headers, has_tools),
1063
+ media_type="text/event-stream",
1064
+ headers={
1065
+ "Cache-Control": "no-cache",
1066
+ "Connection": "keep-alive",
1067
+ "X-Accel-Buffering": "no"
1068
+ }
1069
+ )
1070
+ else:
1071
+ # 非流式响应
1072
+ full_content, token_info = await process_non_stream_response(k2think_payload, headers)
1073
+
1074
+ # Handle tool calls for non-streaming
1075
+ tool_calls = None
1076
+ finish_reason = "stop"
1077
+ message_content = full_content
1078
+
1079
+ if has_tools:
1080
+ tool_calls = extract_tool_invocations(full_content)
1081
+ if tool_calls:
1082
+ # Content must be null when tool_calls are present (OpenAI spec)
1083
+ message_content = None
1084
+ finish_reason = "tool_calls"
1085
+ logger.info(f"提取到工具调用: {json.dumps(tool_calls, ensure_ascii=False)}")
1086
+ else:
1087
+ # Remove tool JSON from content
1088
+ message_content = remove_tool_json_content(full_content)
1089
+ if not message_content:
1090
+ message_content = full_content # 保留原内容如果清理后为空
1091
+
1092
+ openai_response = {
1093
+ "id": f"chatcmpl-{int(time.time())}",
1094
+ "object": "chat.completion",
1095
+ "created": int(time.time()),
1096
+ "model": "MBZUAI-IFM/K2-Think",
1097
+ "choices": [{
1098
+ "index": 0,
1099
+ "message": {
1100
+ "role": "assistant",
1101
+ "content": message_content,
1102
+ **({"tool_calls": tool_calls} if tool_calls else {})
1103
+ },
1104
+ "finish_reason": finish_reason
1105
+ }],
1106
+ "usage": token_info
1107
+ }
1108
+
1109
+ return JSONResponse(content=openai_response)
1110
+
1111
+ except httpx.HTTPStatusError as e:
1112
+ logger.error(f"HTTP错误: {e.response.status_code}")
1113
+ raise HTTPException(
1114
+ status_code=e.response.status_code,
1115
+ detail={
1116
+ "error": {
1117
+ "message": f"上游服务错误: {e.response.status_code}",
1118
+ "type": "upstream_error"
1119
+ }
1120
+ }
1121
+ )
1122
+ except httpx.TimeoutException:
1123
+ logger.error("请求超时")
1124
+ raise HTTPException(
1125
+ status_code=504,
1126
+ detail={
1127
+ "error": {
1128
+ "message": "请求超时",
1129
+ "type": "timeout_error"
1130
+ }
1131
+ }
1132
+ )
1133
+ except Exception as e:
1134
+ logger.error(f"API转发错误: {e}")
1135
+ raise HTTPException(
1136
+ status_code=500,
1137
+ detail={
1138
+ "error": {
1139
+ "message": str(e),
1140
+ "type": "api_error"
1141
+ }
1142
+ }
1143
+ )
1144
+
1145
+ @app.exception_handler(404)
1146
+ async def not_found_handler(request: Request, exc):
1147
+ return JSONResponse(
1148
+ status_code=404,
1149
+ content={"error": "Not Found"}
1150
+ )
1151
+
1152
+ if __name__ == "__main__":
1153
+ import uvicorn
1154
+ host = os.getenv("HOST", "0.0.0.0")
1155
+ port = int(os.getenv("PORT", "8001"))
1156
+
1157
+ # 配置日志级别
1158
+ log_level = "debug" if DEBUG_LOGGING else "info"
1159
+
1160
+ uvicorn.run(
1161
+ app,
1162
+ host=host,
1163
+ port=port,
1164
+ access_log=ENABLE_ACCESS_LOG,
1165
+ log_level=log_level
1166
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ httpx
4
+ pydantic
5
+ python-dotenv
6
+ pytz