Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| from pathlib import Path | |
| import httpx | |
| from zai2api.account_pool import AccountPool | |
| from zai2api.config import Settings | |
| from zai2api.db import Database | |
| from zai2api.zai_client import SessionState, UpstreamResult | |
| class FakeClient: | |
| def __init__(self, *, name: str, answer: str, fail_status: int | None = None): | |
| self._name = name | |
| self._answer = answer | |
| self._fail_status = fail_status | |
| async def ensure_session(self, force_refresh: bool = False) -> SessionState: | |
| return SessionState( | |
| token=f"session-{self._name}", | |
| user_id=f"user-{self._name}", | |
| name=self._name, | |
| email=f"{self._name}@example.com", | |
| role="user", | |
| ) | |
| async def verify_completion_version(self) -> int: | |
| return 2 | |
| async def collect_prompt( | |
| self, | |
| *, | |
| prompt: str, | |
| model: str, | |
| enable_thinking: bool, | |
| auto_web_search: bool, | |
| ) -> UpstreamResult: | |
| if self._fail_status is not None: | |
| request = httpx.Request("POST", "https://example.com") | |
| response = httpx.Response(self._fail_status, request=request) | |
| raise httpx.HTTPStatusError("boom", request=request, response=response) | |
| return UpstreamResult( | |
| answer_text=f"{self._answer}:{prompt}", | |
| reasoning_text="reasoning", | |
| usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, | |
| finish_reason="stop", | |
| ) | |
| async def stream_prompt(self, **_: object): | |
| if False: | |
| yield None | |
| async def aclose(self) -> None: | |
| return None | |
| def make_settings(tmp_path: Path, **overrides: object) -> Settings: | |
| base = Settings( | |
| host="127.0.0.1", | |
| port=8000, | |
| log_level="info", | |
| zai_base_url="https://chat.z.ai", | |
| zai_jwt=None, | |
| zai_session_token=None, | |
| default_model="glm-5", | |
| request_timeout=120.0, | |
| database_path=str(tmp_path / "state.db"), | |
| panel_password_env=None, | |
| api_password_env=None, | |
| admin_cookie_name="zai2api_admin_session", | |
| admin_session_ttl_hours=24, | |
| admin_cookie_secure=False, | |
| ) | |
| for key, value in overrides.items(): | |
| setattr(base, key, value) | |
| return base | |
| def test_account_pool_rotates_enabled_accounts(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| db = Database(settings.database_path) | |
| db.initialize() | |
| first = db.upsert_account( | |
| jwt="jwt-a", | |
| session_token="token-a", | |
| user_id="user-a", | |
| email="a@example.com", | |
| name="a", | |
| ) | |
| second = db.upsert_account( | |
| jwt="jwt-b", | |
| session_token="token-b", | |
| user_id="user-b", | |
| email="b@example.com", | |
| name="b", | |
| ) | |
| def client_factory(jwt: str | None, session_token: str | None) -> FakeClient: | |
| if session_token == "token-a": | |
| return FakeClient(name="a", answer="alpha") | |
| if session_token == "token-b": | |
| return FakeClient(name="b", answer="beta") | |
| raise AssertionError(f"unexpected session token: {session_token}") | |
| pool = AccountPool(settings, db, client_factory=client_factory) | |
| first_result = asyncio.run( | |
| pool.collect_prompt( | |
| prompt="hello", | |
| model="glm-5", | |
| enable_thinking=True, | |
| auto_web_search=False, | |
| ) | |
| ) | |
| second_result = asyncio.run( | |
| pool.collect_prompt( | |
| prompt="hello", | |
| model="glm-5", | |
| enable_thinking=True, | |
| auto_web_search=False, | |
| ) | |
| ) | |
| assert first_result.answer_text == "alpha:hello" | |
| assert second_result.answer_text == "beta:hello" | |
| assert db.get_account(first.id).status == "active" | |
| assert db.get_account(second.id).status == "active" | |
| def test_account_pool_disables_unauthorized_account(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| db = Database(settings.database_path) | |
| db.initialize() | |
| bad = db.upsert_account( | |
| jwt="jwt-bad", | |
| session_token="token-bad", | |
| user_id="user-bad", | |
| email="bad@example.com", | |
| name="bad", | |
| ) | |
| db.upsert_account( | |
| jwt="jwt-good", | |
| session_token="token-good", | |
| user_id="user-good", | |
| email="good@example.com", | |
| name="good", | |
| ) | |
| def client_factory(jwt: str | None, session_token: str | None) -> FakeClient: | |
| if session_token == "token-bad": | |
| return FakeClient(name="bad", answer="bad", fail_status=401) | |
| if session_token == "token-good": | |
| return FakeClient(name="good", answer="good") | |
| raise AssertionError(f"unexpected session token: {session_token}") | |
| pool = AccountPool(settings, db, client_factory=client_factory) | |
| result = asyncio.run( | |
| pool.collect_prompt( | |
| prompt="hello", | |
| model="glm-5", | |
| enable_thinking=True, | |
| auto_web_search=False, | |
| ) | |
| ) | |
| bad_account = db.get_account(bad.id) | |
| assert result.answer_text == "good:hello" | |
| assert bad_account is not None | |
| assert bad_account.enabled is False | |
| assert bad_account.status == "invalid" | |
| def test_register_jwt_persists_account(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| db = Database(settings.database_path) | |
| db.initialize() | |
| def client_factory(jwt: str | None, session_token: str | None) -> FakeClient: | |
| assert jwt == "fresh-jwt" | |
| return FakeClient(name="fresh", answer="fresh") | |
| pool = AccountPool(settings, db, client_factory=client_factory) | |
| account = asyncio.run(pool.register_jwt("fresh-jwt")) | |
| assert account.user_id == "user-fresh" | |
| assert account.session_token == "session-fresh" | |
| assert db.count_accounts() == 1 | |
| def test_check_account_can_reenable_account(tmp_path: Path) -> None: | |
| settings = make_settings(tmp_path) | |
| db = Database(settings.database_path) | |
| db.initialize() | |
| account = db.upsert_account( | |
| jwt="jwt-a", | |
| session_token="token-a", | |
| user_id="user-a", | |
| email="a@example.com", | |
| name="a", | |
| enabled=False, | |
| status="invalid", | |
| last_error="HTTP 401", | |
| failure_count=2, | |
| ) | |
| def client_factory(jwt: str | None, session_token: str | None) -> FakeClient: | |
| return FakeClient(name="a", answer="alpha") | |
| pool = AccountPool(settings, db, client_factory=client_factory) | |
| updated = asyncio.run(pool.check_account(account.id)) | |
| assert updated.enabled is True | |
| assert updated.status == "active" | |
| assert updated.failure_count == 0 | |