"""OpenAI-compatible proxy server for chat.z.ai.""" from __future__ import annotations import asyncio import json import os import time import uuid from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from main import ZaiClient # ── Session Pool ───────────────────────────────────────────────────── def _env_float(name: str, default: float) -> float: raw = os.getenv(name) if raw is None: return default try: return float(raw) except ValueError: return default AUTH_REFRESH_MIN_INTERVAL_SECONDS = _env_float( "ZAI_AUTH_REFRESH_MIN_INTERVAL_SECONDS", 2.0 ) class SessionPool: """Manages a single ZaiClient instance with automatic auth refresh.""" def __init__(self) -> None: self._client = ZaiClient() self._lock = asyncio.Lock() self._authed = False self._last_auth_refresh_at = 0.0 self._refresh_min_interval = max(0.0, AUTH_REFRESH_MIN_INTERVAL_SECONDS) async def close(self) -> None: await self._client.close() async def ensure_auth(self) -> None: """Authenticate if not already done.""" if self._authed: return async with self._lock: if self._authed: return await self._client.auth_as_guest() self._authed = True self._last_auth_refresh_at = time.monotonic() async def refresh_auth(self, *, force: bool = False) -> None: """Refresh the guest token with single-flight behavior.""" now = time.monotonic() if ( not force and self._authed and now - self._last_auth_refresh_at < self._refresh_min_interval ): return async with self._lock: now = time.monotonic() if ( not force and self._authed and now - self._last_auth_refresh_at < self._refresh_min_interval ): return await self._client.auth_as_guest() self._authed = True self._last_auth_refresh_at = time.monotonic() async def get_models(self) -> list | dict: await self.ensure_auth() return await self._client.get_models() async def create_chat(self, user_message: str, model: str) -> dict: await self.ensure_auth() return await self._client.create_chat(user_message, model) def chat_completions( self, chat_id: str, messages: list[dict], prompt: str, *, model: str, tools: list[dict] | None = None, ): return self._client.chat_completions( chat_id=chat_id, messages=messages, prompt=prompt, model=model, tools=tools, ) pool = SessionPool() # ── FastAPI app ────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(_app: FastAPI): await pool.ensure_auth() yield await pool.close() app = FastAPI(lifespan=lifespan) # ── Helpers ────────────────────────────────────────────────────────── def _make_id() -> str: return f"chatcmpl-{uuid.uuid4().hex[:29]}" def _openai_chunk( completion_id: str, model: str, *, content: str | None = None, reasoning_content: str | None = None, finish_reason: str | None = None, ) -> dict: delta: dict = {} if content is not None: delta["content"] = content if reasoning_content is not None: delta["reasoning_content"] = reasoning_content return { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": delta, "finish_reason": finish_reason, } ], } def _openai_completion( completion_id: str, model: str, content: str, reasoning_content: str, ) -> dict: message: dict = {"role": "assistant", "content": content} if reasoning_content: message["reasoning_content"] = reasoning_content return { "id": completion_id, "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "message": message, "finish_reason": "stop", } ], "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, }, } # ── /v1/models ─────────────────────────────────────────────────────── @app.get("/v1/models") async def list_models(): models_resp = await pool.get_models() # Normalize to list if isinstance(models_resp, dict) and "data" in models_resp: models_list = models_resp["data"] elif isinstance(models_resp, list): models_list = models_resp else: models_list = [] data = [] for m in models_list: mid = m.get("id") or m.get("name", "unknown") data.append( { "id": mid, "object": "model", "created": 0, "owned_by": "z.ai", } ) return {"object": "list", "data": data} # ── /v1/chat/completions ──────────────────────────────────────────── async def _do_request( messages: list[dict], model: str, prompt: str, tools: list[dict] | None = None, ): """Create a new chat and return (chat_id, async generator). Raises on Zai errors so the caller can retry. """ chat = await pool.create_chat(prompt, model) chat_id = chat["id"] gen = pool.chat_completions( chat_id=chat_id, messages=messages, prompt=prompt, model=model, tools=tools, ) return chat_id, gen async def _stream_response( messages: list[dict], model: str, prompt: str, tools: list[dict] | None = None, ): """SSE generator with one retry on error.""" completion_id = _make_id() retried = False while True: try: _chat_id, gen = await _do_request(messages, model, prompt, tools) # Send initial role chunk role_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": {"role": "assistant"}, "finish_reason": None, } ], } yield f"data: {json.dumps(role_chunk, ensure_ascii=False)}\n\n" tool_call_idx = 0 async for data in gen: phase = data.get("phase", "") delta = data.get("delta_content", "") # Tool call events from Zai if data.get("tool_calls"): for tc in data["tool_calls"]: tc_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "tool_calls": [ { "index": tool_call_idx, "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), "type": "function", "function": { "name": tc.get("function", {}).get("name", ""), "arguments": tc.get("function", {}).get("arguments", ""), }, } ] }, "finish_reason": None, } ], } yield f"data: {json.dumps(tc_chunk, ensure_ascii=False)}\n\n" tool_call_idx += 1 elif phase == "thinking" and delta: chunk = _openai_chunk( completion_id, model, reasoning_content=delta ) yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" elif phase == "answer" and delta: chunk = _openai_chunk(completion_id, model, content=delta) yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" elif phase == "done": break # Send finish chunk finish_reason = "tool_calls" if tool_call_idx > 0 else "stop" finish_chunk = _openai_chunk( completion_id, model, finish_reason=finish_reason ) yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return except Exception: if retried: # Already retried once — yield error and stop error = { "error": { "message": "Upstream Zai error after retry", "type": "server_error", } } yield f"data: {json.dumps(error)}\n\n" yield "data: [DONE]\n\n" return retried = True await pool.refresh_auth() # Loop back and retry async def _sync_response( messages: list[dict], model: str, prompt: str, tools: list[dict] | None = None, ) -> dict: """Non-streaming response with one retry on error.""" completion_id = _make_id() for attempt in range(2): try: _chat_id, gen = await _do_request(messages, model, prompt, tools) content_parts: list[str] = [] reasoning_parts: list[str] = [] tool_calls: list[dict] = [] async for data in gen: phase = data.get("phase", "") delta = data.get("delta_content", "") if data.get("tool_calls"): for tc in data["tool_calls"]: tool_calls.append( { "id": tc.get("id", f"call_{uuid.uuid4().hex[:24]}"), "type": "function", "function": { "name": tc.get("function", {}).get("name", ""), "arguments": tc.get("function", {}).get("arguments", ""), }, } ) elif phase == "thinking" and delta: reasoning_parts.append(delta) elif phase == "answer" and delta: content_parts.append(delta) elif phase == "done": break if tool_calls: message: dict = {"role": "assistant", "content": None, "tool_calls": tool_calls} if reasoning_parts: message["reasoning_content"] = "".join(reasoning_parts) return { "id": completion_id, "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "message": message, "finish_reason": "tool_calls", } ], "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, }, } return _openai_completion( completion_id, model, "".join(content_parts), "".join(reasoning_parts), ) except Exception: if attempt == 0: await pool.refresh_auth() continue return { "error": { "message": "Upstream Zai error after retry", "type": "server_error", } } # Unreachable, but satisfy type checker return {"error": {"message": "Unexpected error", "type": "server_error"}} @app.post("/v1/chat/completions") async def chat_completions(request: Request): body = await request.json() model: str = body.get("model", "glm-5") messages: list[dict] = body.get("messages", []) stream: bool = body.get("stream", False) tools: list[dict] | None = body.get("tools") # Extract the last user message as the prompt for signature prompt = "" for msg in reversed(messages): if msg.get("role") == "user": content = msg.get("content", "") if isinstance(content, str): prompt = content elif isinstance(content, list): # Handle multimodal content array prompt = " ".join( p.get("text", "") for p in content if isinstance(p, dict) and p.get("type") == "text" ) break # Zai ignores multi-turn context — flatten all messages into a single # user message with tags so the model sees the full conversation. parts: list[str] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") or "" parts.append(f"<{role.upper()}>{content}") flat_content = "\n".join(parts) messages = [{"role": "user", "content": flat_content}] if not prompt: return JSONResponse( status_code=400, content={ "error": { "message": "No user message found in messages", "type": "invalid_request_error", } }, ) if stream: return StreamingResponse( _stream_response(messages, model, prompt, tools), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) else: result = await _sync_response(messages, model, prompt, tools) if "error" in result: return JSONResponse(status_code=502, content=result) return result # ── Entry point ────────────────────────────────────────────────────── if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)