Spaces:
Sleeping
Sleeping
| import asyncio | |
| import types | |
| from dataclasses import dataclass | |
| from unittest.mock import AsyncMock | |
| import pytest | |
| from app.core import upstream as upstream_module | |
| from app.core.upstream import UpstreamClient | |
| from app.models.schemas import Message, OpenAIRequest | |
| from app.utils.guest_session_pool import GuestSession, GuestSessionPool | |
| AUTH_POOL_SIZE = 2 | |
| GUEST_POOL_SIZE = 2 | |
| AUTH_REQUEST_COUNT = 6 | |
| MIXED_REQUEST_DELAY = 0.01 | |
| def _make_request() -> OpenAIRequest: | |
| return OpenAIRequest( | |
| model="GLM-4.5", | |
| messages=[Message(role="user", content="ping")], | |
| stream=False, | |
| ) | |
| def _make_guest_session(user_id: str) -> GuestSession: | |
| return GuestSession( | |
| token=f"guest-token-{user_id}", | |
| user_id=user_id, | |
| username=f"Guest-{user_id}", | |
| ) | |
| class StubTokenPool: | |
| tokens: list[str] | |
| def __post_init__(self): | |
| self.failure_tokens: list[str] = [] | |
| self.success_tokens: list[str] = [] | |
| def get_next_token(self, exclude_tokens=None): | |
| excluded = exclude_tokens or set() | |
| for token in self.tokens: | |
| if token not in excluded: | |
| return token | |
| return None | |
| async def record_token_failure(self, token: str, error=None, dao=None): | |
| self.failure_tokens.append(token) | |
| async def record_token_success(self, token: str, dao=None): | |
| self.success_tokens.append(token) | |
| def get_pool_status(self): | |
| return {"available_tokens": len(self.tokens)} | |
| class FakeResponse: | |
| def __init__(self, status_code: int, text: str = "{}"): | |
| self.status_code = status_code | |
| self.text = text | |
| def is_success(self) -> bool: | |
| return 200 <= self.status_code < 300 | |
| def _build_fake_async_client(handler): | |
| class FakeAsyncClient: | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc, tb): | |
| return False | |
| async def post(self, url, headers=None, json=None): | |
| return await handler(headers or {}) | |
| return FakeAsyncClient | |
| async def _build_guest_pool( | |
| monkeypatch, | |
| *, | |
| pool_size: int, | |
| user_ids: list[str], | |
| ) -> GuestSessionPool: | |
| pool = GuestSessionPool(pool_size=pool_size) | |
| queue = iter(user_ids) | |
| async def fake_create_session() -> GuestSession: | |
| return _make_guest_session(next(queue)) | |
| monkeypatch.setattr(pool, "_create_session", fake_create_session) | |
| monkeypatch.setattr(pool, "_maintenance_loop", AsyncMock(return_value=None)) | |
| monkeypatch.setattr(pool, "_delete_all_chats", AsyncMock(return_value=True)) | |
| await pool.initialize() | |
| await asyncio.sleep(0) | |
| return pool | |
| def _patch_upstream_dependencies( | |
| monkeypatch, | |
| *, | |
| token_pool, | |
| guest_pool, | |
| async_client_cls, | |
| ): | |
| monkeypatch.setattr(upstream_module, "get_token_pool", lambda: token_pool) | |
| monkeypatch.setattr(upstream_module, "get_guest_session_pool", lambda: guest_pool) | |
| monkeypatch.setattr(upstream_module.settings, "ANONYMOUS_MODE", True) | |
| monkeypatch.setattr( | |
| upstream_module.settings, | |
| "GUEST_POOL_SIZE", | |
| guest_pool.pool_size if guest_pool else 1, | |
| ) | |
| monkeypatch.setattr(upstream_module.httpx, "AsyncClient", async_client_cls) | |
| def _bind_minimal_request_flow(client: UpstreamClient, captures: list[dict]): | |
| async def fake_transform_request( | |
| self, | |
| request, | |
| excluded_tokens=None, | |
| excluded_guest_user_ids=None, | |
| ): | |
| auth_info = await self.get_auth_info( | |
| excluded_tokens=excluded_tokens, | |
| excluded_guest_user_ids=excluded_guest_user_ids, | |
| ) | |
| captures.append(dict(auth_info)) | |
| return { | |
| "url": "https://upstream.test/chat", | |
| "headers": { | |
| "x-token": str(auth_info["token"]), | |
| "x-token-source": str(auth_info["token_source"]), | |
| "x-guest-user-id": str(auth_info.get("guest_user_id") or ""), | |
| }, | |
| "body": {"model": request.model}, | |
| "token": auth_info["token"], | |
| "chat_id": "chat-id", | |
| "model": request.model, | |
| "user_id": auth_info["user_id"], | |
| "auth_mode": auth_info["auth_mode"], | |
| "token_source": auth_info["token_source"], | |
| "guest_user_id": auth_info["guest_user_id"], | |
| } | |
| async def fake_transform_response(self, response, request, transformed): | |
| return { | |
| "ok": response.is_success, | |
| "token_source": transformed["token_source"], | |
| "token": transformed["token"], | |
| "guest_user_id": transformed["guest_user_id"], | |
| } | |
| client.transform_request = types.MethodType(fake_transform_request, client) | |
| client.transform_response = types.MethodType(fake_transform_response, client) | |
| async def _run_chat_requests(client: UpstreamClient, count: int) -> list[dict]: | |
| tasks = [client.chat_completion(_make_request()) for _ in range(count)] | |
| return await asyncio.gather(*tasks) | |
| async def test_authenticated_tokens_are_used_before_guest_pool(monkeypatch): | |
| token_pool = StubTokenPool(["auth-1"]) | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=GUEST_POOL_SIZE, | |
| user_ids=["guest-1", "guest-2"], | |
| ) | |
| captures: list[dict] = [] | |
| acquire_calls = 0 | |
| async def counted_acquire(*args, **kwargs): | |
| nonlocal acquire_calls | |
| acquire_calls += 1 | |
| return await original_acquire(*args, **kwargs) | |
| async def handler(headers): | |
| await asyncio.sleep(MIXED_REQUEST_DELAY) | |
| return FakeResponse(200) | |
| client = UpstreamClient() | |
| original_acquire = guest_pool.acquire | |
| monkeypatch.setattr(guest_pool, "acquire", counted_acquire) | |
| _bind_minimal_request_flow(client, captures) | |
| _patch_upstream_dependencies( | |
| monkeypatch, | |
| token_pool=token_pool, | |
| guest_pool=guest_pool, | |
| async_client_cls=_build_fake_async_client(handler), | |
| ) | |
| try: | |
| results = await _run_chat_requests(client, AUTH_REQUEST_COUNT) | |
| pool_status = guest_pool.get_pool_status() | |
| finally: | |
| await guest_pool.close() | |
| assert all(result["ok"] is True for result in results) | |
| assert all(item["token_source"] == "auth_pool" for item in captures) | |
| assert acquire_calls == 0 | |
| assert token_pool.success_tokens == ["auth-1"] * AUTH_REQUEST_COUNT | |
| assert token_pool.failure_tokens == [] | |
| assert pool_status["busy_sessions"] == 0 | |
| assert pool_status["available_sessions"] == GUEST_POOL_SIZE | |
| async def test_authenticated_401_retries_next_token_before_guest_fallback(monkeypatch): | |
| token_pool = StubTokenPool(["auth-1", "auth-2"]) | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=GUEST_POOL_SIZE, | |
| user_ids=["guest-1", "guest-2"], | |
| ) | |
| captures: list[dict] = [] | |
| acquire_calls = 0 | |
| async def counted_acquire(*args, **kwargs): | |
| nonlocal acquire_calls | |
| acquire_calls += 1 | |
| return await original_acquire(*args, **kwargs) | |
| async def handler(headers): | |
| token = headers["x-token"] | |
| if token == "auth-1": | |
| return FakeResponse(401, '{"message":"expired"}') | |
| return FakeResponse(200) | |
| client = UpstreamClient() | |
| original_acquire = guest_pool.acquire | |
| monkeypatch.setattr(guest_pool, "acquire", counted_acquire) | |
| _bind_minimal_request_flow(client, captures) | |
| _patch_upstream_dependencies( | |
| monkeypatch, | |
| token_pool=token_pool, | |
| guest_pool=guest_pool, | |
| async_client_cls=_build_fake_async_client(handler), | |
| ) | |
| try: | |
| result = await client.chat_completion(_make_request()) | |
| finally: | |
| await guest_pool.close() | |
| assert result["ok"] is True | |
| assert [item["token"] for item in captures] == ["auth-1", "auth-2"] | |
| assert [item["token_source"] for item in captures] == ["auth_pool", "auth_pool"] | |
| assert token_pool.failure_tokens == ["auth-1"] | |
| assert token_pool.success_tokens == ["auth-2"] | |
| assert acquire_calls == 0 | |
| async def test_authenticated_pool_exhaustion_falls_back_to_guest(monkeypatch): | |
| token_pool = StubTokenPool(["auth-1", "auth-2"]) | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=GUEST_POOL_SIZE, | |
| user_ids=["guest-1", "guest-2", "guest-3"], | |
| ) | |
| captures: list[dict] = [] | |
| async def handler(headers): | |
| if headers["x-token-source"] == "auth_pool": | |
| return FakeResponse(401, '{"message":"expired"}') | |
| return FakeResponse(200) | |
| client = UpstreamClient() | |
| _bind_minimal_request_flow(client, captures) | |
| _patch_upstream_dependencies( | |
| monkeypatch, | |
| token_pool=token_pool, | |
| guest_pool=guest_pool, | |
| async_client_cls=_build_fake_async_client(handler), | |
| ) | |
| try: | |
| result = await client.chat_completion(_make_request()) | |
| pool_status = guest_pool.get_pool_status() | |
| finally: | |
| await guest_pool.close() | |
| assert result["ok"] is True | |
| assert [item["token_source"] for item in captures] == [ | |
| "auth_pool", | |
| "auth_pool", | |
| "guest_pool", | |
| ] | |
| assert token_pool.failure_tokens == ["auth-1", "auth-2"] | |
| assert token_pool.success_tokens == [] | |
| assert result["guest_user_id"] | |
| assert pool_status["busy_sessions"] == 0 | |
| async def test_guest_retry_is_isolated_and_does_not_pollute_auth_stats(monkeypatch): | |
| token_pool = StubTokenPool(["auth-1", "auth-2"]) | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=GUEST_POOL_SIZE, | |
| user_ids=["guest-1", "guest-2", "guest-3", "guest-4"], | |
| ) | |
| captures: list[dict] = [] | |
| async def handler(headers): | |
| source = headers["x-token-source"] | |
| guest_user_id = headers["x-guest-user-id"] | |
| if source == "auth_pool": | |
| return FakeResponse(401, '{"message":"expired"}') | |
| if guest_user_id == "guest-1": | |
| return FakeResponse(401, '{"message":"expired"}') | |
| return FakeResponse(200) | |
| client = UpstreamClient() | |
| _bind_minimal_request_flow(client, captures) | |
| _patch_upstream_dependencies( | |
| monkeypatch, | |
| token_pool=token_pool, | |
| guest_pool=guest_pool, | |
| async_client_cls=_build_fake_async_client(handler), | |
| ) | |
| try: | |
| result = await client.chat_completion(_make_request()) | |
| pool_status = guest_pool.get_pool_status() | |
| finally: | |
| await guest_pool.close() | |
| guest_ids = [ | |
| item["guest_user_id"] | |
| for item in captures | |
| if item["token_source"] == "guest_pool" | |
| ] | |
| assert result["ok"] is True | |
| assert [item["token"] for item in captures[:2]] == ["auth-1", "auth-2"] | |
| assert token_pool.failure_tokens == ["auth-1", "auth-2"] | |
| assert token_pool.success_tokens == [] | |
| assert guest_ids[0] == "guest-1" | |
| assert guest_ids[1] != "guest-1" | |
| assert pool_status["busy_sessions"] == 0 | |
| assert pool_status["valid_sessions"] == GUEST_POOL_SIZE | |
| async def test_cleanup_idle_chats_only_touches_idle_valid_sessions(monkeypatch): | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=3, | |
| user_ids=["guest-1", "guest-2", "guest-3"], | |
| ) | |
| deleted_user_ids: list[str] = [] | |
| async def fake_delete_all_chats(session: GuestSession): | |
| deleted_user_ids.append(session.user_id) | |
| return True | |
| monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats) | |
| guest_pool._sessions["guest-2"].active_requests = 1 | |
| try: | |
| await guest_pool.cleanup_idle_chats() | |
| deleted_before_close = list(deleted_user_ids) | |
| finally: | |
| await guest_pool.close() | |
| assert deleted_before_close == ["guest-1", "guest-3"] | |
| async def test_report_failure_only_retires_target_guest_session(monkeypatch): | |
| guest_pool = await _build_guest_pool( | |
| monkeypatch, | |
| pool_size=3, | |
| user_ids=["guest-1", "guest-2", "guest-3", "guest-4"], | |
| ) | |
| deleted_user_ids: list[str] = [] | |
| async def fake_delete_all_chats(session: GuestSession): | |
| deleted_user_ids.append(session.user_id) | |
| return True | |
| monkeypatch.setattr(guest_pool, "_delete_all_chats", fake_delete_all_chats) | |
| try: | |
| await guest_pool.report_failure("guest-1") | |
| await asyncio.sleep(0) | |
| current_user_ids = set(guest_pool._sessions) | |
| deleted_before_close = list(deleted_user_ids) | |
| finally: | |
| await guest_pool.close() | |
| assert "guest-1" not in current_user_ids | |
| assert "guest-2" in current_user_ids | |
| assert "guest-3" in current_user_ids | |
| assert "guest-4" in current_user_ids | |
| assert deleted_before_close == ["guest-1"] | |