zarev / openai.py
bingn's picture
Upload openai.py
e8e0ef8 verified
"""OpenAI-compatible proxy server for chat.z.ai."""
from __future__ import annotations
import asyncio
import html
import json
import re
import time
import uuid
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from main import ZaiClient
# ── Session Pool ─────────────────────────────────────────────────────
class SessionPool:
"""Manages a single ZaiClient instance with automatic auth refresh."""
def __init__(self) -> None:
self._client = ZaiClient()
self._lock = asyncio.Lock()
self._authed = False
async def close(self) -> None:
await self._client.close()
async def ensure_auth(self) -> None:
"""Authenticate if not already done."""
if not self._authed:
await self._client.auth_as_guest()
self._authed = True
async def refresh_auth(self) -> None:
"""Force-refresh the guest token (locked to avoid concurrent rebuilds)."""
async with self._lock:
await self._client.auth_as_guest()
self._authed = True
async def get_models(self) -> list | dict:
await self.ensure_auth()
return await self._client.get_models()
async def create_chat(self, user_message: str, model: str) -> dict:
return await self._client.create_chat(user_message, model)
def chat_completions(
self,
chat_id: str,
messages: list[dict],
prompt: str,
*,
model: str,
tools: list[dict] | None = None,
):
return self._client.chat_completions(
chat_id=chat_id,
messages=messages,
prompt=prompt,
model=model,
tools=tools,
)
pool = SessionPool()
# ── FastAPI app ──────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(_app: FastAPI):
await pool.ensure_auth()
yield
await pool.close()
app = FastAPI(lifespan=lifespan)
# ── Helpers ──────────────────────────────────────────────────────────
def _make_id() -> str:
return f"chatcmpl-{uuid.uuid4().hex[:29]}"
def _openai_chunk(
completion_id: str,
model: str,
*,
content: str | None = None,
reasoning_content: str | None = None,
finish_reason: str | None = None,
) -> dict:
delta: dict = {}
if content is not None:
delta["content"] = content
if reasoning_content is not None:
delta["reasoning_content"] = reasoning_content
return {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
def _openai_completion(
completion_id: str,
model: str,
content: str,
reasoning_content: str,
) -> dict:
message: dict = {"role": "assistant", "content": content}
if reasoning_content:
message["reasoning_content"] = reasoning_content
return {
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
def _extract_text_from_content(content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for p in content:
if isinstance(p, dict) and p.get("type") == "text":
parts.append(str(p.get("text", "")))
return " ".join(parts).strip()
if content is None:
return ""
try:
return json.dumps(content, ensure_ascii=False)
except Exception:
return str(content)
def _build_tool_call_index(messages: list[dict]) -> dict[str, str]:
index: dict[str, str] = {}
for msg in messages:
if msg.get("role") != "assistant":
continue
tool_calls = msg.get("tool_calls")
if not isinstance(tool_calls, list):
continue
for tc in tool_calls:
if not isinstance(tc, dict):
continue
tc_id = tc.get("id")
fn = tc.get("function", {})
name = fn.get("name") if isinstance(fn, dict) else None
if isinstance(tc_id, str) and isinstance(name, str):
index[tc_id] = name
return index
def _render_assistant_tool_calls_xml(tool_calls: list[dict]) -> str:
blocks: list[str] = []
for tc in tool_calls:
if not isinstance(tc, dict):
continue
fn = tc.get("function", {})
if not isinstance(fn, dict):
continue
name = str(fn.get("name", "")).strip()
args_raw = fn.get("arguments", "{}")
if not name:
continue
if isinstance(args_raw, str):
args_text = args_raw
else:
try:
args_text = json.dumps(args_raw, ensure_ascii=False)
except Exception:
args_text = "{}"
blocks.append(
"<function_call>\n"
f"<name>{name}</name>\n"
f"<arguments>{args_text}</arguments>\n"
"</function_call>"
)
if not blocks:
return ""
return "<function_calls>\n" + "\n".join(blocks) + "\n</function_calls>"
def _flatten_messages_for_zai(messages: list[dict]) -> list[dict]:
tool_call_index = _build_tool_call_index(messages)
parts: list[str] = []
for msg in messages:
role = str(msg.get("role", "user")).lower()
content_text = _extract_text_from_content(msg.get("content", ""))
if role == "assistant" and isinstance(msg.get("tool_calls"), list):
xml_calls = _render_assistant_tool_calls_xml(msg["tool_calls"])
if content_text and xml_calls:
content_text = f"{content_text}\n{xml_calls}"
elif xml_calls:
content_text = xml_calls
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
tool_name = msg.get("name")
if not tool_name and isinstance(tool_call_id, str):
tool_name = tool_call_index.get(tool_call_id, "")
meta: list[str] = []
if tool_name:
meta.append(f'name="{tool_name}"')
if tool_call_id:
meta.append(f'tool_call_id="{tool_call_id}"')
meta_str = (" " + " ".join(meta)) if meta else ""
content_text = f"<TOOL_RESULT{meta_str}>\n{content_text}\n</TOOL_RESULT>"
parts.append(f"<{role.upper()}>{content_text}</{role.upper()}>")
flat_content = "\n".join(parts)
return [{"role": "user", "content": flat_content}]
def _tool_definitions_xml(tools: list[dict]) -> str:
blocks: list[str] = []
for t in tools:
if not isinstance(t, dict):
continue
if t.get("type") != "function":
continue
fn = t.get("function", {})
if not isinstance(fn, dict):
continue
name = str(fn.get("name", "")).strip()
if not name:
continue
desc = str(fn.get("description", "")).strip()
params = fn.get("parameters", {})
try:
params_json = json.dumps(params, ensure_ascii=False)
except Exception:
params_json = "{}"
blocks.append(
"<tool>\n"
f"<name>{name}</name>\n"
f"<description>{desc}</description>\n"
f"<parameters>{params_json}</parameters>\n"
"</tool>"
)
return "\n".join(blocks)
def _build_prompt_xml_instruction(tools: list[dict]) -> str:
tools_xml = _tool_definitions_xml(tools)
return (
"You can call tools. Available tools:\n"
"<tools>\n"
f"{tools_xml}\n"
"</tools>\n\n"
"If you need tools, respond ONLY with this exact XML format:\n"
"<function_calls>\n"
" <function_call>\n"
" <name>tool_name</name>\n"
" <arguments>{\"key\":\"value\"}</arguments>\n"
" </function_call>\n"
"</function_calls>\n\n"
"Rules:\n"
"1) arguments MUST be valid JSON object string.\n"
"2) Multiple calls: include multiple <function_call> inside one <function_calls>.\n"
"3) If no tool is needed, answer normally (no XML)."
)
def _inject_prompt_xml_system(messages: list[dict], tools: list[dict]) -> list[dict]:
instruction = _build_prompt_xml_instruction(tools)
injected = list(messages)
injected.insert(0, {"role": "system", "content": instruction})
return injected
def _clean_xml_text(text: str) -> str:
s = text.strip()
if s.startswith("```"):
s = re.sub(r"^```(?:xml|json)?\s*", "", s, flags=re.IGNORECASE)
s = re.sub(r"\s*```$", "", s)
return html.unescape(s.strip())
def _parse_prompt_xml_tool_calls(text: str) -> tuple[list[dict], str]:
"""Return (tool_calls, cleaned_text_without_xml_block)."""
if not text:
return [], text
pattern = re.compile(r"<function_calls\b[^>]*>(.*?)</function_calls>", re.IGNORECASE | re.DOTALL)
matches = list(pattern.finditer(text))
if not matches:
return [], text
last = matches[-1]
inner = last.group(1)
remaining = (text[: last.start()] + text[last.end() :]).strip()
call_pattern = re.compile(r"<function_call\b[^>]*>(.*?)</function_call>", re.IGNORECASE | re.DOTALL)
name_pattern = re.compile(r"<name\b[^>]*>(.*?)</name>", re.IGNORECASE | re.DOTALL)
args_pattern = re.compile(r"<arguments\b[^>]*>(.*?)</arguments>", re.IGNORECASE | re.DOTALL)
tool_calls: list[dict] = []
for m in call_pattern.finditer(inner):
block = m.group(1)
name_m = name_pattern.search(block)
args_m = args_pattern.search(block)
if not name_m:
continue
name = _clean_xml_text(name_m.group(1))
args_text = _clean_xml_text(args_m.group(1) if args_m else "{}")
# Normalize arguments to JSON string object when possible.
if not args_text:
args_text = "{}"
else:
try:
parsed = json.loads(args_text)
if isinstance(parsed, dict):
args_text = json.dumps(parsed, ensure_ascii=False)
else:
args_text = json.dumps({"value": parsed}, ensure_ascii=False)
except Exception:
args_text = json.dumps({"raw": args_text}, ensure_ascii=False)
tool_calls.append(
{
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {"name": name, "arguments": args_text},
}
)
return tool_calls, remaining
# ── /v1/models ───────────────────────────────────────────────────────
@app.get("/v1/models")
async def list_models():
models_resp = await pool.get_models()
# Normalize to list
if isinstance(models_resp, dict) and "data" in models_resp:
models_list = models_resp["data"]
elif isinstance(models_resp, list):
models_list = models_resp
else:
models_list = []
data = []
for m in models_list:
mid = m.get("id") or m.get("name", "unknown")
data.append(
{
"id": mid,
"object": "model",
"created": 0,
"owned_by": "z.ai",
}
)
return {"object": "list", "data": data}
# ── /v1/chat/completions ────────────────────────────────────────────
async def _do_request(
messages: list[dict],
model: str,
prompt: str,
tools: list[dict] | None = None,
):
"""Create a new chat and return (chat_id, async generator).
Raises on Zai errors so the caller can retry.
"""
chat = await pool.create_chat(prompt, model)
chat_id = chat["id"]
gen = pool.chat_completions(
chat_id=chat_id,
messages=messages,
prompt=prompt,
model=model,
tools=tools,
)
return chat_id, gen
async def _stream_response(
messages: list[dict],
model: str,
prompt: str,
tools: list[dict] | None = None,
*,
toolify_mode: str = "off",
):
"""SSE generator with one retry on error."""
completion_id = _make_id()
retried = False
while True:
try:
_chat_id, gen = await _do_request(messages, model, prompt, tools)
role_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
}
yield f"data: {json.dumps(role_chunk, ensure_ascii=False)}\n\n"
# prompt-xml mode buffers content to avoid leaking raw XML tokens in stream.
buffer_for_xml = toolify_mode == "prompt-xml" and bool(tools)
tool_call_idx = 0
reasoning_parts: list[str] = []
content_parts: list[str] = []
native_tool_calls: list[dict] = []
async for data in gen:
phase = data.get("phase", "")
delta = data.get("delta_content", "")
if data.get("tool_calls"):
for tc in data["tool_calls"]:
native_tool_calls.append(
{
"id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", ""),
},
}
)
continue
if phase == "thinking" and delta:
if buffer_for_xml:
reasoning_parts.append(delta)
else:
chunk = _openai_chunk(completion_id, model, reasoning_content=delta)
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
elif phase == "answer" and delta:
if buffer_for_xml:
content_parts.append(delta)
else:
chunk = _openai_chunk(completion_id, model, content=delta)
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
elif phase == "done":
break
if native_tool_calls:
for tc in native_tool_calls:
tc_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"tool_calls": [{"index": tool_call_idx, **tc}]},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n"
tool_call_idx += 1
finish_chunk = _openai_chunk(completion_id, model, finish_reason="tool_calls")
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
if buffer_for_xml:
reasoning_text = "".join(reasoning_parts)
content_text = "".join(content_parts)
parsed_tool_calls, cleaned_content = _parse_prompt_xml_tool_calls(content_text)
if reasoning_text:
chunk = _openai_chunk(completion_id, model, reasoning_content=reasoning_text)
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
if parsed_tool_calls:
for tc in parsed_tool_calls:
tc_chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"tool_calls": [{"index": tool_call_idx, **tc}]},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n"
tool_call_idx += 1
finish_chunk = _openai_chunk(completion_id, model, finish_reason="tool_calls")
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
if cleaned_content:
chunk = _openai_chunk(completion_id, model, content=cleaned_content)
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
finish_chunk = _openai_chunk(completion_id, model, finish_reason="stop")
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
finish_chunk = _openai_chunk(completion_id, model, finish_reason="stop")
yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
except Exception:
if retried:
error = {"error": {"message": "Upstream Zai error after retry", "type": "server_error"}}
yield f"data: {json.dumps(error)}\n\n"
yield "data: [DONE]\n\n"
return
retried = True
await pool.refresh_auth()
async def _sync_response(
messages: list[dict],
model: str,
prompt: str,
tools: list[dict] | None = None,
*,
toolify_mode: str = "off",
) -> dict:
"""Non-streaming response with one retry on error."""
completion_id = _make_id()
for attempt in range(2):
try:
_chat_id, gen = await _do_request(messages, model, prompt, tools)
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls: list[dict] = []
async for data in gen:
phase = data.get("phase", "")
delta = data.get("delta_content", "")
if data.get("tool_calls"):
for tc in data["tool_calls"]:
tool_calls.append(
{
"id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", ""),
},
}
)
elif phase == "thinking" and delta:
reasoning_parts.append(delta)
elif phase == "answer" and delta:
content_parts.append(delta)
elif phase == "done":
break
if tool_calls:
message: dict = {"role": "assistant", "content": None, "tool_calls": tool_calls}
if reasoning_parts:
message["reasoning_content"] = "".join(reasoning_parts)
return {
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
answer_content = "".join(content_parts)
if toolify_mode == "prompt-xml" and tools:
parsed_tool_calls, cleaned_content = _parse_prompt_xml_tool_calls(answer_content)
if parsed_tool_calls:
message: dict = {
"role": "assistant",
"content": None,
"tool_calls": parsed_tool_calls,
}
if reasoning_parts:
message["reasoning_content"] = "".join(reasoning_parts)
return {
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
answer_content = cleaned_content
return _openai_completion(
completion_id,
model,
answer_content,
"".join(reasoning_parts),
)
except Exception:
if attempt == 0:
await pool.refresh_auth()
continue
return {"error": {"message": "Upstream Zai error after retry", "type": "server_error"}}
return {"error": {"message": "Unexpected error", "type": "server_error"}}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
model: str = body.get("model", "glm-5")
messages: list[dict] = body.get("messages", [])
stream: bool = body.get("stream", False)
tools: list[dict] | None = body.get("tools")
toolify_mode_raw = body.get("toolify_mode")
if isinstance(toolify_mode_raw, str) and toolify_mode_raw in {"off", "prompt-xml"}:
toolify_mode = toolify_mode_raw
elif tools:
toolify_mode = "prompt-xml"
else:
toolify_mode = "off"
# Backward compatibility: old bool `toolify` switch.
toolify_raw = body.get("toolify", body.get("features", {}).get("toolify"))
if isinstance(toolify_raw, bool):
toolify_mode = "prompt-xml" if toolify_raw else "off"
# Extract the last user message as the signature prompt.
prompt = ""
for msg in reversed(messages):
if msg.get("role") == "user":
prompt = _extract_text_from_content(msg.get("content", ""))
break
if not prompt:
return JSONResponse(
status_code=400,
content={
"error": {
"message": "No user message found in messages",
"type": "invalid_request_error",
}
},
)
model_messages = messages
if toolify_mode == "prompt-xml" and tools:
model_messages = _inject_prompt_xml_system(messages, tools)
# Zai ignores native multi-turn context; flatten into one message.
flat_messages = _flatten_messages_for_zai(model_messages)
if stream:
return StreamingResponse(
_stream_response(flat_messages, model, prompt, tools, toolify_mode=toolify_mode),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
result = await _sync_response(flat_messages, model, prompt, tools, toolify_mode=toolify_mode)
if "error" in result:
return JSONResponse(status_code=502, content=result)
return result
# ── Entry point ──────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=30016)