gaila commited on
Commit
51cb8bb
·
verified ·
1 Parent(s): b093d94

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +15 -0
  2. claude_compat.py +250 -0
  3. docker-compose.yml +26 -0
  4. main.py +504 -0
  5. openai.py +1761 -0
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1
5
+
6
+ WORKDIR /app
7
+
8
+ COPY . /app
9
+
10
+ RUN pip install --no-cache-dir --upgrade pip \
11
+ && pip install --no-cache-dir fastapi uvicorn httpx httpcore
12
+
13
+ EXPOSE 30016
14
+
15
+ CMD ["python", "openai.py"]
claude_compat.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anthropic Claude Messages API (/v1/messages) helpers.
2
+
3
+ Converts between Anthropic Claude native format and the internal
4
+ OpenAI-like format already used by openai.py.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import uuid
11
+ from typing import Any
12
+
13
+
14
+ # ── Message conversion ───────────────────────────────────────────────
15
+
16
+
17
+ def claude_messages_to_openai(system: Any, messages: list[dict]) -> list[dict]:
18
+ """Convert Anthropic Claude messages to OpenAI message format."""
19
+ out: list[dict] = []
20
+
21
+ # System is a top-level field in Claude format
22
+ if system:
23
+ if isinstance(system, str):
24
+ out.append({"role": "system", "content": system})
25
+ elif isinstance(system, list):
26
+ texts = [
27
+ b.get("text", "")
28
+ for b in system
29
+ if isinstance(b, dict) and b.get("type") == "text"
30
+ ]
31
+ if texts:
32
+ out.append({"role": "system", "content": "\n".join(texts)})
33
+
34
+ for msg in messages:
35
+ role = msg.get("role", "user")
36
+ content = msg.get("content", "")
37
+
38
+ # Assistant with content blocks (may include tool_use)
39
+ if role == "assistant" and isinstance(content, list):
40
+ text_parts: list[str] = []
41
+ tool_calls: list[dict] = []
42
+ for block in content:
43
+ if not isinstance(block, dict):
44
+ continue
45
+ bt = block.get("type")
46
+ if bt == "text":
47
+ text_parts.append(block.get("text", ""))
48
+ elif bt == "tool_use":
49
+ tool_calls.append({
50
+ "id": block.get("id", f"call_{uuid.uuid4().hex[:24]}"),
51
+ "type": "function",
52
+ "function": {
53
+ "name": block.get("name", ""),
54
+ "arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
55
+ },
56
+ })
57
+ omsg: dict = {"role": "assistant", "content": " ".join(text_parts).strip() or None}
58
+ if tool_calls:
59
+ omsg["tool_calls"] = tool_calls
60
+ out.append(omsg)
61
+ continue
62
+
63
+ # User with tool_result blocks
64
+ if role == "user" and isinstance(content, list):
65
+ has_tool_result = any(
66
+ isinstance(b, dict) and b.get("type") == "tool_result" for b in content
67
+ )
68
+ if has_tool_result:
69
+ for block in content:
70
+ if not isinstance(block, dict):
71
+ continue
72
+ bt = block.get("type")
73
+ if bt == "tool_result":
74
+ rc = block.get("content", "")
75
+ if isinstance(rc, str):
76
+ rt = rc
77
+ elif isinstance(rc, list):
78
+ rt = " ".join(
79
+ s.get("text", "") for s in rc
80
+ if isinstance(s, dict) and s.get("type") == "text"
81
+ )
82
+ else:
83
+ rt = str(rc)
84
+ out.append({"role": "tool", "tool_call_id": block.get("tool_use_id", ""), "content": rt})
85
+ elif bt == "text":
86
+ out.append({"role": "user", "content": block.get("text", "")})
87
+ continue
88
+
89
+ # Default: extract text
90
+ out.append({"role": role, "content": _extract_text(content)})
91
+
92
+ return out
93
+
94
+
95
+ def claude_tools_to_openai(tools: list[dict] | None) -> list[dict] | None:
96
+ """Convert Anthropic tool definitions to OpenAI format."""
97
+ if not tools:
98
+ return None
99
+ out = [
100
+ {
101
+ "type": "function",
102
+ "function": {
103
+ "name": t.get("name", ""),
104
+ "description": t.get("description", ""),
105
+ "parameters": t.get("input_schema", {}),
106
+ },
107
+ }
108
+ for t in tools
109
+ if isinstance(t, dict)
110
+ ]
111
+ return out or None
112
+
113
+
114
+ def claude_tool_choice_prompt(tool_choice: Any) -> str:
115
+ if not isinstance(tool_choice, dict):
116
+ return ""
117
+ tc_type = tool_choice.get("type", "auto")
118
+ if tc_type == "any":
119
+ return "\nIMPORTANT: You MUST call at least one tool in your next response."
120
+ if tc_type == "tool":
121
+ name = tool_choice.get("name", "")
122
+ if name:
123
+ return f"\nIMPORTANT: You MUST call this tool: {name}"
124
+ return ""
125
+
126
+
127
+ # ── Response builders ────────────────────────────────────────────────
128
+
129
+
130
+ def make_claude_id() -> str:
131
+ return f"msg_{uuid.uuid4().hex[:24]}"
132
+
133
+
134
+ def build_tool_call_blocks(
135
+ tool_calls: list[dict],
136
+ ) -> list[dict]:
137
+ """Convert internal tool call dicts to Claude content blocks."""
138
+ blocks = []
139
+ for tc in tool_calls:
140
+ fn = tc.get("function", {}) if isinstance(tc.get("function"), dict) else {}
141
+ args_str = fn.get("arguments", "{}")
142
+ try:
143
+ args_obj = json.loads(args_str) if isinstance(args_str, str) else args_str
144
+ except Exception:
145
+ args_obj = {}
146
+ blocks.append({
147
+ "type": "tool_use",
148
+ "id": tc.get("id", f"toolu_{uuid.uuid4().hex[:20]}").replace("call_", "toolu_"),
149
+ "name": fn.get("name", ""),
150
+ "input": args_obj,
151
+ })
152
+ return blocks
153
+
154
+
155
+ def build_non_stream_response(
156
+ msg_id: str,
157
+ model: str,
158
+ reasoning_parts: list[str],
159
+ answer_text: str,
160
+ tool_calls: list[dict] | None,
161
+ input_tokens: int,
162
+ output_tokens: int,
163
+ ) -> dict:
164
+ """Build a complete Anthropic non-streaming response."""
165
+ content: list[dict] = []
166
+ if reasoning_parts:
167
+ content.append({"type": "thinking", "thinking": "".join(reasoning_parts)})
168
+ if answer_text:
169
+ content.append({"type": "text", "text": answer_text})
170
+ elif not tool_calls:
171
+ content.append({"type": "text", "text": ""})
172
+ if tool_calls:
173
+ content.extend(build_tool_call_blocks(tool_calls))
174
+
175
+ return {
176
+ "id": msg_id,
177
+ "type": "message",
178
+ "role": "assistant",
179
+ "content": content,
180
+ "model": model,
181
+ "stop_reason": "tool_use" if tool_calls else "end_turn",
182
+ "stop_sequence": None,
183
+ "usage": {"input_tokens": input_tokens, "output_tokens": output_tokens},
184
+ }
185
+
186
+
187
+ # ── SSE event helpers ────────────────────────────────────────────────
188
+
189
+
190
+ def sse(event: str, data: dict) -> str:
191
+ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
192
+
193
+
194
+ def sse_message_start(msg_id: str, model: str, input_tokens: int) -> str:
195
+ return sse("message_start", {
196
+ "type": "message_start",
197
+ "message": {
198
+ "id": msg_id, "type": "message", "role": "assistant",
199
+ "content": [], "model": model,
200
+ "stop_reason": None, "stop_sequence": None,
201
+ "usage": {"input_tokens": input_tokens, "output_tokens": 0},
202
+ },
203
+ })
204
+
205
+
206
+ def sse_ping() -> str:
207
+ return sse("ping", {"type": "ping"})
208
+
209
+
210
+ def sse_content_block_start(index: int, block: dict) -> str:
211
+ return sse("content_block_start", {"type": "content_block_start", "index": index, "content_block": block})
212
+
213
+
214
+ def sse_content_block_delta(index: int, delta: dict) -> str:
215
+ return sse("content_block_delta", {"type": "content_block_delta", "index": index, "delta": delta})
216
+
217
+
218
+ def sse_content_block_stop(index: int) -> str:
219
+ return sse("content_block_stop", {"type": "content_block_stop", "index": index})
220
+
221
+
222
+ def sse_message_delta(stop_reason: str, output_tokens: int) -> str:
223
+ return sse("message_delta", {
224
+ "type": "message_delta",
225
+ "delta": {"stop_reason": stop_reason, "stop_sequence": None},
226
+ "usage": {"output_tokens": output_tokens},
227
+ })
228
+
229
+
230
+ def sse_message_stop() -> str:
231
+ return sse("message_stop", {"type": "message_stop"})
232
+
233
+
234
+ def sse_error(error_type: str, message: str) -> str:
235
+ return sse("error", {"type": "error", "error": {"type": error_type, "message": message}})
236
+
237
+
238
+ # ── Private ──────────────────────────────────────────────────────────
239
+
240
+
241
+ def _extract_text(content: object) -> str:
242
+ if isinstance(content, str):
243
+ return content
244
+ if isinstance(content, list):
245
+ return " ".join(
246
+ str(b.get("text", ""))
247
+ for b in content
248
+ if isinstance(b, dict) and b.get("type") == "text"
249
+ ).strip()
250
+ return str(content) if content else ""
docker-compose.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ zai-openai:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
6
+ image: zai-openai:local
7
+ container_name: zai-openai
8
+ working_dir: /app
9
+ restart: unless-stopped
10
+ environment:
11
+ - LOG_LEVEL=DEBUG
12
+ - HTTP_DEBUG=0
13
+ - ENABLE_THINKING=1
14
+ - UPSTREAM_FIRST_EVENT_TIMEOUT=60
15
+ - UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX=2
16
+ - TOKEN_MAX_AGE=480
17
+ - POOL_MIN_SIZE=5
18
+ - POOL_MAX_SIZE=24
19
+ - POOL_TARGET_INFLIGHT_PER_ACCOUNT=2
20
+ - POOL_MAINTAIN_INTERVAL=10
21
+ - POOL_SCALE_DOWN_IDLE_ROUNDS=3
22
+ volumes:
23
+ - ./:/app
24
+ command: ["python", "openai.py"]
25
+ ports:
26
+ - "30016:30016"
main.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """chat.z.ai reverse-engineered Python client."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import base64
7
+ import hashlib
8
+ import hmac
9
+ import json
10
+ import logging
11
+ import os
12
+ import time
13
+ import uuid
14
+ from datetime import datetime, timezone, timedelta
15
+ from urllib.parse import urlencode
16
+
17
+ import httpx
18
+
19
+ logger = logging.getLogger("zai.client")
20
+
21
+ BASE_URL = "https://chat.z.ai"
22
+ HMAC_SECRET = "key-@@@@)))()((9))-xxxx&&&%%%%%"
23
+ FE_VERSION = "prod-fe-1.0.231"
24
+ CLIENT_VERSION = "0.0.1"
25
+ DEFAULT_MODEL = "glm-5"
26
+ ENABLE_THINKING_DEFAULT = os.getenv("ENABLE_THINKING", "1") == "1"
27
+ USER_AGENT = (
28
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
29
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
30
+ "Chrome/144.0.0.0 Safari/537.36"
31
+ )
32
+
33
+
34
+ class ZaiClient:
35
+ def __init__(self) -> None:
36
+ # 分离超时配置:connect快速失败,read支持长时间流式响应
37
+ timeout_config = httpx.Timeout(
38
+ connect=5.0, # 连接超时 5秒
39
+ read=180.0, # 读取超时 3分钟(支持长文生成)
40
+ write=10.0, # 写入超时 10秒
41
+ pool=5.0, # 连接池获取超时 5秒
42
+ )
43
+ self.client = httpx.AsyncClient(
44
+ base_url=BASE_URL,
45
+ timeout=timeout_config,
46
+ headers={
47
+ "User-Agent": USER_AGENT,
48
+ "Accept-Language": "zh-CN",
49
+ "Referer": f"{BASE_URL}/",
50
+ "Origin": BASE_URL,
51
+ },
52
+ # 限制连接池大小,避免连接泄漏
53
+ limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
54
+ )
55
+ self.token: str | None = None
56
+ self.user_id: str | None = None
57
+ self.username: str | None = None
58
+
59
+ async def close(self) -> None:
60
+ await self.client.aclose()
61
+
62
+ # ── auth ────────────────────────────────────────────────────────
63
+
64
+ async def auth_as_guest(self) -> dict:
65
+ """GET /api/v1/auths/ — creates a guest session and returns user info."""
66
+ resp = await self.client.get(
67
+ "/api/v1/auths/",
68
+ headers={"Content-Type": "application/json"},
69
+ )
70
+ resp.raise_for_status()
71
+ data = resp.json()
72
+ self.token = data["token"]
73
+ self.user_id = data["id"]
74
+ self.username = data.get("name") or data.get("email", "").split("@")[0]
75
+ return data
76
+
77
+ # ── models ──────────────────────────────────────────────────────
78
+
79
+ async def get_models(self) -> list:
80
+ """GET /api/models — returns available model list."""
81
+ resp = await self.client.get(
82
+ "/api/models",
83
+ headers={
84
+ "Content-Type": "application/json",
85
+ "Accept": "application/json",
86
+ **({"Authorization": f"Bearer {self.token}"} if self.token else {}),
87
+ },
88
+ )
89
+ resp.raise_for_status()
90
+ return resp.json()
91
+
92
+ # ── chat CRUD ───────────────────────────────────────────────────
93
+
94
+ async def create_chat(
95
+ self,
96
+ user_message: str,
97
+ model: str = DEFAULT_MODEL,
98
+ *,
99
+ enable_thinking: bool | None = None,
100
+ ) -> dict:
101
+ """POST /api/v1/chats/new — creates a new chat session.
102
+
103
+ The content placed in history is only for session initialization;
104
+ actual conversation content is sent via chat_completions later.
105
+ We truncate long prompts to avoid 400 errors from the upstream API.
106
+ """
107
+ # Truncate the user_message for chat creation to avoid 400 errors
108
+ # when the prompt is too long (e.g. contains tool definitions,
109
+ # multi-turn history, or large system prompts).
110
+ MAX_INIT_CONTENT_LEN = 500
111
+ init_content = user_message
112
+ if len(init_content) > MAX_INIT_CONTENT_LEN:
113
+ init_content = init_content[:MAX_INIT_CONTENT_LEN] + "..."
114
+
115
+ msg_id = str(uuid.uuid4())
116
+ ts = int(time.time())
117
+ body = {
118
+ "chat": {
119
+ "id": "",
120
+ "title": "新聊天",
121
+ "models": [model],
122
+ "params": {},
123
+ "history": {
124
+ "messages": {
125
+ msg_id: {
126
+ "id": msg_id,
127
+ "parentId": None,
128
+ "childrenIds": [],
129
+ "role": "user",
130
+ "content": init_content,
131
+ "timestamp": ts,
132
+ "models": [model],
133
+ }
134
+ },
135
+ "currentId": msg_id,
136
+ },
137
+ "tags": [],
138
+ "flags": [],
139
+ "features": [
140
+ {
141
+ "type": "tool_selector",
142
+ "server": "tool_selector_h",
143
+ "status": "hidden",
144
+ }
145
+ ],
146
+ "mcp_servers": [],
147
+ "enable_thinking": (
148
+ ENABLE_THINKING_DEFAULT
149
+ if enable_thinking is None
150
+ else bool(enable_thinking)
151
+ ),
152
+ "auto_web_search": False,
153
+ "message_version": 1,
154
+ "extra": {},
155
+ "timestamp": int(time.time() * 1000),
156
+ }
157
+ }
158
+ resp = await self.client.post(
159
+ "/api/v1/chats/new",
160
+ headers={
161
+ "Content-Type": "application/json",
162
+ "Accept": "application/json",
163
+ **({"Authorization": f"Bearer {self.token}"} if self.token else {}),
164
+ },
165
+ json=body,
166
+ )
167
+ if resp.status_code != 200:
168
+ error_body = resp.text
169
+ logger.warning(
170
+ "create_chat failed: status=%d body=%s (prompt_len=%d, truncated_len=%d)",
171
+ resp.status_code, error_body[:500], len(user_message), len(init_content),
172
+ )
173
+ resp.raise_for_status()
174
+ return resp.json()
175
+
176
+ # ── chat cleanup ─────────────────────────────────────────────────
177
+
178
+ async def delete_chat(self, chat_id: str) -> bool:
179
+ """DELETE /api/v1/chats/{chat_id} — deletes a chat session.
180
+
181
+ Returns True if deleted successfully, False otherwise.
182
+ This should be called after each request to free up concurrency slots.
183
+ """
184
+ try:
185
+ resp = await self.client.delete(
186
+ f"/api/v1/chats/{chat_id}",
187
+ headers={
188
+ "Content-Type": "application/json",
189
+ "Accept": "application/json",
190
+ **({
191
+ "Authorization": f"Bearer {self.token}"
192
+ } if self.token else {}),
193
+ },
194
+ )
195
+ if resp.status_code == 200:
196
+ return True
197
+ logger.debug(
198
+ "delete_chat %s: status=%d body=%s",
199
+ chat_id, resp.status_code, resp.text[:200],
200
+ )
201
+ return False
202
+ except Exception as e:
203
+ logger.debug("delete_chat %s failed: %s", chat_id, e)
204
+ return False
205
+
206
+ async def delete_all_chats(self) -> bool:
207
+ """DELETE /api/v1/chats/ — deletes all chats for the current user.
208
+
209
+ Useful for cleaning up accumulated chats when hitting concurrency limits.
210
+ """
211
+ try:
212
+ resp = await self.client.delete(
213
+ "/api/v1/chats/",
214
+ headers={
215
+ "Content-Type": "application/json",
216
+ "Accept": "application/json",
217
+ **({
218
+ "Authorization": f"Bearer {self.token}"
219
+ } if self.token else {}),
220
+ },
221
+ )
222
+ if resp.status_code == 200:
223
+ logger.info("delete_all_chats: success")
224
+ return True
225
+ logger.warning(
226
+ "delete_all_chats: status=%d body=%s",
227
+ resp.status_code, resp.text[:200],
228
+ )
229
+ return False
230
+ except Exception as e:
231
+ logger.warning("delete_all_chats failed: %s", e)
232
+ return False
233
+
234
+ # ── signature ───────────────────────────────────────────────────
235
+
236
+ @staticmethod
237
+ def _generate_signature(
238
+ sorted_payload: str, prompt: str, timestamp: str
239
+ ) -> str:
240
+ """
241
+ Two-layer HMAC-SHA256 matching DLHfQWwv.js.
242
+
243
+ 1. b64_prompt = base64(utf8(prompt))
244
+ 2. message = "{sorted_payload}|{b64_prompt}|{timestamp}"
245
+ 3. time_bucket = floor(int(timestamp) / 300_000)
246
+ 4. derived_key = HMAC-SHA256(HMAC_SECRET, str(time_bucket)) → hex string
247
+ 5. signature = HMAC-SHA256(derived_key_hex_bytes, message) → hex
248
+ """
249
+ b64_prompt = base64.b64encode(prompt.encode("utf-8")).decode("ascii")
250
+ message = f"{sorted_payload}|{b64_prompt}|{timestamp}"
251
+ time_bucket = int(timestamp) // (5 * 60 * 1000)
252
+
253
+ derived_key_hex = hmac.new(
254
+ HMAC_SECRET.encode("utf-8"),
255
+ str(time_bucket).encode("utf-8"),
256
+ hashlib.sha256,
257
+ ).hexdigest()
258
+
259
+ signature = hmac.new(
260
+ derived_key_hex.encode("utf-8"),
261
+ message.encode("utf-8"),
262
+ hashlib.sha256,
263
+ ).hexdigest()
264
+ return signature
265
+
266
+ def _build_query_and_signature(
267
+ self, prompt: str, chat_id: str
268
+ ) -> tuple[str, str]:
269
+ """Build the full URL query string and X-Signature header.
270
+
271
+ Returns (full_query_string, signature).
272
+ """
273
+ timestamp_ms = str(int(time.time() * 1000))
274
+ request_id = str(uuid.uuid4())
275
+
276
+ now = datetime.now(timezone.utc)
277
+
278
+ # Core params (used for sortedPayload)
279
+ core = {
280
+ "timestamp": timestamp_ms,
281
+ "requestId": request_id,
282
+ "user_id": self.user_id,
283
+ }
284
+
285
+ # sortedPayload: Object.entries(core).sort(by key).join(",")
286
+ sorted_payload = ",".join(
287
+ f"{k},{v}" for k, v in sorted(core.items(), key=lambda x: x[0])
288
+ )
289
+
290
+ # Compute signature over the prompt
291
+ signature = self._generate_signature(sorted_payload, prompt, timestamp_ms)
292
+
293
+ # Browser/device fingerprint params
294
+ extra = {
295
+ "version": CLIENT_VERSION,
296
+ "platform": "web",
297
+ "token": self.token or "",
298
+ "user_agent": USER_AGENT,
299
+ "language": "zh-CN",
300
+ "languages": "zh-CN",
301
+ "timezone": "Asia/Shanghai",
302
+ "cookie_enabled": "true",
303
+ "screen_width": "1920",
304
+ "screen_height": "1080",
305
+ "screen_resolution": "1920x1080",
306
+ "viewport_height": "919",
307
+ "viewport_width": "944",
308
+ "viewport_size": "944x919",
309
+ "color_depth": "24",
310
+ "pixel_ratio": "1.25",
311
+ "current_url": f"{BASE_URL}/c/{chat_id}",
312
+ "pathname": f"/c/{chat_id}",
313
+ "search": "",
314
+ "hash": "",
315
+ "host": "chat.z.ai",
316
+ "hostname": "chat.z.ai",
317
+ "protocol": "https:",
318
+ "referrer": "",
319
+ "title": "Z.ai - Free AI Chatbot & Agent powered by GLM-5 & GLM-4.7",
320
+ "timezone_offset": "-480",
321
+ "local_time": now.strftime("%Y-%m-%dT%H:%M:%S.")
322
+ + f"{now.microsecond // 1000:03d}Z",
323
+ "utc_time": now.strftime("%a, %d %b %Y %H:%M:%S GMT"),
324
+ "is_mobile": "false",
325
+ "is_touch": "false",
326
+ "max_touch_points": "10",
327
+ "browser_name": "Chrome",
328
+ "os_name": "Windows",
329
+ "signature_timestamp": timestamp_ms,
330
+ }
331
+
332
+ all_params = {**core, **extra}
333
+ query_string = urlencode(all_params)
334
+
335
+ return query_string, signature
336
+
337
+ # ── chat completions (SSE) ──────────────────────────────────────
338
+
339
+ async def chat_completions(
340
+ self,
341
+ chat_id: str,
342
+ messages: list[dict],
343
+ prompt: str,
344
+ *,
345
+ model: str = DEFAULT_MODEL,
346
+ parent_message_id: str | None = None,
347
+ tools: list[dict] | None = None,
348
+ enable_thinking: bool | None = None,
349
+ ):
350
+ """POST /api/v2/chat/completions — streams SSE response.
351
+
352
+ Yields the full event ``data`` dict for each SSE frame.
353
+ """
354
+ query_string, signature = self._build_query_and_signature(prompt, chat_id)
355
+
356
+ msg_id = str(uuid.uuid4())
357
+ user_msg_id = str(uuid.uuid4())
358
+
359
+ now = datetime.now(timezone(timedelta(hours=8)))
360
+ variables = {
361
+ "{{USER_NAME}}": self.username or "Guest",
362
+ "{{USER_LOCATION}}": "Unknown",
363
+ "{{CURRENT_DATETIME}}": now.strftime("%Y-%m-%d %H:%M:%S"),
364
+ "{{CURRENT_DATE}}": now.strftime("%Y-%m-%d"),
365
+ "{{CURRENT_TIME}}": now.strftime("%H:%M:%S"),
366
+ "{{CURRENT_WEEKDAY}}": now.strftime("%A"),
367
+ "{{CURRENT_TIMEZONE}}": "Asia/Shanghai",
368
+ "{{USER_LANGUAGE}}": "zh-CN",
369
+ }
370
+
371
+ body = {
372
+ "stream": True,
373
+ "model": model,
374
+ "messages": messages,
375
+ "signature_prompt": prompt,
376
+ "params": {},
377
+ "extra": {},
378
+ "features": {
379
+ "image_generation": False,
380
+ "web_search": False,
381
+ "auto_web_search": False,
382
+ "preview_mode": True,
383
+ "flags": [],
384
+ "enable_thinking": (
385
+ ENABLE_THINKING_DEFAULT
386
+ if enable_thinking is None
387
+ else bool(enable_thinking)
388
+ ),
389
+ },
390
+ "variables": variables,
391
+ "chat_id": chat_id,
392
+ "id": msg_id,
393
+ "current_user_message_id": user_msg_id,
394
+ "current_user_message_parent_id": parent_message_id,
395
+ "background_tasks": {
396
+ "title_generation": True,
397
+ "tags_generation": True,
398
+ },
399
+ }
400
+
401
+ if tools:
402
+ body["tools"] = tools
403
+
404
+ headers = {
405
+ "Content-Type": "application/json",
406
+ "Accept": "*/*",
407
+ "Accept-Language": "zh-CN",
408
+ "X-FE-Version": FE_VERSION,
409
+ "X-Signature": signature,
410
+ **({"Authorization": f"Bearer {self.token}"} if self.token else {}),
411
+ }
412
+
413
+ url = f"{BASE_URL}/api/v2/chat/completions?{query_string}"
414
+
415
+ async with self.client.stream(
416
+ "POST", url, headers=headers, json=body,
417
+ ) as resp:
418
+ if resp.status_code != 200:
419
+ error_body = (await resp.aread()).decode("utf-8", errors="replace")
420
+ raise httpx.HTTPStatusError(
421
+ f"chat/completions {resp.status_code}: {error_body[:500]}",
422
+ request=resp.request,
423
+ response=resp,
424
+ )
425
+ async for line in resp.aiter_lines():
426
+ if not line.startswith("data: "):
427
+ continue
428
+ raw = line[6:]
429
+ if raw.strip() == "[DONE]":
430
+ return
431
+ try:
432
+ event = json.loads(raw)
433
+ except json.JSONDecodeError:
434
+ continue
435
+ data = event.get("data", {})
436
+ yield data
437
+ if data.get("done"):
438
+ return
439
+
440
+
441
+ async def main() -> None:
442
+ client = ZaiClient()
443
+ try:
444
+ # 1. Authenticate as guest
445
+ print("[1] Authenticating as guest...")
446
+ auth = await client.auth_as_guest()
447
+ print(f" user_id : {auth['id']}")
448
+ print(f" email : {auth.get('email', 'N/A')}")
449
+ print(f" token : {auth['token'][:40]}...")
450
+
451
+ # 2. Fetch models
452
+ print("\n[2] Fetching models...")
453
+ models_resp = await client.get_models()
454
+ if isinstance(models_resp, dict) and "data" in models_resp:
455
+ names = [m.get("id", m.get("name", "?")) for m in models_resp["data"]]
456
+ elif isinstance(models_resp, list):
457
+ names = [m.get("id", m.get("name", "?")) for m in models_resp]
458
+ else:
459
+ names = [str(models_resp)[:80]]
460
+ print(f" models : {', '.join(names[:10])}")
461
+
462
+ # 3. Create chat
463
+ user_message = "Hello"
464
+ print(f"\n[3] Creating chat with first message: {user_message!r}")
465
+ messages = [{"role": "user", "content": user_message}]
466
+ chat = await client.create_chat(user_message)
467
+ chat_id = chat["id"]
468
+ print(f" chat_id : {chat_id}")
469
+
470
+ # 4. Stream chat completions
471
+ print(f"\n[4] Streaming chat completions (model={DEFAULT_MODEL})...\n")
472
+ messages = [{"role": "user", "content": user_message}]
473
+
474
+ thinking_started = False
475
+ answer_started = False
476
+ async for data in client.chat_completions(
477
+ chat_id=chat_id,
478
+ messages=messages,
479
+ prompt=user_message,
480
+ ):
481
+ phase = data.get("phase", "")
482
+ delta = data.get("delta_content", "")
483
+ if phase == "thinking":
484
+ if not thinking_started:
485
+ print("[thinking] ", end="", flush=True)
486
+ thinking_started = True
487
+ print(delta, end="", flush=True)
488
+ elif phase == "answer":
489
+ if not answer_started:
490
+ if thinking_started:
491
+ print("\n")
492
+ print("[answer] ", end="", flush=True)
493
+ answer_started = True
494
+ print(delta, end="", flush=True)
495
+ elif phase == "done":
496
+ break
497
+ print("\n\n[done]")
498
+
499
+ finally:
500
+ await client.close()
501
+
502
+
503
+ if __name__ == "__main__":
504
+ asyncio.run(main())
openai.py ADDED
@@ -0,0 +1,1761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-compatible proxy server for chat.z.ai + Toolify-style function calling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import math
9
+ import os
10
+ import re
11
+ import secrets
12
+ import string
13
+ import time
14
+ import uuid
15
+ from contextlib import asynccontextmanager
16
+ from typing import Any
17
+
18
+ import httpcore
19
+ import httpx
20
+ import uvicorn
21
+ from fastapi import FastAPI, Request
22
+ from fastapi.responses import JSONResponse, StreamingResponse
23
+
24
+ from main import ZaiClient
25
+ from claude_compat import (
26
+ claude_messages_to_openai,
27
+ claude_tools_to_openai,
28
+ claude_tool_choice_prompt,
29
+ make_claude_id,
30
+ build_tool_call_blocks,
31
+ build_non_stream_response,
32
+ sse_message_start,
33
+ sse_ping,
34
+ sse_content_block_start,
35
+ sse_content_block_delta,
36
+ sse_content_block_stop,
37
+ sse_message_delta,
38
+ sse_message_stop,
39
+ sse_error,
40
+ )
41
+
42
+ # ── Logging ──────────────────────────────────────────────────────────
43
+
44
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
45
+ HTTP_DEBUG = os.getenv("HTTP_DEBUG", "0") == "1"
46
+ logging.basicConfig(
47
+ level=getattr(logging, LOG_LEVEL, logging.INFO),
48
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
49
+ )
50
+ logger = logging.getLogger("zai.openai")
51
+ if not HTTP_DEBUG:
52
+ logging.getLogger("httpx").setLevel(logging.WARNING)
53
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
54
+
55
+
56
+ # ── Multi-Account Pool ───────────────────────────────────────────────
57
+
58
+ POOL_SIZE = int(os.getenv("POOL_SIZE", "3"))
59
+ POOL_MIN_SIZE = max(1, int(os.getenv("POOL_MIN_SIZE", str(POOL_SIZE))))
60
+ POOL_MAX_SIZE = max(POOL_MIN_SIZE, int(os.getenv("POOL_MAX_SIZE", str(max(POOL_MIN_SIZE, POOL_MIN_SIZE * 3)))))
61
+ POOL_TARGET_INFLIGHT_PER_ACCOUNT = max(1, int(os.getenv("POOL_TARGET_INFLIGHT_PER_ACCOUNT", "2")))
62
+ POOL_MAINTAIN_INTERVAL = max(5, int(os.getenv("POOL_MAINTAIN_INTERVAL", "10")))
63
+ POOL_SCALE_DOWN_IDLE_ROUNDS = max(1, int(os.getenv("POOL_SCALE_DOWN_IDLE_ROUNDS", "3")))
64
+ TOKEN_MAX_AGE = int(os.getenv("TOKEN_MAX_AGE", "480")) # seconds
65
+ UPSTREAM_FIRST_EVENT_TIMEOUT = max(1.0, float(os.getenv("UPSTREAM_FIRST_EVENT_TIMEOUT", "60")))
66
+ UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX = max(0, int(os.getenv("UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX", "2")))
67
+
68
+
69
+ def _compute_pool_target_by_load(in_flight: int) -> int:
70
+ """根据当前并发负载估算池目标大小。"""
71
+ if in_flight <= 0:
72
+ return POOL_MIN_SIZE
73
+ # +1 headroom,避免全部账号都打满时排队。
74
+ by_load = math.ceil(in_flight / POOL_TARGET_INFLIGHT_PER_ACCOUNT) + 1
75
+ return min(POOL_MAX_SIZE, max(POOL_MIN_SIZE, by_load))
76
+
77
+
78
+ class AccountInfo:
79
+ """A single guest auth session."""
80
+ __slots__ = ("token", "user_id", "username", "created_at", "active", "valid")
81
+
82
+ def __init__(self, token: str, user_id: str, username: str) -> None:
83
+ self.token = token
84
+ self.user_id = user_id
85
+ self.username = username
86
+ self.created_at = time.time()
87
+ self.active = 0 # number of in-flight requests
88
+ self.valid = True
89
+
90
+ def snapshot(self) -> dict[str, str]:
91
+ return {"token": self.token, "user_id": self.user_id, "username": self.username}
92
+
93
+ @property
94
+ def age(self) -> float:
95
+ return time.time() - self.created_at
96
+
97
+
98
+ class SessionPool:
99
+ """Pool of guest accounts for concurrent, seamless use."""
100
+
101
+ def __init__(self) -> None:
102
+ self._lock = asyncio.Lock()
103
+ self._accounts: list[AccountInfo] = []
104
+ self._bg_task: asyncio.Task | None = None
105
+ self._maintain_event = asyncio.Event()
106
+ self._target_size = POOL_MIN_SIZE
107
+ self._idle_rounds = 0
108
+
109
+ # ── internal ─────────────────────────────────────────────────────
110
+
111
+ def _valid_accounts(self, *, include_expired: bool = False) -> list[AccountInfo]:
112
+ if include_expired:
113
+ return [a for a in self._accounts if a.valid]
114
+ return [a for a in self._accounts if a.valid and a.age < TOKEN_MAX_AGE]
115
+
116
+ def _raise_target_size(self, target_size: int) -> None:
117
+ clamped = min(POOL_MAX_SIZE, max(POOL_MIN_SIZE, target_size))
118
+ if clamped > self._target_size:
119
+ self._target_size = clamped
120
+ self._maintain_event.set()
121
+
122
+ async def _new_account(self) -> AccountInfo:
123
+ c = ZaiClient()
124
+ try:
125
+ d = await c.auth_as_guest()
126
+ acc = AccountInfo(d["token"], d["id"], d.get("name") or d.get("email", "").split("@")[0])
127
+ logger.info("Pool: +account uid=%s (total=%d)", acc.user_id, len(self._accounts) + 1)
128
+ return acc
129
+ finally:
130
+ await c.close()
131
+
132
+ async def _del_account(self, acc: AccountInfo) -> None:
133
+ try:
134
+ c = ZaiClient()
135
+ c.token, c.user_id, c.username = acc.token, acc.user_id, acc.username
136
+ await c.delete_all_chats()
137
+ await c.close()
138
+ except Exception:
139
+ pass
140
+
141
+ async def _maintain(self) -> None:
142
+ """后台维护:按负载扩缩容 + 清理失效账号。"""
143
+ while True:
144
+ try:
145
+ try:
146
+ await asyncio.wait_for(self._maintain_event.wait(), timeout=POOL_MAINTAIN_INTERVAL)
147
+ except asyncio.TimeoutError:
148
+ pass
149
+ self._maintain_event.clear()
150
+
151
+ to_delete: list[AccountInfo] = []
152
+ to_add = 0
153
+ cycle_target = POOL_MIN_SIZE
154
+
155
+ async with self._lock:
156
+ dead = [a for a in self._accounts if (not a.valid or a.age > TOKEN_MAX_AGE) and a.active == 0]
157
+ for a in dead:
158
+ self._accounts.remove(a)
159
+ to_delete.append(a)
160
+
161
+ valid = self._valid_accounts()
162
+ valid_count = len(valid)
163
+ in_flight = sum(a.active for a in valid)
164
+
165
+ load_target = _compute_pool_target_by_load(in_flight)
166
+ desired = min(POOL_MAX_SIZE, max(POOL_MIN_SIZE, max(load_target, self._target_size)))
167
+
168
+ # 缩容仅在连续空闲轮次后执行,避免负载抖动。
169
+ if in_flight == 0 and valid_count > desired:
170
+ self._idle_rounds += 1
171
+ else:
172
+ self._idle_rounds = 0
173
+
174
+ if self._idle_rounds >= POOL_SCALE_DOWN_IDLE_ROUNDS and valid_count > desired:
175
+ removable = [a for a in valid if a.active == 0]
176
+ removable.sort(key=lambda a: a.created_at)
177
+ shrink_by = min(valid_count - desired, len(removable))
178
+ for a in removable[:shrink_by]:
179
+ self._accounts.remove(a)
180
+ to_delete.append(a)
181
+ valid_count -= shrink_by
182
+ if valid_count <= desired:
183
+ self._idle_rounds = 0
184
+ else:
185
+ # 未满足缩容条件时,至少保持当前 valid 数量。
186
+ desired = max(desired, valid_count)
187
+
188
+ cycle_target = desired
189
+ # _target_size 仅作为“临时抬升”的请求值,下一轮回到按负载计算。
190
+ self._target_size = load_target
191
+ to_add = max(0, desired - valid_count)
192
+
193
+ for a in to_delete:
194
+ asyncio.create_task(self._del_account(a))
195
+
196
+ for _ in range(to_add):
197
+ try:
198
+ new_acc = await self._new_account()
199
+ except Exception as e:
200
+ logger.warning("Pool maintain add failed: %s", e)
201
+ break
202
+
203
+ async with self._lock:
204
+ valid_now = len(self._valid_accounts())
205
+ if valid_now >= cycle_target:
206
+ asyncio.create_task(self._del_account(new_acc))
207
+ continue
208
+ self._accounts.append(new_acc)
209
+ except asyncio.CancelledError:
210
+ return
211
+ except Exception as e:
212
+ logger.warning("Pool maintain loop error: %s", e)
213
+
214
+ # ── public API ───────────────────────────────────────────────────
215
+
216
+ async def initialize(self) -> None:
217
+ self._target_size = POOL_MIN_SIZE
218
+ async with self._lock:
219
+ results = await asyncio.gather(
220
+ *[self._new_account() for _ in range(POOL_MIN_SIZE)],
221
+ return_exceptions=True,
222
+ )
223
+ for r in results:
224
+ if isinstance(r, AccountInfo):
225
+ self._accounts.append(r)
226
+ else:
227
+ logger.warning("Pool init failed: %s", r)
228
+ if not self._accounts:
229
+ self._accounts.append(await self._new_account())
230
+ logger.info("Pool: ready with %d accounts", len(self._accounts))
231
+ self._bg_task = asyncio.create_task(self._maintain())
232
+ self._maintain_event.set()
233
+
234
+ async def close(self) -> None:
235
+ if self._bg_task:
236
+ self._bg_task.cancel()
237
+ try:
238
+ await self._bg_task
239
+ except asyncio.CancelledError:
240
+ pass
241
+ for a in list(self._accounts):
242
+ await self._del_account(a)
243
+ self._accounts.clear()
244
+
245
+ async def acquire(self) -> AccountInfo:
246
+ """Get the least-busy valid account (creates one if needed)."""
247
+ good = self._valid_accounts()
248
+ if not good:
249
+ async with self._lock:
250
+ good = self._valid_accounts()
251
+ if not good:
252
+ acc = await self._new_account()
253
+ self._accounts.append(acc)
254
+ good = [acc]
255
+ acc = min(good, key=lambda a: a.active)
256
+ acc.active += 1
257
+ if acc.active >= POOL_TARGET_INFLIGHT_PER_ACCOUNT:
258
+ self._raise_target_size(len(good) + 1)
259
+ return acc
260
+
261
+ def release(self, acc: AccountInfo) -> None:
262
+ acc.active = max(0, acc.active - 1)
263
+ if acc.active == 0:
264
+ self._maintain_event.set()
265
+
266
+ async def report_failure(self, acc: AccountInfo) -> None:
267
+ """Mark account invalid, schedule cleanup, add replacement."""
268
+ acc.valid = False
269
+ acc.active = max(0, acc.active - 1)
270
+ self._raise_target_size(len(self._valid_accounts()) + 1)
271
+ asyncio.create_task(self._del_account(acc))
272
+ try:
273
+ new = await self._new_account()
274
+ async with self._lock:
275
+ if len(self._valid_accounts(include_expired=True)) < POOL_MAX_SIZE:
276
+ self._accounts.append(new)
277
+ else:
278
+ asyncio.create_task(self._del_account(new))
279
+ except Exception as e:
280
+ logger.warning("Pool replace failed: %s", e)
281
+ self._maintain_event.set()
282
+
283
+ async def get_models(self) -> list | dict:
284
+ acc = await self.acquire()
285
+ c = ZaiClient()
286
+ try:
287
+ c.token, c.user_id, c.username = acc.token, acc.user_id, acc.username
288
+ return await c.get_models()
289
+ finally:
290
+ self.release(acc)
291
+ await c.close()
292
+
293
+ # ── compat methods (called by request handlers) ──────────────────
294
+
295
+ async def ensure_auth(self) -> None:
296
+ """Ensure at least one valid account exists in the pool."""
297
+ good = self._valid_accounts(include_expired=True)
298
+ if not good:
299
+ async with self._lock:
300
+ good = self._valid_accounts(include_expired=True)
301
+ if not good:
302
+ self._accounts.append(await self._new_account())
303
+ if len(good) < POOL_MIN_SIZE:
304
+ self._raise_target_size(POOL_MIN_SIZE)
305
+
306
+ def get_auth_snapshot(self) -> dict[str, str]:
307
+ """Get auth snapshot from the least-busy valid account."""
308
+ good = self._valid_accounts()
309
+ if not good:
310
+ good = self._valid_accounts(include_expired=True)
311
+ if not good:
312
+ raise RuntimeError("No valid accounts in pool")
313
+ acc = min(good, key=lambda a: a.active)
314
+ acc.active += 1
315
+ if acc.active >= POOL_TARGET_INFLIGHT_PER_ACCOUNT:
316
+ self._raise_target_size(len(good) + 1)
317
+ return acc.snapshot()
318
+
319
+ def _release_by_user_id(self, user_id: str) -> None:
320
+ """Release (decrement active) for the account matching user_id."""
321
+ for a in self._accounts:
322
+ if a.user_id == user_id:
323
+ a.active = max(0, a.active - 1)
324
+ if a.active == 0:
325
+ self._maintain_event.set()
326
+ return
327
+
328
+ async def refresh_auth(self, failed_user_id: str | None = None) -> None:
329
+ """Invalidate the failed account (if given) and create a fresh one."""
330
+ if failed_user_id:
331
+ for a in self._accounts:
332
+ if a.user_id == failed_user_id:
333
+ a.valid = False
334
+ a.active = max(0, a.active - 1)
335
+ asyncio.create_task(self._del_account(a))
336
+ logger.info("SessionPool: invalidated failed account uid=%s", failed_user_id)
337
+ break
338
+ self._raise_target_size(len(self._valid_accounts()) + 1)
339
+ try:
340
+ acc = await self._new_account()
341
+ async with self._lock:
342
+ if len(self._valid_accounts(include_expired=True)) < POOL_MAX_SIZE:
343
+ self._accounts.append(acc)
344
+ else:
345
+ asyncio.create_task(self._del_account(acc))
346
+ logger.info("SessionPool: auth refreshed, new user_id=%s", acc.user_id)
347
+ except Exception as e:
348
+ logger.warning("SessionPool: refresh_auth failed: %s", e)
349
+ self._maintain_event.set()
350
+
351
+ async def cleanup_chats(self) -> None:
352
+ """Clean up chats for idle accounts to free concurrency slots."""
353
+ for a in list(self._accounts):
354
+ if a.valid and a.active == 0:
355
+ try:
356
+ c = ZaiClient()
357
+ c.token, c.user_id, c.username = a.token, a.user_id, a.username
358
+ await c.delete_all_chats()
359
+ await c.close()
360
+ except Exception:
361
+ pass
362
+
363
+
364
+ pool = SessionPool()
365
+
366
+
367
+ @asynccontextmanager
368
+ async def lifespan(_app: FastAPI):
369
+ await pool.initialize()
370
+ yield
371
+ await pool.close()
372
+
373
+
374
+ app = FastAPI(lifespan=lifespan)
375
+
376
+
377
+ # ── Toolify-style helpers ─────────���──────────────────────────────────
378
+
379
+
380
+ def _generate_trigger_signal() -> str:
381
+ chars = string.ascii_letters + string.digits
382
+ rand = "".join(secrets.choice(chars) for _ in range(4))
383
+ return f"<Function_{rand}_Start/>"
384
+
385
+
386
+ GLOBAL_TRIGGER_SIGNAL = _generate_trigger_signal()
387
+
388
+
389
+ def _extract_text_from_content(content: object) -> str:
390
+ if isinstance(content, str):
391
+ return content
392
+ if isinstance(content, list):
393
+ parts: list[str] = []
394
+ for p in content:
395
+ if isinstance(p, dict) and p.get("type") == "text":
396
+ parts.append(str(p.get("text", "")))
397
+ return " ".join(parts).strip()
398
+ if content is None:
399
+ return ""
400
+ try:
401
+ return json.dumps(content, ensure_ascii=False)
402
+ except Exception:
403
+ return str(content)
404
+
405
+
406
+ def _build_tool_call_index_from_messages(messages: list[dict]) -> dict[str, dict[str, str]]:
407
+ idx: dict[str, dict[str, str]] = {}
408
+ for msg in messages:
409
+ if msg.get("role") != "assistant":
410
+ continue
411
+ tcs = msg.get("tool_calls")
412
+ if not isinstance(tcs, list):
413
+ continue
414
+ for tc in tcs:
415
+ if not isinstance(tc, dict):
416
+ continue
417
+ tc_id = tc.get("id")
418
+ fn = tc.get("function", {}) if isinstance(tc.get("function"), dict) else {}
419
+ name = str(fn.get("name", ""))
420
+ args = fn.get("arguments", "{}")
421
+ if not isinstance(args, str):
422
+ try:
423
+ args = json.dumps(args, ensure_ascii=False)
424
+ except Exception:
425
+ args = "{}"
426
+ if isinstance(tc_id, str) and name:
427
+ idx[tc_id] = {"name": name, "arguments": args}
428
+ return idx
429
+
430
+
431
+ def _format_tool_result_for_ai(tool_name: str, tool_arguments: str, result_content: str) -> str:
432
+ return (
433
+ "<tool_execution_result>\n"
434
+ f"<tool_name>{tool_name}</tool_name>\n"
435
+ f"<tool_arguments>{tool_arguments}</tool_arguments>\n"
436
+ f"<tool_output>{result_content}</tool_output>\n"
437
+ "</tool_execution_result>"
438
+ )
439
+
440
+
441
+ def _format_assistant_tool_calls_for_ai(tool_calls: list[dict], trigger_signal: str) -> str:
442
+ blocks: list[str] = []
443
+ for tc in tool_calls:
444
+ if not isinstance(tc, dict):
445
+ continue
446
+ fn = tc.get("function", {}) if isinstance(tc.get("function"), dict) else {}
447
+ name = str(fn.get("name", "")).strip()
448
+ if not name:
449
+ continue
450
+ args = fn.get("arguments", "{}")
451
+ if isinstance(args, str):
452
+ args_text = args
453
+ else:
454
+ try:
455
+ args_text = json.dumps(args, ensure_ascii=False)
456
+ except Exception:
457
+ args_text = "{}"
458
+ blocks.append(
459
+ "<function_call>\n"
460
+ f"<name>{name}</name>\n"
461
+ f"<args_json>{args_text}</args_json>\n"
462
+ "</function_call>"
463
+ )
464
+ if not blocks:
465
+ return ""
466
+ return f"{trigger_signal}\n<function_calls>\n" + "\n".join(blocks) + "\n</function_calls>"
467
+
468
+
469
+ def _preprocess_messages(messages: list[dict]) -> list[dict]:
470
+ tool_idx = _build_tool_call_index_from_messages(messages)
471
+ out: list[dict] = []
472
+
473
+ for msg in messages:
474
+ if not isinstance(msg, dict):
475
+ continue
476
+ role = msg.get("role")
477
+
478
+ if role == "tool":
479
+ tc_id = msg.get("tool_call_id")
480
+ content = _extract_text_from_content(msg.get("content", ""))
481
+ info = tool_idx.get(tc_id, {"name": msg.get("name", "unknown_tool"), "arguments": "{}"})
482
+ out.append(
483
+ {
484
+ "role": "user",
485
+ "content": _format_tool_result_for_ai(info["name"], info["arguments"], content),
486
+ }
487
+ )
488
+ continue
489
+
490
+ if role == "assistant" and isinstance(msg.get("tool_calls"), list):
491
+ xml_calls = _format_assistant_tool_calls_for_ai(msg["tool_calls"], GLOBAL_TRIGGER_SIGNAL)
492
+ content = (_extract_text_from_content(msg.get("content", "")) + "\n" + xml_calls).strip()
493
+ out.append({"role": "assistant", "content": content})
494
+ continue
495
+
496
+ if role == "developer":
497
+ cloned = dict(msg)
498
+ cloned["role"] = "system"
499
+ out.append(cloned)
500
+ continue
501
+
502
+ out.append(msg)
503
+
504
+ return out
505
+
506
+
507
+ def _generate_function_prompt(tools: list[dict], trigger_signal: str) -> str:
508
+ tool_lines: list[str] = []
509
+ for i, t in enumerate(tools):
510
+ if not isinstance(t, dict) or t.get("type") != "function":
511
+ continue
512
+ fn = t.get("function", {}) if isinstance(t.get("function"), dict) else {}
513
+ name = str(fn.get("name", "")).strip()
514
+ if not name:
515
+ continue
516
+ desc = str(fn.get("description", "")).strip() or "None"
517
+ params = fn.get("parameters", {})
518
+ required = params.get("required", []) if isinstance(params, dict) else []
519
+ try:
520
+ params_json = json.dumps(params, ensure_ascii=False)
521
+ except Exception:
522
+ params_json = "{}"
523
+
524
+ tool_lines.append(
525
+ f"{i+1}. <tool name=\"{name}\">\n"
526
+ f" Description: {desc}\n"
527
+ f" Required: {', '.join(required) if isinstance(required, list) and required else 'None'}\n"
528
+ f" Parameters JSON Schema: {params_json}"
529
+ )
530
+
531
+ tools_block = "\n\n".join(tool_lines) if tool_lines else "(no tools)"
532
+
533
+ return (
534
+ "You have access to tools.\n\n"
535
+ "When you need to call tools, you MUST output exactly:\n"
536
+ f"{trigger_signal}\n"
537
+ "<function_calls>\n"
538
+ " <function_call>\n"
539
+ " <name>tool_name</name>\n"
540
+ " <args_json>{\"arg\":\"value\"}</args_json>\n"
541
+ " </function_call>\n"
542
+ "</function_calls>\n\n"
543
+ "Rules:\n"
544
+ "1) args_json MUST be valid JSON object\n"
545
+ "2) For multiple calls, output one <function_calls> with multiple <function_call> children\n"
546
+ "3) If no tool is needed, answer normally\n\n"
547
+ f"Available tools:\n{tools_block}"
548
+ )
549
+
550
+
551
+ def _safe_process_tool_choice(tool_choice: Any, tools: list[dict]) -> str:
552
+ if tool_choice is None:
553
+ return ""
554
+
555
+ if isinstance(tool_choice, str):
556
+ if tool_choice == "required":
557
+ return "\nIMPORTANT: You MUST call at least one tool in your next response."
558
+ if tool_choice == "none":
559
+ return "\nIMPORTANT: Do not call tools. Answer directly."
560
+ return ""
561
+
562
+ if isinstance(tool_choice, dict):
563
+ fn = tool_choice.get("function", {}) if isinstance(tool_choice.get("function"), dict) else {}
564
+ name = fn.get("name")
565
+ if isinstance(name, str) and name:
566
+ return f"\nIMPORTANT: You MUST call this tool: {name}"
567
+
568
+ return ""
569
+
570
+
571
+ def _flatten_messages_for_zai(messages: list[dict]) -> list[dict]:
572
+ parts: list[str] = []
573
+ for msg in messages:
574
+ role = str(msg.get("role", "user")).upper()
575
+ content = _extract_text_from_content(msg.get("content", ""))
576
+ parts.append(f"<{role}>{content}</{role}>")
577
+ return [{"role": "user", "content": "\n".join(parts)}]
578
+
579
+
580
+ def _remove_think_blocks(text: str) -> str:
581
+ while "<think>" in text and "</think>" in text:
582
+ start = text.find("<think>")
583
+ if start == -1:
584
+ break
585
+ pos = start + 7
586
+ depth = 1
587
+ while pos < len(text) and depth > 0:
588
+ if text[pos : pos + 7] == "<think>":
589
+ depth += 1
590
+ pos += 7
591
+ elif text[pos : pos + 8] == "</think>":
592
+ depth -= 1
593
+ pos += 8
594
+ else:
595
+ pos += 1
596
+ if depth == 0:
597
+ text = text[:start] + text[pos:]
598
+ else:
599
+ break
600
+ return text
601
+
602
+
603
+ def _find_last_trigger_signal_outside_think(text: str, trigger_signal: str) -> int:
604
+ if not text or not trigger_signal:
605
+ return -1
606
+ i = 0
607
+ depth = 0
608
+ last = -1
609
+ while i < len(text):
610
+ if text.startswith("<think>", i):
611
+ depth += 1
612
+ i += 7
613
+ continue
614
+ if text.startswith("</think>", i):
615
+ depth = max(0, depth - 1)
616
+ i += 8
617
+ continue
618
+ if depth == 0 and text.startswith(trigger_signal, i):
619
+ last = i
620
+ i += 1
621
+ continue
622
+ i += 1
623
+ return last
624
+
625
+
626
+ def _drain_safe_answer_delta(
627
+ answer_text: str,
628
+ emitted_chars: int,
629
+ *,
630
+ has_fc: bool,
631
+ trigger_signal: str,
632
+ ) -> tuple[str, int, bool]:
633
+ """在流式输出中提取可安全下发的增量文本。
634
+
635
+ - 非 function-calling 场景:可直接全部下发。
636
+ - function-calling 场景:保留末尾 `len(trigger_signal)-1` 字符,避免触发信号跨 chunk 时泄漏。
637
+ - 一旦检测到触发信号,仅允许下发触发信号之前的文本。
638
+ """
639
+ if emitted_chars >= len(answer_text):
640
+ return "", emitted_chars, False
641
+
642
+ if not has_fc:
643
+ return answer_text[emitted_chars:], len(answer_text), False
644
+
645
+ trigger_pos = _find_last_trigger_signal_outside_think(answer_text, trigger_signal)
646
+ if trigger_pos >= 0:
647
+ safe_end = trigger_pos
648
+ has_trigger = True
649
+ else:
650
+ holdback = max(0, len(trigger_signal) - 1)
651
+ safe_end = max(0, len(answer_text) - holdback)
652
+ has_trigger = False
653
+
654
+ if safe_end <= emitted_chars:
655
+ return "", emitted_chars, has_trigger
656
+
657
+ return answer_text[emitted_chars:safe_end], safe_end, has_trigger
658
+
659
+
660
+ def _parse_function_calls_xml(xml_string: str, trigger_signal: str) -> list[dict]:
661
+ if not xml_string or trigger_signal not in xml_string:
662
+ return []
663
+
664
+ cleaned = _remove_think_blocks(xml_string)
665
+ pos = cleaned.rfind(trigger_signal)
666
+ if pos == -1:
667
+ return []
668
+
669
+ sub = cleaned[pos:]
670
+ m = re.search(r"<function_calls>([\s\S]*?)</function_calls>", sub)
671
+ if not m:
672
+ return []
673
+
674
+ calls_block = m.group(1)
675
+ chunks = re.findall(r"<function_call>([\s\S]*?)</function_call>", calls_block)
676
+ out: list[dict] = []
677
+
678
+ for c in chunks:
679
+ name_m = re.search(r"<name>([\s\S]*?)</name>", c)
680
+ args_m = re.search(r"<args_json>([\s\S]*?)</args_json>", c)
681
+ if not name_m:
682
+ continue
683
+ name = name_m.group(1).strip()
684
+ args_raw = (args_m.group(1).strip() if args_m else "{}")
685
+ try:
686
+ parsed = json.loads(args_raw) if args_raw else {}
687
+ if not isinstance(parsed, dict):
688
+ parsed = {"value": parsed}
689
+ except Exception:
690
+ parsed = {"raw": args_raw}
691
+
692
+ out.append(
693
+ {
694
+ "id": f"call_{uuid.uuid4().hex[:24]}",
695
+ "type": "function",
696
+ "function": {"name": name, "arguments": json.dumps(parsed, ensure_ascii=False)},
697
+ }
698
+ )
699
+
700
+ return out
701
+
702
+
703
+ # ── OpenAI response helpers ──────────────────────────────────────────
704
+
705
+
706
+ def _make_id() -> str:
707
+ return f"chatcmpl-{uuid.uuid4().hex[:29]}"
708
+
709
+
710
+ def _estimate_tokens(text: str) -> int:
711
+ if not text:
712
+ return 0
713
+ return max(1, math.ceil(len(text) / 2))
714
+
715
+
716
+ def _to_optional_bool(value: Any) -> bool | None:
717
+ if value is None:
718
+ return None
719
+ if isinstance(value, bool):
720
+ return value
721
+ if isinstance(value, (int, float)):
722
+ return bool(value)
723
+ if isinstance(value, str):
724
+ v = value.strip().lower()
725
+ if v in {"1", "true", "yes", "on"}:
726
+ return True
727
+ if v in {"0", "false", "no", "off"}:
728
+ return False
729
+ return None
730
+
731
+
732
+ def _build_usage(prompt_text: str, completion_text: str) -> dict:
733
+ p = _estimate_tokens(prompt_text)
734
+ c = _estimate_tokens(completion_text)
735
+ return {"prompt_tokens": p, "completion_tokens": c, "total_tokens": p + c}
736
+
737
+
738
+ def _openai_chunk(
739
+ completion_id: str,
740
+ model: str,
741
+ *,
742
+ content: str | None = None,
743
+ reasoning_content: str | None = None,
744
+ finish_reason: str | None = None,
745
+ ) -> dict:
746
+ delta: dict = {}
747
+ if content is not None:
748
+ delta["content"] = content
749
+ if reasoning_content is not None:
750
+ delta["reasoning_content"] = reasoning_content
751
+ return {
752
+ "id": completion_id,
753
+ "object": "chat.completion.chunk",
754
+ "created": int(time.time()),
755
+ "model": model,
756
+ "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
757
+ }
758
+
759
+
760
+ def _extract_upstream_tool_calls(data: dict) -> list[dict]:
761
+ # Native Toolify/Z.ai style
762
+ tcs = data.get("tool_calls")
763
+ if isinstance(tcs, list):
764
+ return tcs
765
+
766
+ # OpenAI-like style: choices[0].delta.tool_calls or choices[0].message.tool_calls
767
+ choices = data.get("choices")
768
+ if isinstance(choices, list) and choices:
769
+ c0 = choices[0] if isinstance(choices[0], dict) else {}
770
+ delta = c0.get("delta") if isinstance(c0.get("delta"), dict) else {}
771
+ message = c0.get("message") if isinstance(c0.get("message"), dict) else {}
772
+ for candidate in (delta.get("tool_calls"), message.get("tool_calls")):
773
+ if isinstance(candidate, list):
774
+ return candidate
775
+
776
+ return []
777
+
778
+
779
+ def _extract_upstream_delta(data: dict) -> tuple[str, str]:
780
+ """Best-effort extract (phase, delta_text) from upstream event payload."""
781
+ phase = str(data.get("phase", "") or "")
782
+
783
+ # OpenAI-like envelope
784
+ choices = data.get("choices")
785
+ if isinstance(choices, list) and choices:
786
+ c0 = choices[0] if isinstance(choices[0], dict) else {}
787
+ delta_obj = c0.get("delta") if isinstance(c0.get("delta"), dict) else {}
788
+ msg_obj = c0.get("message") if isinstance(c0.get("message"), dict) else {}
789
+ if not phase:
790
+ phase = str(c0.get("phase", "") or "")
791
+ for v in (
792
+ delta_obj.get("reasoning_content"),
793
+ delta_obj.get("content"),
794
+ msg_obj.get("reasoning_content"),
795
+ msg_obj.get("content"),
796
+ ):
797
+ if isinstance(v, str) and v:
798
+ return phase, v
799
+
800
+ candidates = [
801
+ data.get("delta_content"),
802
+ data.get("content"),
803
+ data.get("delta"),
804
+ (data.get("message") or {}).get("content") if isinstance(data.get("message"), dict) else None,
805
+ ]
806
+
807
+ for v in candidates:
808
+ if isinstance(v, str) and v:
809
+ return phase, v
810
+
811
+ return phase, ""
812
+
813
+
814
+ async def _iter_upstream_with_first_event_timeout(upstream: Any, timeout_s: float):
815
+ """Wrap upstream iterator and enforce a timeout for the first event only."""
816
+ iterator = upstream.__aiter__()
817
+ try:
818
+ first = await asyncio.wait_for(iterator.__anext__(), timeout=timeout_s)
819
+ except StopAsyncIteration:
820
+ return
821
+ yield first
822
+ async for data in iterator:
823
+ yield data
824
+
825
+
826
+ # ── Endpoints ──────────��─────────────────────────────────────────────
827
+
828
+
829
+ @app.get("/v1/models")
830
+ async def list_models():
831
+ models_resp = await pool.get_models()
832
+ if isinstance(models_resp, dict) and "data" in models_resp:
833
+ models_list = models_resp["data"]
834
+ elif isinstance(models_resp, list):
835
+ models_list = models_resp
836
+ else:
837
+ models_list = []
838
+
839
+ return {
840
+ "object": "list",
841
+ "data": [
842
+ {
843
+ "id": m.get("id") or m.get("name", "unknown"),
844
+ "object": "model",
845
+ "created": 0,
846
+ "owned_by": "z.ai",
847
+ }
848
+ for m in models_list
849
+ ],
850
+ }
851
+
852
+
853
+ @app.post("/v1/chat/completions")
854
+ async def chat_completions(request: Request):
855
+ body = await request.json()
856
+
857
+ model: str = body.get("model", "glm-5")
858
+ messages: list[dict] = body.get("messages", [])
859
+ stream: bool = body.get("stream", False)
860
+ tools: list[dict] | None = body.get("tools")
861
+ tool_choice = body.get("tool_choice")
862
+ enable_thinking = _to_optional_bool(body.get("enable_thinking"))
863
+
864
+ # signature prompt: last user message in original request
865
+ prompt = ""
866
+ for msg in reversed(messages):
867
+ if msg.get("role") == "user":
868
+ prompt = _extract_text_from_content(msg.get("content", ""))
869
+ break
870
+ if not prompt:
871
+ return JSONResponse(
872
+ status_code=400,
873
+ content={"error": {"message": "No user message found in messages", "type": "invalid_request_error"}},
874
+ )
875
+
876
+ processed_messages = _preprocess_messages(messages)
877
+
878
+ has_fc = bool(tools)
879
+ if has_fc:
880
+ fc_prompt = _generate_function_prompt(tools or [], GLOBAL_TRIGGER_SIGNAL)
881
+ fc_prompt += _safe_process_tool_choice(tool_choice, tools or [])
882
+ processed_messages.insert(0, {"role": "system", "content": fc_prompt})
883
+
884
+ flat_messages = _flatten_messages_for_zai(processed_messages)
885
+ usage_prompt_text = "\n".join(_extract_text_from_content(m.get("content", "")) for m in processed_messages)
886
+
887
+ req_id = f"req_{uuid.uuid4().hex[:10]}"
888
+ logger.info(
889
+ "[entry][%s] model=%s stream=%s tools=%d input_messages=%d flat_chars=%d est_prompt_tokens=%d first_event_timeout=%.1fs timeout_retry_max=%d",
890
+ req_id,
891
+ model,
892
+ stream,
893
+ len(tools or []),
894
+ len(messages),
895
+ len(flat_messages[0].get("content", "")),
896
+ _estimate_tokens(usage_prompt_text),
897
+ UPSTREAM_FIRST_EVENT_TIMEOUT,
898
+ UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX,
899
+ )
900
+
901
+ async def run_once(auth: dict[str, str], enable_thinking_override: bool | None):
902
+ client = ZaiClient()
903
+ try:
904
+ client.token = auth["token"]
905
+ client.user_id = auth["user_id"]
906
+ client.username = auth["username"]
907
+ create_chat_started = time.perf_counter()
908
+ chat = await client.create_chat(prompt, model, enable_thinking=enable_thinking_override)
909
+ create_chat_elapsed = time.perf_counter() - create_chat_started
910
+ chat_id = chat["id"]
911
+ upstream = client.chat_completions(
912
+ chat_id=chat_id,
913
+ messages=flat_messages,
914
+ prompt=prompt,
915
+ model=model,
916
+ tools=None,
917
+ enable_thinking=enable_thinking_override,
918
+ )
919
+ return upstream, client, chat_id, create_chat_elapsed
920
+ except Exception:
921
+ await client.close()
922
+ raise
923
+
924
+ if stream:
925
+
926
+ async def gen_sse():
927
+ completion_id = _make_id()
928
+ retried = False
929
+ first_event_timeout_retries = 0
930
+ empty_reply_retries = 0
931
+ current_uid: str | None = None
932
+ role_emitted = False
933
+
934
+ while True:
935
+ client: ZaiClient | None = None
936
+ chat_id: str | None = None
937
+ try:
938
+ phase_started = time.perf_counter()
939
+ await pool.ensure_auth()
940
+ ensure_auth_elapsed = time.perf_counter() - phase_started
941
+ auth = pool.get_auth_snapshot()
942
+ current_uid = auth["user_id"]
943
+ if not role_emitted:
944
+ yield f"data: {json.dumps({'id': completion_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]}, ensure_ascii=False)}\n\n"
945
+ role_emitted = True
946
+ upstream, client, chat_id, create_chat_elapsed = await run_once(auth, enable_thinking)
947
+ first_upstream_started = time.perf_counter()
948
+ first_event_logged = False
949
+
950
+ reasoning_parts: list[str] = []
951
+ answer_text = ""
952
+ emitted_answer_chars = 0
953
+ native_tool_calls: list[dict] = []
954
+
955
+ async for data in _iter_upstream_with_first_event_timeout(upstream, UPSTREAM_FIRST_EVENT_TIMEOUT):
956
+ if not first_event_logged:
957
+ first_upstream_elapsed = time.perf_counter() - first_upstream_started
958
+ logger.info(
959
+ "[stream][%s] phase ensure_auth=%.3fs create_chat=%.3fs first_upstream_event=%.3fs",
960
+ completion_id,
961
+ ensure_auth_elapsed,
962
+ create_chat_elapsed,
963
+ first_upstream_elapsed,
964
+ )
965
+ first_event_logged = True
966
+ phase, delta = _extract_upstream_delta(data)
967
+
968
+ upstream_tcs = _extract_upstream_tool_calls(data)
969
+ if upstream_tcs:
970
+ for tc in upstream_tcs:
971
+ native_tool_calls.append(
972
+ {
973
+ "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"),
974
+ "type": "function",
975
+ "function": {
976
+ "name": tc.get("function", {}).get("name", ""),
977
+ "arguments": tc.get("function", {}).get("arguments", ""),
978
+ },
979
+ }
980
+ )
981
+ continue
982
+
983
+ if phase == "thinking" and delta:
984
+ reasoning_parts.append(delta)
985
+ chunk = _openai_chunk(completion_id, model, reasoning_content=delta)
986
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
987
+ elif delta:
988
+ answer_text += delta
989
+ safe_delta, emitted_answer_chars, _ = _drain_safe_answer_delta(
990
+ answer_text,
991
+ emitted_answer_chars,
992
+ has_fc=has_fc,
993
+ trigger_signal=GLOBAL_TRIGGER_SIGNAL,
994
+ )
995
+ if safe_delta:
996
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=safe_delta), ensure_ascii=False)}\n\n"
997
+
998
+ if not first_event_logged:
999
+ logger.info(
1000
+ "[stream][%s] phase ensure_auth=%.3fs create_chat=%.3fs first_upstream_event=EOF",
1001
+ completion_id,
1002
+ ensure_auth_elapsed,
1003
+ create_chat_elapsed,
1004
+ )
1005
+
1006
+ if native_tool_calls:
1007
+ logger.info("[stream][%s] native_tool_calls=%d", completion_id, len(native_tool_calls))
1008
+ for i, tc in enumerate(native_tool_calls):
1009
+ tc_chunk = {
1010
+ "id": completion_id,
1011
+ "object": "chat.completion.chunk",
1012
+ "created": int(time.time()),
1013
+ "model": model,
1014
+ "choices": [{"index": 0, "delta": {"tool_calls": [{"index": i, **tc}]}, "finish_reason": None}],
1015
+ }
1016
+ yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n"
1017
+ finish = _openai_chunk(completion_id, model, finish_reason="tool_calls")
1018
+ yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
1019
+ yield "data: [DONE]\n\n"
1020
+ return
1021
+
1022
+ logger.info(
1023
+ "[stream][%s] collected answer_len=%d reasoning_len=%d",
1024
+ completion_id,
1025
+ len(answer_text),
1026
+ len("".join(reasoning_parts)),
1027
+ )
1028
+ if not answer_text and not reasoning_parts:
1029
+ if empty_reply_retries >= UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX:
1030
+ yield f"data: {json.dumps({'error': {'message': 'Upstream returned empty reply after retry', 'type': 'empty_response_error'}}, ensure_ascii=False)}\n\n"
1031
+ yield "data: [DONE]\n\n"
1032
+ return
1033
+ empty_reply_retries += 1
1034
+ logger.warning(
1035
+ "[stream][%s] empty upstream reply, retrying... (%d/%d)",
1036
+ completion_id,
1037
+ empty_reply_retries,
1038
+ UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX,
1039
+ )
1040
+ await pool.refresh_auth(current_uid)
1041
+ current_uid = None
1042
+ continue
1043
+ parsed = _parse_function_calls_xml(answer_text, GLOBAL_TRIGGER_SIGNAL) if has_fc else []
1044
+
1045
+ if parsed:
1046
+ logger.info("[stream][%s] parsed_tool_calls=%d", completion_id, len(parsed))
1047
+ prefix_pos = _find_last_trigger_signal_outside_think(answer_text, GLOBAL_TRIGGER_SIGNAL)
1048
+ if prefix_pos > emitted_answer_chars:
1049
+ prefix_delta = answer_text[emitted_answer_chars:prefix_pos]
1050
+ if prefix_delta:
1051
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=prefix_delta), ensure_ascii=False)}\n\n"
1052
+
1053
+ for i, tc in enumerate(parsed):
1054
+ tc_chunk = {
1055
+ "id": completion_id,
1056
+ "object": "chat.completion.chunk",
1057
+ "created": int(time.time()),
1058
+ "model": model,
1059
+ "choices": [{"index": 0, "delta": {"tool_calls": [{"index": i, **tc}]}, "finish_reason": None}],
1060
+ }
1061
+ yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n"
1062
+
1063
+ finish = _openai_chunk(completion_id, model, finish_reason="tool_calls")
1064
+ yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
1065
+ yield "data: [DONE]\n\n"
1066
+ return
1067
+
1068
+ if emitted_answer_chars < len(answer_text):
1069
+ tail_delta = answer_text[emitted_answer_chars:]
1070
+ if tail_delta:
1071
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=tail_delta), ensure_ascii=False)}\n\n"
1072
+ else:
1073
+ # Never return an empty stream response body to clients.
1074
+ if not answer_text:
1075
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=''), ensure_ascii=False)}\n\n"
1076
+
1077
+ finish = _openai_chunk(completion_id, model, finish_reason="stop")
1078
+ yield f"data: {json.dumps(finish, ensure_ascii=False)}\n\n"
1079
+ yield "data: [DONE]\n\n"
1080
+ return
1081
+
1082
+ except asyncio.TimeoutError:
1083
+ logger.error(
1084
+ "[stream][%s] first upstream event timeout: %.1fs",
1085
+ completion_id,
1086
+ UPSTREAM_FIRST_EVENT_TIMEOUT,
1087
+ )
1088
+ if client is not None:
1089
+ if chat_id:
1090
+ await client.delete_chat(chat_id)
1091
+ await client.close()
1092
+ client = None
1093
+ if first_event_timeout_retries >= UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX:
1094
+ yield f"data: {json.dumps({'error': {'message': 'Upstream first event timeout after retry', 'type': 'timeout_error'}}, ensure_ascii=False)}\n\n"
1095
+ yield "data: [DONE]\n\n"
1096
+ return
1097
+ first_event_timeout_retries += 1
1098
+ logger.info(
1099
+ "[stream][%s] retrying after first-event timeout... (%d/%d)",
1100
+ completion_id,
1101
+ first_event_timeout_retries,
1102
+ UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX,
1103
+ )
1104
+ await pool.refresh_auth(current_uid)
1105
+ current_uid = None
1106
+ continue
1107
+ except (httpcore.RemoteProtocolError, httpx.RemoteProtocolError) as e:
1108
+ logger.error("[stream][%s] server disconnected: %s", completion_id, e)
1109
+ if client is not None:
1110
+ if chat_id:
1111
+ await client.delete_chat(chat_id)
1112
+ await client.close()
1113
+ client = None
1114
+ if retried:
1115
+ error_msg = "上游服务断开连接,请稍后重试"
1116
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=f'[{error_msg}]'), ensure_ascii=False)}\n\n"
1117
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, finish_reason='error'), ensure_ascii=False)}\n\n"
1118
+ yield "data: [DONE]\n\n"
1119
+ return
1120
+ retried = True
1121
+ logger.info("[stream][%s] switching account and retrying...", completion_id)
1122
+ await pool.refresh_auth(current_uid)
1123
+ current_uid = None
1124
+ continue
1125
+ except (httpcore.ReadTimeout, httpx.ReadTimeout) as e:
1126
+ logger.error("[stream][%s] read timeout: %s", completion_id, e)
1127
+ if client is not None:
1128
+ if chat_id:
1129
+ await client.delete_chat(chat_id)
1130
+ await client.close()
1131
+ client = None
1132
+
1133
+ if retried:
1134
+ error_msg = "上游服务响应超时,请稍后重试或减少消息长度"
1135
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, content=f'[{error_msg}]'), ensure_ascii=False)}\n\n"
1136
+ yield f"data: {json.dumps(_openai_chunk(completion_id, model, finish_reason='error'), ensure_ascii=False)}\n\n"
1137
+ yield "data: [DONE]\n\n"
1138
+ return
1139
+
1140
+ retried = True
1141
+ logger.info("[stream][%s] retrying after timeout...", completion_id)
1142
+ await pool.refresh_auth(current_uid)
1143
+ current_uid = None
1144
+ continue
1145
+ except httpx.HTTPStatusError as e:
1146
+ # Handle upstream 400 with concurrency limit (code 429)
1147
+ is_concurrency = False
1148
+ try:
1149
+ err_body = e.response.json() if e.response else {}
1150
+ is_concurrency = err_body.get("code") == 429
1151
+ except Exception:
1152
+ pass
1153
+
1154
+ logger.error("[stream][%s] HTTP %s (concurrency=%s): %s", completion_id, e.response.status_code if e.response else '?', is_concurrency, e)
1155
+ if client is not None:
1156
+ if chat_id:
1157
+ await client.delete_chat(chat_id)
1158
+ await client.close()
1159
+ client = None
1160
+
1161
+ if retried:
1162
+ yield f"data: {json.dumps({'error': {'message': 'Upstream concurrency limit' if is_concurrency else 'Upstream error after retry', 'type': 'server_error'}}, ensure_ascii=False)}\n\n"
1163
+ yield "data: [DONE]\n\n"
1164
+ return
1165
+
1166
+ retried = True
1167
+ if is_concurrency:
1168
+ logger.info("[stream][%s] concurrency limit hit, cleaning up chats...", completion_id)
1169
+ await pool.cleanup_chats()
1170
+ await asyncio.sleep(1)
1171
+ await pool.refresh_auth(current_uid)
1172
+ current_uid = None
1173
+ continue
1174
+ except Exception as e:
1175
+ logger.exception("[stream][%s] exception: %s", completion_id, e)
1176
+ if client is not None:
1177
+ if chat_id:
1178
+ await client.delete_chat(chat_id)
1179
+ await client.close()
1180
+ client = None
1181
+
1182
+ if retried:
1183
+ yield f"data: {json.dumps({'error': {'message': 'Upstream Zai error after retry', 'type': 'server_error'}}, ensure_ascii=False)}\n\n"
1184
+ yield "data: [DONE]\n\n"
1185
+ return
1186
+
1187
+ retried = True
1188
+ logger.info("[stream][%s] refreshing auth and retrying...", completion_id)
1189
+ await pool.refresh_auth(current_uid)
1190
+ current_uid = None
1191
+ continue
1192
+ finally:
1193
+ if client is not None:
1194
+ if chat_id:
1195
+ await client.delete_chat(chat_id)
1196
+ await client.close()
1197
+ if current_uid:
1198
+ pool._release_by_user_id(current_uid)
1199
+ current_uid = None
1200
+
1201
+ return StreamingResponse(
1202
+ gen_sse(),
1203
+ media_type="text/event-stream",
1204
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
1205
+ )
1206
+
1207
+ completion_id = _make_id()
1208
+ client: ZaiClient | None = None
1209
+ chat_id: str | None = None
1210
+ current_uid: str | None = None
1211
+
1212
+ max_sync_attempts = max(2, UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX + 1)
1213
+ for attempt in range(max_sync_attempts):
1214
+ try:
1215
+ phase_started = time.perf_counter()
1216
+ await pool.ensure_auth()
1217
+ ensure_auth_elapsed = time.perf_counter() - phase_started
1218
+ auth = pool.get_auth_snapshot()
1219
+ current_uid = auth["user_id"]
1220
+ upstream, client, chat_id, create_chat_elapsed = await run_once(auth, enable_thinking)
1221
+ first_upstream_started = time.perf_counter()
1222
+ first_event_logged = False
1223
+ reasoning_parts: list[str] = []
1224
+ answer_parts: list[str] = []
1225
+ native_tool_calls: list[dict] = []
1226
+
1227
+ async for data in _iter_upstream_with_first_event_timeout(upstream, UPSTREAM_FIRST_EVENT_TIMEOUT):
1228
+ if not first_event_logged:
1229
+ first_upstream_elapsed = time.perf_counter() - first_upstream_started
1230
+ logger.info(
1231
+ "[sync][%s] phase ensure_auth=%.3fs create_chat=%.3fs first_upstream_event=%.3fs",
1232
+ completion_id,
1233
+ ensure_auth_elapsed,
1234
+ create_chat_elapsed,
1235
+ first_upstream_elapsed,
1236
+ )
1237
+ first_event_logged = True
1238
+ phase, delta = _extract_upstream_delta(data)
1239
+
1240
+ upstream_tcs = _extract_upstream_tool_calls(data)
1241
+ if upstream_tcs:
1242
+ for tc in upstream_tcs:
1243
+ native_tool_calls.append(
1244
+ {
1245
+ "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"),
1246
+ "type": "function",
1247
+ "function": {
1248
+ "name": tc.get("function", {}).get("name", ""),
1249
+ "arguments": tc.get("function", {}).get("arguments", ""),
1250
+ },
1251
+ }
1252
+ )
1253
+ elif phase == "thinking" and delta:
1254
+ reasoning_parts.append(delta)
1255
+ elif delta:
1256
+ answer_parts.append(delta)
1257
+
1258
+ if not first_event_logged:
1259
+ logger.info(
1260
+ "[sync][%s] phase ensure_auth=%.3fs create_chat=%.3fs first_upstream_event=EOF",
1261
+ completion_id,
1262
+ ensure_auth_elapsed,
1263
+ create_chat_elapsed,
1264
+ )
1265
+
1266
+ if native_tool_calls:
1267
+ message: dict = {"role": "assistant", "content": None, "tool_calls": native_tool_calls}
1268
+ if reasoning_parts:
1269
+ message["reasoning_content"] = "".join(reasoning_parts)
1270
+ usage = _build_usage(usage_prompt_text, "".join(reasoning_parts))
1271
+ return {
1272
+ "id": completion_id,
1273
+ "object": "chat.completion",
1274
+ "created": int(time.time()),
1275
+ "model": model,
1276
+ "choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}],
1277
+ "usage": usage,
1278
+ }
1279
+
1280
+ answer_text = "".join(answer_parts)
1281
+ if not answer_text and not reasoning_parts:
1282
+ if attempt < max_sync_attempts - 1:
1283
+ logger.warning(
1284
+ "[sync][%s] empty upstream reply, retrying... (%d/%d)",
1285
+ completion_id,
1286
+ attempt + 1,
1287
+ max_sync_attempts - 1,
1288
+ )
1289
+ await pool.refresh_auth(current_uid)
1290
+ current_uid = None
1291
+ continue
1292
+ return JSONResponse(
1293
+ status_code=502,
1294
+ content={"error": {"message": "Upstream returned empty reply after retry", "type": "empty_response_error"}},
1295
+ )
1296
+ parsed = _parse_function_calls_xml(answer_text, GLOBAL_TRIGGER_SIGNAL) if has_fc else []
1297
+ if parsed:
1298
+ prefix_pos = _find_last_trigger_signal_outside_think(answer_text, GLOBAL_TRIGGER_SIGNAL)
1299
+ prefix_text = answer_text[:prefix_pos].rstrip() if prefix_pos > 0 else None
1300
+ message = {"role": "assistant", "content": prefix_text or None, "tool_calls": parsed}
1301
+ if reasoning_parts:
1302
+ message["reasoning_content"] = "".join(reasoning_parts)
1303
+ usage = _build_usage(usage_prompt_text, (prefix_text or "") + "".join(reasoning_parts))
1304
+ return {
1305
+ "id": completion_id,
1306
+ "object": "chat.completion",
1307
+ "created": int(time.time()),
1308
+ "model": model,
1309
+ "choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}],
1310
+ "usage": usage,
1311
+ }
1312
+
1313
+ usage = _build_usage(usage_prompt_text, answer_text + "".join(reasoning_parts))
1314
+ msg: dict = {"role": "assistant", "content": answer_text}
1315
+ if reasoning_parts:
1316
+ msg["reasoning_content"] = "".join(reasoning_parts)
1317
+ return {
1318
+ "id": completion_id,
1319
+ "object": "chat.completion",
1320
+ "created": int(time.time()),
1321
+ "model": model,
1322
+ "choices": [{"index": 0, "message": msg, "finish_reason": "stop"}],
1323
+ "usage": usage,
1324
+ }
1325
+
1326
+ except asyncio.TimeoutError:
1327
+ logger.error(
1328
+ "[sync][%s] first upstream event timeout: %.1fs",
1329
+ completion_id,
1330
+ UPSTREAM_FIRST_EVENT_TIMEOUT,
1331
+ )
1332
+ if client is not None:
1333
+ if chat_id:
1334
+ await client.delete_chat(chat_id)
1335
+ await client.close()
1336
+ client = None
1337
+ chat_id = None
1338
+ if attempt < UPSTREAM_FIRST_EVENT_TIMEOUT_RETRY_MAX:
1339
+ await pool.refresh_auth(current_uid)
1340
+ current_uid = None
1341
+ continue
1342
+ return JSONResponse(
1343
+ status_code=504,
1344
+ content={"error": {"message": "Upstream first event timeout after retry", "type": "timeout_error"}},
1345
+ )
1346
+ except httpx.HTTPStatusError as e:
1347
+ is_concurrency = False
1348
+ try:
1349
+ err_body = e.response.json() if e.response else {}
1350
+ is_concurrency = err_body.get("code") == 429
1351
+ except Exception:
1352
+ pass
1353
+ logger.error("[sync][%s] HTTP %s (concurrency=%s): %s", completion_id, e.response.status_code if e.response else '?', is_concurrency, e)
1354
+ if client is not None:
1355
+ if chat_id:
1356
+ await client.delete_chat(chat_id)
1357
+ await client.close()
1358
+ client = None
1359
+ chat_id = None
1360
+ if attempt == 0:
1361
+ if is_concurrency:
1362
+ await pool.cleanup_chats()
1363
+ await asyncio.sleep(1)
1364
+ await pool.refresh_auth(current_uid)
1365
+ current_uid = None
1366
+ continue
1367
+ return JSONResponse(
1368
+ status_code=502,
1369
+ content={"error": {"message": "Upstream concurrency limit" if is_concurrency else "Upstream error after retry", "type": "server_error"}},
1370
+ )
1371
+ except Exception as e:
1372
+ logger.exception("[sync][%s] exception: %s", completion_id, e)
1373
+ if client is not None:
1374
+ if chat_id:
1375
+ await client.delete_chat(chat_id)
1376
+ await client.close()
1377
+ client = None
1378
+ chat_id = None
1379
+
1380
+ if attempt == 0:
1381
+ await pool.refresh_auth(current_uid)
1382
+ current_uid = None
1383
+ continue
1384
+ return JSONResponse(
1385
+ status_code=502,
1386
+ content={"error": {"message": "Upstream Zai error after retry", "type": "server_error"}},
1387
+ )
1388
+ finally:
1389
+ if client is not None:
1390
+ if chat_id:
1391
+ await client.delete_chat(chat_id)
1392
+ await client.close()
1393
+ if current_uid:
1394
+ pool._release_by_user_id(current_uid)
1395
+ current_uid = None
1396
+
1397
+ return JSONResponse(status_code=502, content={"error": {"message": "Unexpected error", "type": "server_error"}})
1398
+
1399
+
1400
+ # ── Anthropic Claude Messages Endpoint ───────────────────────────────
1401
+
1402
+
1403
+ @app.post("/v1/messages")
1404
+ async def claude_messages(request: Request):
1405
+ """Anthropic Claude Messages API compatible endpoint for new-api."""
1406
+ body = await request.json()
1407
+ model: str = body.get("model", "glm-5")
1408
+ claude_msgs: list[dict] = body.get("messages", [])
1409
+ system = body.get("system")
1410
+ stream: bool = body.get("stream", False)
1411
+ tools_claude: list[dict] | None = body.get("tools")
1412
+ tool_choice = body.get("tool_choice")
1413
+ enable_thinking = _to_optional_bool(body.get("enable_thinking"))
1414
+
1415
+ openai_messages = claude_messages_to_openai(system, claude_msgs)
1416
+ openai_tools = claude_tools_to_openai(tools_claude)
1417
+
1418
+ prompt = ""
1419
+ for msg in reversed(openai_messages):
1420
+ if msg.get("role") == "user":
1421
+ prompt = _extract_text_from_content(msg.get("content", ""))
1422
+ break
1423
+ if not prompt:
1424
+ return JSONResponse(
1425
+ status_code=400,
1426
+ content={"type": "error", "error": {"type": "invalid_request_error", "message": "No user message"}},
1427
+ )
1428
+
1429
+ processed_messages = _preprocess_messages(openai_messages)
1430
+ has_fc = bool(openai_tools)
1431
+ if has_fc:
1432
+ fc_prompt = _generate_function_prompt(openai_tools, GLOBAL_TRIGGER_SIGNAL)
1433
+ fc_prompt += claude_tool_choice_prompt(tool_choice)
1434
+ processed_messages.insert(0, {"role": "system", "content": fc_prompt})
1435
+
1436
+ flat_messages = _flatten_messages_for_zai(processed_messages)
1437
+ usage_prompt = "\n".join(_extract_text_from_content(m.get("content", "")) for m in processed_messages)
1438
+
1439
+ msg_id = make_claude_id()
1440
+ req_id = f"req_{uuid.uuid4().hex[:10]}"
1441
+ logger.info("[claude][%s] model=%s stream=%s tools=%d", req_id, model, stream, len(openai_tools or []))
1442
+
1443
+ async def _run(auth):
1444
+ c = ZaiClient()
1445
+ try:
1446
+ c.token, c.user_id, c.username = auth["token"], auth["user_id"], auth["username"]
1447
+ chat = await c.create_chat(prompt, model, enable_thinking=enable_thinking)
1448
+ chat_id = chat["id"]
1449
+ up = c.chat_completions(
1450
+ chat_id=chat_id,
1451
+ messages=flat_messages,
1452
+ prompt=prompt,
1453
+ model=model,
1454
+ enable_thinking=enable_thinking,
1455
+ )
1456
+ return up, c, chat_id
1457
+ except Exception:
1458
+ await c.close()
1459
+ raise
1460
+
1461
+ if stream:
1462
+ return StreamingResponse(
1463
+ _claude_stream(msg_id, model, _run, has_fc, usage_prompt),
1464
+ media_type="text/event-stream",
1465
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
1466
+ )
1467
+
1468
+ return await _claude_sync(msg_id, model, _run, has_fc, usage_prompt)
1469
+
1470
+
1471
+ async def _claude_stream(msg_id, model, run_once, has_fc, usage_prompt):
1472
+ """Generator for Claude SSE streaming."""
1473
+ retried = False
1474
+ current_uid: str | None = None
1475
+ started = False
1476
+ while True:
1477
+ client = None
1478
+ chat_id = None
1479
+ try:
1480
+ await pool.ensure_auth()
1481
+ auth = pool.get_auth_snapshot()
1482
+ current_uid = auth["user_id"]
1483
+ input_tk = _estimate_tokens(usage_prompt)
1484
+ if not started:
1485
+ yield sse_message_start(msg_id, model, input_tk)
1486
+ yield sse_ping()
1487
+ started = True
1488
+ upstream, client, chat_id = await run_once(auth)
1489
+
1490
+ r_parts: list[str] = []
1491
+ answer_text = ""
1492
+ emitted_answer_chars = 0
1493
+ bidx = 0
1494
+ thinking_on = False
1495
+ text_on = False
1496
+ native_tcs: list[dict] = []
1497
+
1498
+ async for data in upstream:
1499
+ phase, delta = _extract_upstream_delta(data)
1500
+ up_tcs = _extract_upstream_tool_calls(data)
1501
+ if up_tcs:
1502
+ native_tcs.extend(up_tcs)
1503
+ continue
1504
+ if phase == "thinking" and delta:
1505
+ if not thinking_on and not text_on:
1506
+ yield sse_content_block_start(bidx, {"type": "thinking", "thinking": ""})
1507
+ thinking_on = True
1508
+ r_parts.append(delta)
1509
+ if thinking_on:
1510
+ yield sse_content_block_delta(bidx, {"type": "thinking_delta", "thinking": delta})
1511
+ elif delta:
1512
+ answer_text += delta
1513
+ safe_delta, emitted_answer_chars, _ = _drain_safe_answer_delta(
1514
+ answer_text,
1515
+ emitted_answer_chars,
1516
+ has_fc=has_fc,
1517
+ trigger_signal=GLOBAL_TRIGGER_SIGNAL,
1518
+ )
1519
+ if safe_delta:
1520
+ if thinking_on:
1521
+ yield sse_content_block_stop(bidx)
1522
+ bidx += 1
1523
+ thinking_on = False
1524
+ if not text_on:
1525
+ yield sse_content_block_start(bidx, {"type": "text", "text": ""})
1526
+ text_on = True
1527
+ yield sse_content_block_delta(bidx, {"type": "text_delta", "text": safe_delta})
1528
+
1529
+ # close thinking block
1530
+ if thinking_on:
1531
+ yield sse_content_block_stop(bidx)
1532
+ bidx += 1
1533
+
1534
+ all_tcs = native_tcs
1535
+ parsed_tcs: list[dict] = []
1536
+ if not all_tcs and has_fc:
1537
+ parsed_tcs = _parse_function_calls_xml(answer_text, GLOBAL_TRIGGER_SIGNAL)
1538
+ all_tcs = parsed_tcs
1539
+
1540
+ if all_tcs:
1541
+ answer_visible = answer_text
1542
+ if parsed_tcs:
1543
+ prefix_pos = _find_last_trigger_signal_outside_think(answer_text, GLOBAL_TRIGGER_SIGNAL)
1544
+ if prefix_pos < 0:
1545
+ prefix_pos = 0
1546
+ if prefix_pos > emitted_answer_chars:
1547
+ prefix_delta = answer_text[emitted_answer_chars:prefix_pos]
1548
+ if prefix_delta:
1549
+ if not text_on:
1550
+ yield sse_content_block_start(bidx, {"type": "text", "text": ""})
1551
+ text_on = True
1552
+ yield sse_content_block_delta(bidx, {"type": "text_delta", "text": prefix_delta})
1553
+ answer_visible = answer_text[:prefix_pos]
1554
+ elif emitted_answer_chars < len(answer_text):
1555
+ tail_delta = answer_text[emitted_answer_chars:]
1556
+ if tail_delta:
1557
+ if not text_on:
1558
+ yield sse_content_block_start(bidx, {"type": "text", "text": ""})
1559
+ text_on = True
1560
+ yield sse_content_block_delta(bidx, {"type": "text_delta", "text": tail_delta})
1561
+
1562
+ if text_on:
1563
+ yield sse_content_block_stop(bidx)
1564
+ bidx += 1
1565
+ text_on = False
1566
+ for tc in all_tcs:
1567
+ fn = tc.get("function", {}) if isinstance(tc.get("function"), dict) else tc
1568
+ nm = fn.get("name", tc.get("name", ""))
1569
+ args_s = fn.get("arguments", "{}")
1570
+ tid = tc.get("id", f"toolu_{uuid.uuid4().hex[:20]}").replace("call_", "toolu_")
1571
+ yield sse_content_block_start(bidx, {"type": "tool_use", "id": tid, "name": nm, "input": {}})
1572
+ yield sse_content_block_delta(bidx, {"type": "input_json_delta", "partial_json": args_s})
1573
+ yield sse_content_block_stop(bidx)
1574
+ bidx += 1
1575
+ out_tk = _estimate_tokens("".join(r_parts) + answer_visible)
1576
+ yield sse_message_delta("tool_use", out_tk)
1577
+ yield sse_message_stop()
1578
+ return
1579
+
1580
+ if emitted_answer_chars < len(answer_text):
1581
+ tail_delta = answer_text[emitted_answer_chars:]
1582
+ if tail_delta:
1583
+ if not text_on:
1584
+ yield sse_content_block_start(bidx, {"type": "text", "text": ""})
1585
+ text_on = True
1586
+ yield sse_content_block_delta(bidx, {"type": "text_delta", "text": tail_delta})
1587
+ if not text_on:
1588
+ yield sse_content_block_start(bidx, {"type": "text", "text": ""})
1589
+ yield sse_content_block_stop(bidx)
1590
+ out_tk = _estimate_tokens("".join(r_parts) + answer_text)
1591
+ yield sse_message_delta("end_turn", out_tk)
1592
+ yield sse_message_stop()
1593
+ return
1594
+
1595
+ except (httpcore.ReadTimeout, httpx.ReadTimeout) as e:
1596
+ logger.error("[claude-stream][%s] timeout: %s", msg_id, e)
1597
+ if client:
1598
+ if chat_id:
1599
+ await client.delete_chat(chat_id)
1600
+ await client.close()
1601
+ client = None
1602
+ if retried:
1603
+ yield sse_error("overloaded_error", "Upstream timeout")
1604
+ return
1605
+ retried = True
1606
+ await pool.refresh_auth(current_uid)
1607
+ current_uid = None
1608
+ continue
1609
+ except (httpcore.RemoteProtocolError, httpx.RemoteProtocolError) as e:
1610
+ logger.error("[claude-stream][%s] server disconnected: %s", msg_id, e)
1611
+ if client:
1612
+ if chat_id:
1613
+ await client.delete_chat(chat_id)
1614
+ await client.close()
1615
+ client = None
1616
+ if retried:
1617
+ yield sse_error("api_error", "Server disconnected, please retry")
1618
+ return
1619
+ retried = True
1620
+ await pool.refresh_auth(current_uid)
1621
+ current_uid = None
1622
+ continue
1623
+ except httpx.HTTPStatusError as e:
1624
+ is_concurrency = False
1625
+ try:
1626
+ err_body = e.response.json() if e.response else {}
1627
+ is_concurrency = err_body.get("code") == 429
1628
+ except Exception:
1629
+ pass
1630
+ logger.error("[claude-stream][%s] HTTP %s (concurrency=%s): %s", msg_id, e.response.status_code if e.response else '?', is_concurrency, e)
1631
+ if client:
1632
+ if chat_id:
1633
+ await client.delete_chat(chat_id)
1634
+ await client.close()
1635
+ client = None
1636
+ if retried:
1637
+ yield sse_error("overloaded_error" if is_concurrency else "api_error", "Upstream concurrency limit" if is_concurrency else "Upstream error after retry")
1638
+ return
1639
+ retried = True
1640
+ if is_concurrency:
1641
+ logger.info("[claude-stream][%s] concurrency limit hit, cleaning up chats...", msg_id)
1642
+ await pool.cleanup_chats()
1643
+ await asyncio.sleep(1)
1644
+ await pool.refresh_auth(current_uid)
1645
+ current_uid = None
1646
+ continue
1647
+ except Exception as e:
1648
+ logger.exception("[claude-stream][%s] error: %s", msg_id, e)
1649
+ if client:
1650
+ if chat_id:
1651
+ await client.delete_chat(chat_id)
1652
+ await client.close()
1653
+ client = None
1654
+ if retried:
1655
+ yield sse_error("api_error", "Upstream error after retry")
1656
+ return
1657
+ retried = True
1658
+ await pool.refresh_auth(current_uid)
1659
+ current_uid = None
1660
+ continue
1661
+ finally:
1662
+ if client:
1663
+ if chat_id:
1664
+ await client.delete_chat(chat_id)
1665
+ await client.close()
1666
+ if current_uid:
1667
+ pool._release_by_user_id(current_uid)
1668
+ current_uid = None
1669
+
1670
+
1671
+ async def _claude_sync(msg_id, model, run_once, has_fc, usage_prompt):
1672
+ """Non-streaming Claude response."""
1673
+ client = None
1674
+ chat_id = None
1675
+ current_uid: str | None = None
1676
+ for attempt in range(2):
1677
+ try:
1678
+ await pool.ensure_auth()
1679
+ auth = pool.get_auth_snapshot()
1680
+ current_uid = auth["user_id"]
1681
+ upstream, client, chat_id = await run_once(auth)
1682
+ r_parts, a_parts = [], []
1683
+ native_tcs: list[dict] = []
1684
+
1685
+ async for data in upstream:
1686
+ phase, delta = _extract_upstream_delta(data)
1687
+ up_tcs = _extract_upstream_tool_calls(data)
1688
+ if up_tcs:
1689
+ native_tcs.extend(up_tcs)
1690
+ elif phase == "thinking" and delta:
1691
+ r_parts.append(delta)
1692
+ elif delta:
1693
+ a_parts.append(delta)
1694
+
1695
+ answer = "".join(a_parts)
1696
+ all_tcs = native_tcs
1697
+ if not all_tcs and has_fc:
1698
+ all_tcs = _parse_function_calls_xml(answer, GLOBAL_TRIGGER_SIGNAL)
1699
+ if all_tcs:
1700
+ pp = _find_last_trigger_signal_outside_think(answer, GLOBAL_TRIGGER_SIGNAL)
1701
+ answer = answer[:pp].rstrip() if pp > 0 else ""
1702
+
1703
+ in_tk = _estimate_tokens(usage_prompt)
1704
+ out_tk = _estimate_tokens("".join(r_parts) + "".join(a_parts))
1705
+ return build_non_stream_response(msg_id, model, r_parts, answer, all_tcs or None, in_tk, out_tk)
1706
+
1707
+ except httpx.HTTPStatusError as e:
1708
+ is_concurrency = False
1709
+ try:
1710
+ err_body = e.response.json() if e.response else {}
1711
+ is_concurrency = err_body.get("code") == 429
1712
+ except Exception:
1713
+ pass
1714
+ logger.error("[claude-sync][%s] HTTP %s (concurrency=%s): %s", msg_id, e.response.status_code if e.response else '?', is_concurrency, e)
1715
+ if client:
1716
+ if chat_id:
1717
+ await client.delete_chat(chat_id)
1718
+ await client.close()
1719
+ client = None
1720
+ chat_id = None
1721
+ if attempt == 0:
1722
+ if is_concurrency:
1723
+ await pool.cleanup_chats()
1724
+ await asyncio.sleep(1)
1725
+ await pool.refresh_auth(current_uid)
1726
+ current_uid = None
1727
+ continue
1728
+ return JSONResponse(
1729
+ status_code=500,
1730
+ content={"type": "error", "error": {"type": "overloaded_error" if is_concurrency else "api_error", "message": "Upstream concurrency limit" if is_concurrency else "Upstream error"}},
1731
+ )
1732
+ except Exception as e:
1733
+ logger.exception("[claude-sync][%s] error: %s", msg_id, e)
1734
+ if client:
1735
+ if chat_id:
1736
+ await client.delete_chat(chat_id)
1737
+ await client.close()
1738
+ client = None
1739
+ chat_id = None
1740
+ if attempt == 0:
1741
+ await pool.refresh_auth(current_uid)
1742
+ current_uid = None
1743
+ continue
1744
+ return JSONResponse(
1745
+ status_code=500,
1746
+ content={"type": "error", "error": {"type": "api_error", "message": "Upstream error"}},
1747
+ )
1748
+ finally:
1749
+ if client:
1750
+ if chat_id:
1751
+ await client.delete_chat(chat_id)
1752
+ await client.close()
1753
+ if current_uid:
1754
+ pool._release_by_user_id(current_uid)
1755
+ current_uid = None
1756
+
1757
+ return JSONResponse(status_code=500, content={"type": "error", "error": {"type": "api_error", "message": "Unexpected"}})
1758
+
1759
+
1760
+ if __name__ == "__main__":
1761
+ uvicorn.run(app, host="0.0.0.0", port=30016)