z.ai / tests /test_upstream_dual_pool.py
Gemini CLI
Configure for Hugging Face Spaces
7864524
import asyncio
import types
from dataclasses import dataclass
from unittest.mock import AsyncMock
import pytest
from app.core import upstream as upstream_module
from app.core.upstream import UpstreamClient
from app.models.schemas import Message, OpenAIRequest
from app.utils.guest_session_pool import GuestSession, GuestSessionPool
AUTH_POOL_SIZE = 2
GUEST_POOL_SIZE = 2
AUTH_REQUEST_COUNT = 6
MIXED_REQUEST_DELAY = 0.01
def _make_request() -> OpenAIRequest:
return OpenAIRequest(
model="GLM-4.5",
messages=[Message(role="user", content="ping")],
stream=False,
)
def _make_guest_session(user_id: str) -> GuestSession:
return GuestSession(
token=f"guest-token-{user_id}",
user_id=user_id,
username=f"Guest-{user_id}",
)
@dataclass
class StubTokenPool:
tokens: list[str]
def __post_init__(self):
self.failure_tokens: list[str] = []
self.success_tokens: list[str] = []
def get_next_token(self, exclude_tokens=None):
excluded = exclude_tokens or set()
for token in self.tokens:
if token not in excluded:
return token
return None
async def record_token_failure(self, token: str, error=None, dao=None):
self.failure_tokens.append(token)
async def record_token_success(self, token: str, dao=None):
self.success_tokens.append(token)
def get_pool_status(self):
return {"available_tokens": len(self.tokens)}
class FakeResponse:
def __init__(self, status_code: int, text: str = "{}"):
self.status_code = status_code
self.text = text
@property
def is_success(self) -> bool:
return 200 <= self.status_code < 300
def _build_fake_async_client(handler):
class FakeAsyncClient:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def post(self, url, headers=None, json=None):
return await handler(headers or {})
return FakeAsyncClient
async def _build_guest_pool(
monkeypatch,
*,
pool_size: int,
user_ids: list[str],
) -> GuestSessionPool:
pool = GuestSessionPool(pool_size=pool_size)
queue = iter(user_ids)
async def fake_create_session() -> GuestSession:
return _make_guest_session(next(queue))
monkeypatch.setattr(pool, "_create_session", fake_create_session)
monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None))
monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True))
await pool.initialize()
await asyncio.sleep(0)
return pool
def _patch_upstream_dependencies(
monkeypatch,
*,
token_pool,
guest_pool,
async_client_cls,
):
monkeypatch.setattr(upstream_module, "get_token_pool", lambda: token_pool)
monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: guest_pool)
monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True)
monkeypatch.setattr(
upstream_module.settings,
"GUEST_POOL_SIZE",
guest_pool.pool_size if guest_pool else 1,
)
monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls)
def _bind_minimal_request_flow(client: UpstreamClient, captures: list[dict]):
async def fake_transform_request(
self,
request,
excluded_tokens=None,
excluded_guest_user_ids=None,
):
auth_info = await self.get_auth_info(
excluded_tokens=excluded_tokens,
excluded_guest_user_ids=excluded_guest_user_ids,
)
captures.append(dict(auth_info))
return {
"url": "https://upstream.test/chat",
"headers": {
"x-token": str(auth_info["token"]),
"x-token-source": str(auth_info["token_source"]),
"x-guest-user-id": str(auth_info.get("guest_user_id") or ""),
},
"body": {"model": request.model},
"token": auth_info["token"],
"chat_id": "chat-id",
"model": request.model,
"user_id": auth_info["user_id"],
"auth_mode": auth_info["auth_mode"],
"token_source": auth_info["token_source"],
"guest_user_id": auth_info["guest_user_id"],
}
async def fake_transform_response(self, response, request, transformed):
return {
"ok": response.is_success,
"token_source": transformed["token_source"],
"token": transformed["token"],
"guest_user_id": transformed["guest_user_id"],
}
client.transform_request = types.MethodType(fake_transform_request, client)
client.transform_response = types.MethodType(fake_transform_response, client)
async def _run_chat_requests(client: UpstreamClient, count: int) -> list[dict]:
tasks = [client.chat_completion(_make_request()) for _ in range(count)]
return await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_authenticated_tokens_are_used_before_guest_pool(monkeypatch):
token_pool = StubTokenPool(["auth-1"])
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=GUEST_POOL_SIZE,
user_ids=["guest-1", "guest-2"],
)
captures: list[dict] = []
acquire_calls = 0
async def counted_acquire(*args, **kwargs):
nonlocal acquire_calls
acquire_calls += 1
return await original_acquire(*args, **kwargs)
async def handler(headers):
await asyncio.sleep(MIXED_REQUEST_DELAY)
return FakeResponse(200)
client = UpstreamClient()
original_acquire = guest_pool.acquire
monkeypatch.setattr(guest_pool, "acquire", counted_acquire)
_bind_minimal_request_flow(client, captures)
_patch_upstream_dependencies(
monkeypatch,
token_pool=token_pool,
guest_pool=guest_pool,
async_client_cls=_build_fake_async_client(handler),
)
try:
results = await _run_chat_requests(client, AUTH_REQUEST_COUNT)
pool_status = guest_pool.get_pool_status()
finally:
await guest_pool.close()
assert all(result["ok"] is True for result in results)
assert all(item["token_source"] == "auth_pool" for item in captures)
assert acquire_calls == 0
assert token_pool.success_tokens == ["auth-1"] * AUTH_REQUEST_COUNT
assert token_pool.failure_tokens == []
assert pool_status["busy_sessions"] == 0
assert pool_status["available_sessions"] == GUEST_POOL_SIZE
@pytest.mark.asyncio
async def test_authenticated_401_retries_next_token_before_guest_fallback(monkeypatch):
token_pool = StubTokenPool(["auth-1", "auth-2"])
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=GUEST_POOL_SIZE,
user_ids=["guest-1", "guest-2"],
)
captures: list[dict] = []
acquire_calls = 0
async def counted_acquire(*args, **kwargs):
nonlocal acquire_calls
acquire_calls += 1
return await original_acquire(*args, **kwargs)
async def handler(headers):
token = headers["x-token"]
if token == "auth-1":
return FakeResponse(401, '{"message":"expired"}')
return FakeResponse(200)
client = UpstreamClient()
original_acquire = guest_pool.acquire
monkeypatch.setattr(guest_pool, "acquire", counted_acquire)
_bind_minimal_request_flow(client, captures)
_patch_upstream_dependencies(
monkeypatch,
token_pool=token_pool,
guest_pool=guest_pool,
async_client_cls=_build_fake_async_client(handler),
)
try:
result = await client.chat_completion(_make_request())
finally:
await guest_pool.close()
assert result["ok"] is True
assert [item["token"] for item in captures] == ["auth-1", "auth-2"]
assert [item["token_source"] for item in captures] == ["auth_pool", "auth_pool"]
assert token_pool.failure_tokens == ["auth-1"]
assert token_pool.success_tokens == ["auth-2"]
assert acquire_calls == 0
@pytest.mark.asyncio
async def test_authenticated_pool_exhaustion_falls_back_to_guest(monkeypatch):
token_pool = StubTokenPool(["auth-1", "auth-2"])
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=GUEST_POOL_SIZE,
user_ids=["guest-1", "guest-2", "guest-3"],
)
captures: list[dict] = []
async def handler(headers):
if headers["x-token-source"] == "auth_pool":
return FakeResponse(401, '{"message":"expired"}')
return FakeResponse(200)
client = UpstreamClient()
_bind_minimal_request_flow(client, captures)
_patch_upstream_dependencies(
monkeypatch,
token_pool=token_pool,
guest_pool=guest_pool,
async_client_cls=_build_fake_async_client(handler),
)
try:
result = await client.chat_completion(_make_request())
pool_status = guest_pool.get_pool_status()
finally:
await guest_pool.close()
assert result["ok"] is True
assert [item["token_source"] for item in captures] == [
"auth_pool",
"auth_pool",
"guest_pool",
]
assert token_pool.failure_tokens == ["auth-1", "auth-2"]
assert token_pool.success_tokens == []
assert result["guest_user_id"]
assert pool_status["busy_sessions"] == 0
@pytest.mark.asyncio
async def test_guest_retry_is_isolated_and_does_not_pollute_auth_stats(monkeypatch):
token_pool = StubTokenPool(["auth-1", "auth-2"])
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=GUEST_POOL_SIZE,
user_ids=["guest-1", "guest-2", "guest-3", "guest-4"],
)
captures: list[dict] = []
async def handler(headers):
source = headers["x-token-source"]
guest_user_id = headers["x-guest-user-id"]
if source == "auth_pool":
return FakeResponse(401, '{"message":"expired"}')
if guest_user_id == "guest-1":
return FakeResponse(401, '{"message":"expired"}')
return FakeResponse(200)
client = UpstreamClient()
_bind_minimal_request_flow(client, captures)
_patch_upstream_dependencies(
monkeypatch,
token_pool=token_pool,
guest_pool=guest_pool,
async_client_cls=_build_fake_async_client(handler),
)
try:
result = await client.chat_completion(_make_request())
pool_status = guest_pool.get_pool_status()
finally:
await guest_pool.close()
guest_ids = [
item["guest_user_id"]
for item in captures
if item["token_source"] == "guest_pool"
]
assert result["ok"] is True
assert [item["token"] for item in captures[:2]] == ["auth-1", "auth-2"]
assert token_pool.failure_tokens == ["auth-1", "auth-2"]
assert token_pool.success_tokens == []
assert guest_ids[0] == "guest-1"
assert guest_ids[1] != "guest-1"
assert pool_status["busy_sessions"] == 0
assert pool_status["valid_sessions"] == GUEST_POOL_SIZE
@pytest.mark.asyncio
async def test_cleanup_idle_chats_only_touches_idle_valid_sessions(monkeypatch):
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=3,
user_ids=["guest-1", "guest-2", "guest-3"],
)
deleted_user_ids: list[str] = []
async def fake_delete_all_chats(session: GuestSession):
deleted_user_ids.append(session.user_id)
return True
monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats)
guest_pool._sessions["guest-2"].active_requests = 1
try:
await guest_pool.cleanup_idle_chats()
deleted_before_close = list(deleted_user_ids)
finally:
await guest_pool.close()
assert deleted_before_close == ["guest-1", "guest-3"]
@pytest.mark.asyncio
async def test_report_failure_only_retires_target_guest_session(monkeypatch):
guest_pool = await _build_guest_pool(
monkeypatch,
pool_size=3,
user_ids=["guest-1", "guest-2", "guest-3", "guest-4"],
)
deleted_user_ids: list[str] = []
async def fake_delete_all_chats(session: GuestSession):
deleted_user_ids.append(session.user_id)
return True
monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats)
try:
await guest_pool.report_failure("guest-1")
await asyncio.sleep(0)
current_user_ids = set(guest_pool._sessions)
deleted_before_close = list(deleted_user_ids)
finally:
await guest_pool.close()
assert "guest-1" not in current_user_ids
assert "guest-2" in current_user_ids
assert "guest-3" in current_user_ids
assert "guest-4" in current_user_ids
assert deleted_before_close == ["guest-1"]