Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| from fastapi.testclient import TestClient | |
| from zai2api.config import Settings | |
| from zai2api.db import Database | |
| from zai2api.server import create_app | |
| from zai2api.zai_client import UpstreamChunk | |
| class ErroringStreamPool: | |
| def __init__(self, *, mode: str): | |
| self.mode = mode | |
| async def collect_prompt(self, **_: object): | |
| raise AssertionError("collect_prompt should not be used in streaming tests") | |
| async def stream_prompt(self, **_: object): | |
| if self.mode == "chunk_error": | |
| yield UpstreamChunk(phase=None, text="", error="upstream said no") | |
| return | |
| raise RuntimeError("upstream stream exploded") | |
| def make_settings(tmp_path: Path) -> Settings: | |
| return Settings( | |
| host="127.0.0.1", | |
| port=8000, | |
| log_level="info", | |
| zai_base_url="https://chat.z.ai", | |
| zai_jwt=None, | |
| zai_session_token=None, | |
| default_model="glm-5", | |
| request_timeout=120.0, | |
| database_path=str(tmp_path / "state.db"), | |
| panel_password_env=None, | |
| api_password_env=None, | |
| admin_cookie_name="zai2api_admin_session", | |
| admin_session_ttl_hours=24, | |
| admin_cookie_secure=False, | |
| ) | |
| def test_chat_completion_stream_reports_error_in_band(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| app = create_app(settings, prompt_pool=ErroringStreamPool(mode="chunk_error")) | |
| with TestClient(app) as client: | |
| response = client.post( | |
| "/v1/chat/completions", | |
| json={"model": "glm-5", "stream": True, "messages": [{"role": "user", "content": "hi"}]}, | |
| ) | |
| assert response.status_code == 200 | |
| assert '"type": "upstream_error"' in response.text | |
| assert '"message": "upstream said no"' in response.text | |
| assert "data: [DONE]" in response.text | |
| logs = Database(settings.database_path).list_logs(limit=20) | |
| assert any(item.message == "Streaming chat completion request failed" for item in logs) | |
| def test_responses_stream_reports_runtime_error_in_band(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| app = create_app(settings, prompt_pool=ErroringStreamPool(mode="runtime_error")) | |
| with TestClient(app) as client: | |
| response = client.post( | |
| "/v1/responses", | |
| json={"model": "glm-5", "stream": True, "input": "hi"}, | |
| ) | |
| assert response.status_code == 200 | |
| assert '"type": "response.failed"' in response.text | |
| assert '"status": "failed"' in response.text | |
| assert '"message": "upstream stream exploded"' in response.text | |
| assert "data: [DONE]" in response.text | |
| logs = Database(settings.database_path).list_logs(limit=20) | |
| assert any(item.message == "Streaming responses request failed" for item in logs) | |