| """OpenAI-compatible proxy server for chat.z.ai.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import html |
| import json |
| import re |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
|
|
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| from main import ZaiClient |
|
|
| |
|
|
|
|
| class SessionPool: |
| """Manages a single ZaiClient instance with automatic auth refresh.""" |
|
|
| def __init__(self) -> None: |
| self._client = ZaiClient() |
| self._lock = asyncio.Lock() |
| self._authed = False |
|
|
| async def close(self) -> None: |
| await self._client.close() |
|
|
| async def ensure_auth(self) -> None: |
| """Authenticate if not already done.""" |
| if not self._authed: |
| await self._client.auth_as_guest() |
| self._authed = True |
|
|
| async def refresh_auth(self) -> None: |
| """Force-refresh the guest token (locked to avoid concurrent rebuilds).""" |
| async with self._lock: |
| await self._client.auth_as_guest() |
| self._authed = True |
|
|
| async def get_models(self) -> list | dict: |
| await self.ensure_auth() |
| return await self._client.get_models() |
|
|
| async def create_chat(self, user_message: str, model: str) -> dict: |
| return await self._client.create_chat(user_message, model) |
|
|
| def chat_completions( |
| self, |
| chat_id: str, |
| messages: list[dict], |
| prompt: str, |
| *, |
| model: str, |
| tools: list[dict] | None = None, |
| ): |
| return self._client.chat_completions( |
| chat_id=chat_id, |
| messages=messages, |
| prompt=prompt, |
| model=model, |
| tools=tools, |
| ) |
|
|
|
|
| pool = SessionPool() |
|
|
| |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_app: FastAPI): |
| await pool.ensure_auth() |
| yield |
| await pool.close() |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| |
|
|
|
|
| def _make_id() -> str: |
| return f"chatcmpl-{uuid.uuid4().hex[:29]}" |
|
|
|
|
| def _openai_chunk( |
| completion_id: str, |
| model: str, |
| *, |
| content: str | None = None, |
| reasoning_content: str | None = None, |
| finish_reason: str | None = None, |
| ) -> dict: |
| delta: dict = {} |
| if content is not None: |
| delta["content"] = content |
| if reasoning_content is not None: |
| delta["reasoning_content"] = reasoning_content |
| return { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": delta, |
| "finish_reason": finish_reason, |
| } |
| ], |
| } |
|
|
|
|
| def _openai_completion( |
| completion_id: str, |
| model: str, |
| content: str, |
| reasoning_content: str, |
| ) -> dict: |
| message: dict = {"role": "assistant", "content": content} |
| if reasoning_content: |
| message["reasoning_content"] = reasoning_content |
| return { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": message, |
| "finish_reason": "stop", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0, |
| }, |
| } |
|
|
|
|
| def _extract_text_from_content(content: object) -> str: |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts: list[str] = [] |
| for p in content: |
| if isinstance(p, dict) and p.get("type") == "text": |
| parts.append(str(p.get("text", ""))) |
| return " ".join(parts).strip() |
| if content is None: |
| return "" |
| try: |
| return json.dumps(content, ensure_ascii=False) |
| except Exception: |
| return str(content) |
|
|
|
|
| def _build_tool_call_index(messages: list[dict]) -> dict[str, str]: |
| index: dict[str, str] = {} |
| for msg in messages: |
| if msg.get("role") != "assistant": |
| continue |
| tool_calls = msg.get("tool_calls") |
| if not isinstance(tool_calls, list): |
| continue |
| for tc in tool_calls: |
| if not isinstance(tc, dict): |
| continue |
| tc_id = tc.get("id") |
| fn = tc.get("function", {}) |
| name = fn.get("name") if isinstance(fn, dict) else None |
| if isinstance(tc_id, str) and isinstance(name, str): |
| index[tc_id] = name |
| return index |
|
|
|
|
| def _render_assistant_tool_calls_xml(tool_calls: list[dict]) -> str: |
| blocks: list[str] = [] |
| for tc in tool_calls: |
| if not isinstance(tc, dict): |
| continue |
| fn = tc.get("function", {}) |
| if not isinstance(fn, dict): |
| continue |
| name = str(fn.get("name", "")).strip() |
| args_raw = fn.get("arguments", "{}") |
| if not name: |
| continue |
| if isinstance(args_raw, str): |
| args_text = args_raw |
| else: |
| try: |
| args_text = json.dumps(args_raw, ensure_ascii=False) |
| except Exception: |
| args_text = "{}" |
| blocks.append( |
| "<function_call>\n" |
| f"<name>{name}</name>\n" |
| f"<arguments>{args_text}</arguments>\n" |
| "</function_call>" |
| ) |
| if not blocks: |
| return "" |
| return "<function_calls>\n" + "\n".join(blocks) + "\n</function_calls>" |
|
|
|
|
| def _flatten_messages_for_zai(messages: list[dict]) -> list[dict]: |
| tool_call_index = _build_tool_call_index(messages) |
| parts: list[str] = [] |
|
|
| for msg in messages: |
| role = str(msg.get("role", "user")).lower() |
| content_text = _extract_text_from_content(msg.get("content", "")) |
|
|
| if role == "assistant" and isinstance(msg.get("tool_calls"), list): |
| xml_calls = _render_assistant_tool_calls_xml(msg["tool_calls"]) |
| if content_text and xml_calls: |
| content_text = f"{content_text}\n{xml_calls}" |
| elif xml_calls: |
| content_text = xml_calls |
|
|
| elif role == "tool": |
| tool_call_id = msg.get("tool_call_id") |
| tool_name = msg.get("name") |
| if not tool_name and isinstance(tool_call_id, str): |
| tool_name = tool_call_index.get(tool_call_id, "") |
| meta: list[str] = [] |
| if tool_name: |
| meta.append(f'name="{tool_name}"') |
| if tool_call_id: |
| meta.append(f'tool_call_id="{tool_call_id}"') |
| meta_str = (" " + " ".join(meta)) if meta else "" |
| content_text = f"<TOOL_RESULT{meta_str}>\n{content_text}\n</TOOL_RESULT>" |
|
|
| parts.append(f"<{role.upper()}>{content_text}</{role.upper()}>") |
|
|
| flat_content = "\n".join(parts) |
| return [{"role": "user", "content": flat_content}] |
|
|
|
|
| def _tool_definitions_xml(tools: list[dict]) -> str: |
| blocks: list[str] = [] |
| for t in tools: |
| if not isinstance(t, dict): |
| continue |
| if t.get("type") != "function": |
| continue |
| fn = t.get("function", {}) |
| if not isinstance(fn, dict): |
| continue |
| name = str(fn.get("name", "")).strip() |
| if not name: |
| continue |
| desc = str(fn.get("description", "")).strip() |
| params = fn.get("parameters", {}) |
| try: |
| params_json = json.dumps(params, ensure_ascii=False) |
| except Exception: |
| params_json = "{}" |
| blocks.append( |
| "<tool>\n" |
| f"<name>{name}</name>\n" |
| f"<description>{desc}</description>\n" |
| f"<parameters>{params_json}</parameters>\n" |
| "</tool>" |
| ) |
| return "\n".join(blocks) |
|
|
|
|
| def _build_prompt_xml_instruction(tools: list[dict]) -> str: |
| tools_xml = _tool_definitions_xml(tools) |
| return ( |
| "You can call tools. Available tools:\n" |
| "<tools>\n" |
| f"{tools_xml}\n" |
| "</tools>\n\n" |
| "If you need tools, respond ONLY with this exact XML format:\n" |
| "<function_calls>\n" |
| " <function_call>\n" |
| " <name>tool_name</name>\n" |
| " <arguments>{\"key\":\"value\"}</arguments>\n" |
| " </function_call>\n" |
| "</function_calls>\n\n" |
| "Rules:\n" |
| "1) arguments MUST be valid JSON object string.\n" |
| "2) Multiple calls: include multiple <function_call> inside one <function_calls>.\n" |
| "3) If no tool is needed, answer normally (no XML)." |
| ) |
|
|
|
|
| def _inject_prompt_xml_system(messages: list[dict], tools: list[dict]) -> list[dict]: |
| instruction = _build_prompt_xml_instruction(tools) |
| injected = list(messages) |
| injected.insert(0, {"role": "system", "content": instruction}) |
| return injected |
|
|
|
|
| def _clean_xml_text(text: str) -> str: |
| s = text.strip() |
| if s.startswith("```"): |
| s = re.sub(r"^```(?:xml|json)?\s*", "", s, flags=re.IGNORECASE) |
| s = re.sub(r"\s*```$", "", s) |
| return html.unescape(s.strip()) |
|
|
|
|
| def _parse_prompt_xml_tool_calls(text: str) -> tuple[list[dict], str]: |
| """Return (tool_calls, cleaned_text_without_xml_block).""" |
| if not text: |
| return [], text |
|
|
| pattern = re.compile(r"<function_calls\b[^>]*>(.*?)</function_calls>", re.IGNORECASE | re.DOTALL) |
| matches = list(pattern.finditer(text)) |
| if not matches: |
| return [], text |
|
|
| last = matches[-1] |
| inner = last.group(1) |
| remaining = (text[: last.start()] + text[last.end() :]).strip() |
|
|
| call_pattern = re.compile(r"<function_call\b[^>]*>(.*?)</function_call>", re.IGNORECASE | re.DOTALL) |
| name_pattern = re.compile(r"<name\b[^>]*>(.*?)</name>", re.IGNORECASE | re.DOTALL) |
| args_pattern = re.compile(r"<arguments\b[^>]*>(.*?)</arguments>", re.IGNORECASE | re.DOTALL) |
|
|
| tool_calls: list[dict] = [] |
| for m in call_pattern.finditer(inner): |
| block = m.group(1) |
| name_m = name_pattern.search(block) |
| args_m = args_pattern.search(block) |
| if not name_m: |
| continue |
|
|
| name = _clean_xml_text(name_m.group(1)) |
| args_text = _clean_xml_text(args_m.group(1) if args_m else "{}") |
|
|
| |
| if not args_text: |
| args_text = "{}" |
| else: |
| try: |
| parsed = json.loads(args_text) |
| if isinstance(parsed, dict): |
| args_text = json.dumps(parsed, ensure_ascii=False) |
| else: |
| args_text = json.dumps({"value": parsed}, ensure_ascii=False) |
| except Exception: |
| args_text = json.dumps({"raw": args_text}, ensure_ascii=False) |
|
|
| tool_calls.append( |
| { |
| "id": f"call_{uuid.uuid4().hex[:24]}", |
| "type": "function", |
| "function": {"name": name, "arguments": args_text}, |
| } |
| ) |
|
|
| return tool_calls, remaining |
|
|
|
|
| |
|
|
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| models_resp = await pool.get_models() |
| |
| if isinstance(models_resp, dict) and "data" in models_resp: |
| models_list = models_resp["data"] |
| elif isinstance(models_resp, list): |
| models_list = models_resp |
| else: |
| models_list = [] |
|
|
| data = [] |
| for m in models_list: |
| mid = m.get("id") or m.get("name", "unknown") |
| data.append( |
| { |
| "id": mid, |
| "object": "model", |
| "created": 0, |
| "owned_by": "z.ai", |
| } |
| ) |
| return {"object": "list", "data": data} |
|
|
|
|
| |
|
|
|
|
| async def _do_request( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| ): |
| """Create a new chat and return (chat_id, async generator). |
| |
| Raises on Zai errors so the caller can retry. |
| """ |
| chat = await pool.create_chat(prompt, model) |
| chat_id = chat["id"] |
| gen = pool.chat_completions( |
| chat_id=chat_id, |
| messages=messages, |
| prompt=prompt, |
| model=model, |
| tools=tools, |
| ) |
| return chat_id, gen |
|
|
|
|
| async def _stream_response( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| *, |
| toolify_mode: str = "off", |
| ): |
| """SSE generator with one retry on error.""" |
| completion_id = _make_id() |
| retried = False |
|
|
| while True: |
| try: |
| _chat_id, gen = await _do_request(messages, model, prompt, tools) |
|
|
| role_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], |
| } |
| yield f"data: {json.dumps(role_chunk, ensure_ascii=False)}\n\n" |
|
|
| |
| buffer_for_xml = toolify_mode == "prompt-xml" and bool(tools) |
| tool_call_idx = 0 |
| reasoning_parts: list[str] = [] |
| content_parts: list[str] = [] |
| native_tool_calls: list[dict] = [] |
|
|
| async for data in gen: |
| phase = data.get("phase", "") |
| delta = data.get("delta_content", "") |
|
|
| if data.get("tool_calls"): |
| for tc in data["tool_calls"]: |
| native_tool_calls.append( |
| { |
| "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), |
| "type": "function", |
| "function": { |
| "name": tc.get("function", {}).get("name", ""), |
| "arguments": tc.get("function", {}).get("arguments", ""), |
| }, |
| } |
| ) |
| continue |
|
|
| if phase == "thinking" and delta: |
| if buffer_for_xml: |
| reasoning_parts.append(delta) |
| else: |
| chunk = _openai_chunk(completion_id, model, reasoning_content=delta) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| elif phase == "answer" and delta: |
| if buffer_for_xml: |
| content_parts.append(delta) |
| else: |
| chunk = _openai_chunk(completion_id, model, content=delta) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| elif phase == "done": |
| break |
|
|
| if native_tool_calls: |
| for tc in native_tool_calls: |
| tc_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": {"tool_calls": [{"index": tool_call_idx, **tc}]}, |
| "finish_reason": None, |
| } |
| ], |
| } |
| yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n" |
| tool_call_idx += 1 |
|
|
| finish_chunk = _openai_chunk(completion_id, model, finish_reason="tool_calls") |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
|
|
| if buffer_for_xml: |
| reasoning_text = "".join(reasoning_parts) |
| content_text = "".join(content_parts) |
| parsed_tool_calls, cleaned_content = _parse_prompt_xml_tool_calls(content_text) |
|
|
| if reasoning_text: |
| chunk = _openai_chunk(completion_id, model, reasoning_content=reasoning_text) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
| if parsed_tool_calls: |
| for tc in parsed_tool_calls: |
| tc_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": {"tool_calls": [{"index": tool_call_idx, **tc}]}, |
| "finish_reason": None, |
| } |
| ], |
| } |
| yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n" |
| tool_call_idx += 1 |
|
|
| finish_chunk = _openai_chunk(completion_id, model, finish_reason="tool_calls") |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
|
|
| if cleaned_content: |
| chunk = _openai_chunk(completion_id, model, content=cleaned_content) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
| finish_chunk = _openai_chunk(completion_id, model, finish_reason="stop") |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
|
|
| finish_chunk = _openai_chunk(completion_id, model, finish_reason="stop") |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
|
|
| except Exception: |
| if retried: |
| error = {"error": {"message": "Upstream Zai error after retry", "type": "server_error"}} |
| yield f"data: {json.dumps(error)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
| retried = True |
| await pool.refresh_auth() |
|
|
|
|
| async def _sync_response( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| *, |
| toolify_mode: str = "off", |
| ) -> dict: |
| """Non-streaming response with one retry on error.""" |
| completion_id = _make_id() |
|
|
| for attempt in range(2): |
| try: |
| _chat_id, gen = await _do_request(messages, model, prompt, tools) |
|
|
| content_parts: list[str] = [] |
| reasoning_parts: list[str] = [] |
| tool_calls: list[dict] = [] |
|
|
| async for data in gen: |
| phase = data.get("phase", "") |
| delta = data.get("delta_content", "") |
|
|
| if data.get("tool_calls"): |
| for tc in data["tool_calls"]: |
| tool_calls.append( |
| { |
| "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), |
| "type": "function", |
| "function": { |
| "name": tc.get("function", {}).get("name", ""), |
| "arguments": tc.get("function", {}).get("arguments", ""), |
| }, |
| } |
| ) |
| elif phase == "thinking" and delta: |
| reasoning_parts.append(delta) |
| elif phase == "answer" and delta: |
| content_parts.append(delta) |
| elif phase == "done": |
| break |
|
|
| if tool_calls: |
| message: dict = {"role": "assistant", "content": None, "tool_calls": tool_calls} |
| if reasoning_parts: |
| message["reasoning_content"] = "".join(reasoning_parts) |
| return { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}], |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, |
| } |
|
|
| answer_content = "".join(content_parts) |
| if toolify_mode == "prompt-xml" and tools: |
| parsed_tool_calls, cleaned_content = _parse_prompt_xml_tool_calls(answer_content) |
| if parsed_tool_calls: |
| message: dict = { |
| "role": "assistant", |
| "content": None, |
| "tool_calls": parsed_tool_calls, |
| } |
| if reasoning_parts: |
| message["reasoning_content"] = "".join(reasoning_parts) |
| return { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [{"index": 0, "message": message, "finish_reason": "tool_calls"}], |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, |
| } |
| answer_content = cleaned_content |
|
|
| return _openai_completion( |
| completion_id, |
| model, |
| answer_content, |
| "".join(reasoning_parts), |
| ) |
|
|
| except Exception: |
| if attempt == 0: |
| await pool.refresh_auth() |
| continue |
| return {"error": {"message": "Upstream Zai error after retry", "type": "server_error"}} |
|
|
| return {"error": {"message": "Unexpected error", "type": "server_error"}} |
|
|
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| body = await request.json() |
|
|
| model: str = body.get("model", "glm-5") |
| messages: list[dict] = body.get("messages", []) |
| stream: bool = body.get("stream", False) |
| tools: list[dict] | None = body.get("tools") |
|
|
| toolify_mode_raw = body.get("toolify_mode") |
| if isinstance(toolify_mode_raw, str) and toolify_mode_raw in {"off", "prompt-xml"}: |
| toolify_mode = toolify_mode_raw |
| elif tools: |
| toolify_mode = "prompt-xml" |
| else: |
| toolify_mode = "off" |
|
|
| |
| toolify_raw = body.get("toolify", body.get("features", {}).get("toolify")) |
| if isinstance(toolify_raw, bool): |
| toolify_mode = "prompt-xml" if toolify_raw else "off" |
|
|
| |
| prompt = "" |
| for msg in reversed(messages): |
| if msg.get("role") == "user": |
| prompt = _extract_text_from_content(msg.get("content", "")) |
| break |
|
|
| if not prompt: |
| return JSONResponse( |
| status_code=400, |
| content={ |
| "error": { |
| "message": "No user message found in messages", |
| "type": "invalid_request_error", |
| } |
| }, |
| ) |
|
|
| model_messages = messages |
| if toolify_mode == "prompt-xml" and tools: |
| model_messages = _inject_prompt_xml_system(messages, tools) |
|
|
| |
| flat_messages = _flatten_messages_for_zai(model_messages) |
|
|
| if stream: |
| return StreamingResponse( |
| _stream_response(flat_messages, model, prompt, tools, toolify_mode=toolify_mode), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|
| result = await _sync_response(flat_messages, model, prompt, tools, toolify_mode=toolify_mode) |
| if "error" in result: |
| return JSONResponse(status_code=502, content=result) |
| return result |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=30016) |
|
|