Spaces:
Running
Running
File size: 5,053 Bytes
c745a99 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """Regression tests for scripts/grpo_pool.py.
Focus: `GrpoPool.connect()` must be all-or-nothing. If any single WebSocket
handshake fails, every session that DID connect must be closed before the
error propagates, so the server never ends up with leaked pool slots.
No pytest-asyncio dependency — each test drives the loop via asyncio.run().
"""
from __future__ import annotations
import asyncio
import pytest
from scripts.grpo_pool import GrpoPool
class _FakeEnv:
"""Minimal stand-in for AwsRlEnv. Tracks connect/close lifecycle."""
connect_calls = 0 # class-level so the factory can index envs in order
def __init__(self, *, should_fail_on_index: int | None = None) -> None:
self.connected = False
self.close_called = False
self._index = _FakeEnv.connect_calls
_FakeEnv.connect_calls += 1
self._should_fail = (
should_fail_on_index is not None and self._index == should_fail_on_index
)
async def connect(self) -> None:
if self._should_fail:
raise ConnectionError(f"fake failure on env#{self._index}")
await asyncio.sleep(0) # yield so sibling connects can interleave
self.connected = True
async def close(self) -> None:
self.close_called = True
def _install_fake_env(monkeypatch, fail_on_index: int | None) -> list[_FakeEnv]:
"""Monkeypatch AwsRlEnv inside scripts.grpo_pool so GrpoPool builds FakeEnvs.
Returns a shared list the test can inspect after connect() runs.
"""
_FakeEnv.connect_calls = 0
created: list[_FakeEnv] = []
def factory(*args, **kwargs) -> _FakeEnv:
env = _FakeEnv(should_fail_on_index=fail_on_index)
created.append(env)
return env
monkeypatch.setattr("scripts.grpo_pool.AwsRlEnv", factory)
return created
# ---------------------------------------------------------------------------
# Happy path — sanity check the fake harness before running the failure cases
# ---------------------------------------------------------------------------
class TestConnectHappyPath:
def test_all_sessions_connect_and_land_on_pool(self, monkeypatch) -> None:
created = _install_fake_env(monkeypatch, fail_on_index=None)
pool = GrpoPool(base_url="http://x", size=4)
asyncio.run(pool.connect())
assert len(pool.envs) == 4
assert all(e.connected for e in created)
assert not any(e.close_called for e in created)
# ---------------------------------------------------------------------------
# The review: partial failure must roll back
# ---------------------------------------------------------------------------
class TestConnectRollbackOnPartialFailure:
def test_failure_closes_every_env_including_successful_ones(
self, monkeypatch
) -> None:
created = _install_fake_env(monkeypatch, fail_on_index=2)
pool = GrpoPool(base_url="http://x", size=4)
with pytest.raises(ConnectionError):
asyncio.run(pool.connect())
# Every FakeEnv must have had close() called — successful ones so
# server slots are released; the failing one as a harmless no-op.
assert all(e.close_called for e in created), (
"Regression: successful sessions leaked after partial connect failure"
)
def test_pool_envs_stays_empty_on_failure(self, monkeypatch) -> None:
_install_fake_env(monkeypatch, fail_on_index=1)
pool = GrpoPool(base_url="http://x", size=3)
with pytest.raises(ConnectionError):
asyncio.run(pool.connect())
# connect() must NOT leave a half-initialised pool visible to callers.
assert pool.envs == []
def test_failure_does_not_block_retry(self, monkeypatch) -> None:
"""After a failed connect(), the caller can fix the root cause and
call connect() again. pool.envs should be fresh."""
_install_fake_env(monkeypatch, fail_on_index=0)
pool = GrpoPool(base_url="http://x", size=2)
with pytest.raises(ConnectionError):
asyncio.run(pool.connect())
# Second attempt with no injected failure should succeed.
_install_fake_env(monkeypatch, fail_on_index=None)
asyncio.run(pool.connect())
assert len(pool.envs) == 2
assert all(e.connected for e in pool.envs)
def test_async_context_manager_cleans_up_when_enter_fails(
self, monkeypatch
) -> None:
"""If `async with GrpoPool(...)` raises during __aenter__,
__aexit__ is NOT called — so rollback must live inside connect()
itself. This test exercises exactly that scenario.
"""
created = _install_fake_env(monkeypatch, fail_on_index=2)
async def enter_and_fail() -> None:
async with GrpoPool(base_url="http://x", size=4):
pytest.fail("should never enter the body")
with pytest.raises(ConnectionError):
asyncio.run(enter_and_fail())
assert all(e.close_called for e in created)
|