| """OpenAI-compatible proxy server for chat.z.ai.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import os |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
|
|
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| from main import ZaiClient |
|
|
| |
|
|
|
|
| def _env_float(name: str, default: float) -> float: |
| raw = os.getenv(name) |
| if raw is None: |
| return default |
| try: |
| return float(raw) |
| except ValueError: |
| return default |
|
|
|
|
| AUTH_REFRESH_MIN_INTERVAL_SECONDS = _env_float( |
| "ZAI_AUTH_REFRESH_MIN_INTERVAL_SECONDS", 2.0 |
| ) |
|
|
|
|
| class SessionPool: |
| """Manages a single ZaiClient instance with automatic auth refresh.""" |
|
|
| def __init__(self) -> None: |
| self._client = ZaiClient() |
| self._lock = asyncio.Lock() |
| self._authed = False |
| self._last_auth_refresh_at = 0.0 |
| self._refresh_min_interval = max(0.0, AUTH_REFRESH_MIN_INTERVAL_SECONDS) |
|
|
| async def close(self) -> None: |
| await self._client.close() |
|
|
| async def ensure_auth(self) -> None: |
| """Authenticate if not already done.""" |
| if self._authed: |
| return |
| async with self._lock: |
| if self._authed: |
| return |
| await self._client.auth_as_guest() |
| self._authed = True |
| self._last_auth_refresh_at = time.monotonic() |
|
|
| async def refresh_auth(self, *, force: bool = False) -> None: |
| """Refresh the guest token with single-flight behavior.""" |
| now = time.monotonic() |
| if ( |
| not force |
| and self._authed |
| and now - self._last_auth_refresh_at < self._refresh_min_interval |
| ): |
| return |
| async with self._lock: |
| now = time.monotonic() |
| if ( |
| not force |
| and self._authed |
| and now - self._last_auth_refresh_at < self._refresh_min_interval |
| ): |
| return |
| await self._client.auth_as_guest() |
| self._authed = True |
| self._last_auth_refresh_at = time.monotonic() |
|
|
| async def get_models(self) -> list | dict: |
| await self.ensure_auth() |
| return await self._client.get_models() |
|
|
| async def create_chat(self, user_message: str, model: str) -> dict: |
| await self.ensure_auth() |
| return await self._client.create_chat(user_message, model) |
|
|
| def chat_completions( |
| self, |
| chat_id: str, |
| messages: list[dict], |
| prompt: str, |
| *, |
| model: str, |
| tools: list[dict] | None = None, |
| ): |
| return self._client.chat_completions( |
| chat_id=chat_id, |
| messages=messages, |
| prompt=prompt, |
| model=model, |
| tools=tools, |
| ) |
|
|
|
|
| pool = SessionPool() |
|
|
| |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_app: FastAPI): |
| await pool.ensure_auth() |
| yield |
| await pool.close() |
|
|
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| |
|
|
|
|
| def _make_id() -> str: |
| return f"chatcmpl-{uuid.uuid4().hex[:29]}" |
|
|
|
|
| def _openai_chunk( |
| completion_id: str, |
| model: str, |
| *, |
| content: str | None = None, |
| reasoning_content: str | None = None, |
| finish_reason: str | None = None, |
| ) -> dict: |
| delta: dict = {} |
| if content is not None: |
| delta["content"] = content |
| if reasoning_content is not None: |
| delta["reasoning_content"] = reasoning_content |
| return { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": delta, |
| "finish_reason": finish_reason, |
| } |
| ], |
| } |
|
|
|
|
| def _openai_completion( |
| completion_id: str, |
| model: str, |
| content: str, |
| reasoning_content: str, |
| ) -> dict: |
| message: dict = {"role": "assistant", "content": content} |
| if reasoning_content: |
| message["reasoning_content"] = reasoning_content |
| return { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": message, |
| "finish_reason": "stop", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0, |
| }, |
| } |
|
|
|
|
| |
|
|
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| models_resp = await pool.get_models() |
| |
| if isinstance(models_resp, dict) and "data" in models_resp: |
| models_list = models_resp["data"] |
| elif isinstance(models_resp, list): |
| models_list = models_resp |
| else: |
| models_list = [] |
|
|
| data = [] |
| for m in models_list: |
| mid = m.get("id") or m.get("name", "unknown") |
| data.append( |
| { |
| "id": mid, |
| "object": "model", |
| "created": 0, |
| "owned_by": "z.ai", |
| } |
| ) |
| return {"object": "list", "data": data} |
|
|
|
|
| |
|
|
|
|
| async def _do_request( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| ): |
| """Create a new chat and return (chat_id, async generator). |
| |
| Raises on Zai errors so the caller can retry. |
| """ |
| chat = await pool.create_chat(prompt, model) |
| chat_id = chat["id"] |
| gen = pool.chat_completions( |
| chat_id=chat_id, |
| messages=messages, |
| prompt=prompt, |
| model=model, |
| tools=tools, |
| ) |
| return chat_id, gen |
|
|
|
|
| async def _stream_response( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| ): |
| """SSE generator with one retry on error.""" |
| completion_id = _make_id() |
| retried = False |
|
|
| while True: |
| try: |
| _chat_id, gen = await _do_request(messages, model, prompt, tools) |
|
|
| |
| role_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": {"role": "assistant"}, |
| "finish_reason": None, |
| } |
| ], |
| } |
| yield f"data: {json.dumps(role_chunk, ensure_ascii=False)}\n\n" |
|
|
| tool_call_idx = 0 |
| async for data in gen: |
| phase = data.get("phase", "") |
| delta = data.get("delta_content", "") |
|
|
| |
| if data.get("tool_calls"): |
| for tc in data["tool_calls"]: |
| tc_chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": { |
| "tool_calls": [ |
| { |
| "index": tool_call_idx, |
| "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), |
| "type": "function", |
| "function": { |
| "name": tc.get("function", {}).get("name", ""), |
| "arguments": tc.get("function", {}).get("arguments", ""), |
| }, |
| } |
| ] |
| }, |
| "finish_reason": None, |
| } |
| ], |
| } |
| yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n" |
| tool_call_idx += 1 |
| elif phase == "thinking" and delta: |
| chunk = _openai_chunk( |
| completion_id, model, reasoning_content=delta |
| ) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| elif phase == "answer" and delta: |
| chunk = _openai_chunk(completion_id, model, content=delta) |
| yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
| elif phase == "done": |
| break |
|
|
| |
| finish_reason = "tool_calls" if tool_call_idx > 0 else "stop" |
| finish_chunk = _openai_chunk( |
| completion_id, model, finish_reason=finish_reason |
| ) |
| yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
|
|
| except Exception: |
| if retried: |
| |
| error = { |
| "error": { |
| "message": "Upstream Zai error after retry", |
| "type": "server_error", |
| } |
| } |
| yield f"data: {json.dumps(error)}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
| retried = True |
| await pool.refresh_auth() |
| |
|
|
|
|
| async def _sync_response( |
| messages: list[dict], |
| model: str, |
| prompt: str, |
| tools: list[dict] | None = None, |
| ) -> dict: |
| """Non-streaming response with one retry on error.""" |
| completion_id = _make_id() |
|
|
| for attempt in range(2): |
| try: |
| _chat_id, gen = await _do_request(messages, model, prompt, tools) |
|
|
| content_parts: list[str] = [] |
| reasoning_parts: list[str] = [] |
| tool_calls: list[dict] = [] |
|
|
| async for data in gen: |
| phase = data.get("phase", "") |
| delta = data.get("delta_content", "") |
|
|
| if data.get("tool_calls"): |
| for tc in data["tool_calls"]: |
| tool_calls.append( |
| { |
| "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), |
| "type": "function", |
| "function": { |
| "name": tc.get("function", {}).get("name", ""), |
| "arguments": tc.get("function", {}).get("arguments", ""), |
| }, |
| } |
| ) |
| elif phase == "thinking" and delta: |
| reasoning_parts.append(delta) |
| elif phase == "answer" and delta: |
| content_parts.append(delta) |
| elif phase == "done": |
| break |
|
|
| if tool_calls: |
| message: dict = {"role": "assistant", "content": None, "tool_calls": tool_calls} |
| if reasoning_parts: |
| message["reasoning_content"] = "".join(reasoning_parts) |
| return { |
| "id": completion_id, |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": message, |
| "finish_reason": "tool_calls", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0, |
| }, |
| } |
|
|
| return _openai_completion( |
| completion_id, |
| model, |
| "".join(content_parts), |
| "".join(reasoning_parts), |
| ) |
|
|
| except Exception: |
| if attempt == 0: |
| await pool.refresh_auth() |
| continue |
| return { |
| "error": { |
| "message": "Upstream Zai error after retry", |
| "type": "server_error", |
| } |
| } |
|
|
| |
| return {"error": {"message": "Unexpected error", "type": "server_error"}} |
|
|
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| body = await request.json() |
|
|
| model: str = body.get("model", "glm-5") |
| messages: list[dict] = body.get("messages", []) |
| stream: bool = body.get("stream", False) |
| tools: list[dict] | None = body.get("tools") |
|
|
| |
| prompt = "" |
| for msg in reversed(messages): |
| if msg.get("role") == "user": |
| content = msg.get("content", "") |
| if isinstance(content, str): |
| prompt = content |
| elif isinstance(content, list): |
| |
| prompt = " ".join( |
| p.get("text", "") |
| for p in content |
| if isinstance(p, dict) and p.get("type") == "text" |
| ) |
| break |
|
|
| |
| |
| parts: list[str] = [] |
| for msg in messages: |
| role = msg.get("role", "user") |
| content = msg.get("content", "") or "" |
| parts.append(f"<{role.upper()}>{content}</{role.upper()}>") |
| flat_content = "\n".join(parts) |
| messages = [{"role": "user", "content": flat_content}] |
|
|
| if not prompt: |
| return JSONResponse( |
| status_code=400, |
| content={ |
| "error": { |
| "message": "No user message found in messages", |
| "type": "invalid_request_error", |
| } |
| }, |
| ) |
|
|
| if stream: |
| return StreamingResponse( |
| _stream_response(messages, model, prompt, tools), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
| else: |
| result = await _sync_response(messages, model, prompt, tools) |
| if "error" in result: |
| return JSONResponse(status_code=502, content=result) |
| return result |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|