zai2api / tests /test_account_pool.py
xiaoyi7894's picture
Deploy zai2api to Hugging Face Space
b65f9e0
from __future__ import annotations
import asyncio
from pathlib import Path
import httpx
from zai2api.account_pool import AccountPool
from zai2api.config import Settings
from zai2api.db import Database
from zai2api.zai_client import SessionState, UpstreamResult
class FakeClient:
def __init__(self, *, name: str, answer: str, fail_status: int | None = None):
self._name = name
self._answer = answer
self._fail_status = fail_status
async def ensure_session(self, force_refresh: bool = False) -> SessionState:
return SessionState(
token=f"session-{self._name}",
user_id=f"user-{self._name}",
name=self._name,
email=f"{self._name}@example.com",
role="user",
)
async def verify_completion_version(self) -> int:
return 2
async def collect_prompt(
self,
*,
prompt: str,
model: str,
enable_thinking: bool,
auto_web_search: bool,
) -> UpstreamResult:
if self._fail_status is not None:
request = httpx.Request("POST", "https://example.com")
response = httpx.Response(self._fail_status, request=request)
raise httpx.HTTPStatusError("boom", request=request, response=response)
return UpstreamResult(
answer_text=f"{self._answer}:{prompt}",
reasoning_text="reasoning",
usage={"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
finish_reason="stop",
)
async def stream_prompt(self, **_: object):
if False:
yield None
async def aclose(self) -> None:
return None
def make_settings(tmp_path: Path, **overrides: object) -> Settings:
base = Settings(
host="127.0.0.1",
port=8000,
log_level="info",
zai_base_url="https://chat.z.ai",
zai_jwt=None,
zai_session_token=None,
default_model="glm-5",
request_timeout=120.0,
database_path=str(tmp_path / "state.db"),
panel_password_env=None,
api_password_env=None,
admin_cookie_name="zai2api_admin_session",
admin_session_ttl_hours=24,
admin_cookie_secure=False,
)
for key, value in overrides.items():
setattr(base, key, value)
return base
def test_account_pool_rotates_enabled_accounts(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
db = Database(settings.database_path)
db.initialize()
first = db.upsert_account(
jwt="jwt-a",
session_token="token-a",
user_id="user-a",
email="a@example.com",
name="a",
)
second = db.upsert_account(
jwt="jwt-b",
session_token="token-b",
user_id="user-b",
email="b@example.com",
name="b",
)
def client_factory(jwt: str | None, session_token: str | None) -> FakeClient:
if session_token == "token-a":
return FakeClient(name="a", answer="alpha")
if session_token == "token-b":
return FakeClient(name="b", answer="beta")
raise AssertionError(f"unexpected session token: {session_token}")
pool = AccountPool(settings, db, client_factory=client_factory)
first_result = asyncio.run(
pool.collect_prompt(
prompt="hello",
model="glm-5",
enable_thinking=True,
auto_web_search=False,
)
)
second_result = asyncio.run(
pool.collect_prompt(
prompt="hello",
model="glm-5",
enable_thinking=True,
auto_web_search=False,
)
)
assert first_result.answer_text == "alpha:hello"
assert second_result.answer_text == "beta:hello"
assert db.get_account(first.id).status == "active"
assert db.get_account(second.id).status == "active"
def test_account_pool_disables_unauthorized_account(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
db = Database(settings.database_path)
db.initialize()
bad = db.upsert_account(
jwt="jwt-bad",
session_token="token-bad",
user_id="user-bad",
email="bad@example.com",
name="bad",
)
db.upsert_account(
jwt="jwt-good",
session_token="token-good",
user_id="user-good",
email="good@example.com",
name="good",
)
def client_factory(jwt: str | None, session_token: str | None) -> FakeClient:
if session_token == "token-bad":
return FakeClient(name="bad", answer="bad", fail_status=401)
if session_token == "token-good":
return FakeClient(name="good", answer="good")
raise AssertionError(f"unexpected session token: {session_token}")
pool = AccountPool(settings, db, client_factory=client_factory)
result = asyncio.run(
pool.collect_prompt(
prompt="hello",
model="glm-5",
enable_thinking=True,
auto_web_search=False,
)
)
bad_account = db.get_account(bad.id)
assert result.answer_text == "good:hello"
assert bad_account is not None
assert bad_account.enabled is False
assert bad_account.status == "invalid"
def test_register_jwt_persists_account(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
db = Database(settings.database_path)
db.initialize()
def client_factory(jwt: str | None, session_token: str | None) -> FakeClient:
assert jwt == "fresh-jwt"
return FakeClient(name="fresh", answer="fresh")
pool = AccountPool(settings, db, client_factory=client_factory)
account = asyncio.run(pool.register_jwt("fresh-jwt"))
assert account.user_id == "user-fresh"
assert account.session_token == "session-fresh"
assert db.count_accounts() == 1
def test_check_account_can_reenable_account(tmp_path: Path) -> None:
settings = make_settings(tmp_path)
db = Database(settings.database_path)
db.initialize()
account = db.upsert_account(
jwt="jwt-a",
session_token="token-a",
user_id="user-a",
email="a@example.com",
name="a",
enabled=False,
status="invalid",
last_error="HTTP 401",
failure_count=2,
)
def client_factory(jwt: str | None, session_token: str | None) -> FakeClient:
return FakeClient(name="a", answer="alpha")
pool = AccountPool(settings, db, client_factory=client_factory)
updated = asyncio.run(pool.check_account(account.id))
assert updated.enabled is True
assert updated.status == "active"
assert updated.failure_count == 0