File size: 12,921 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""Client module tests — TRN 13.

Tests cover ReplicaLabClient with both REST and WebSocket transports
against the real FastAPI test server.
"""

from __future__ import annotations

import contextlib
import json
import threading
import time

import pytest
import uvicorn

from replicalab.client import ReplicaLabClient
from replicalab.models import (
    Observation,
    ScientistAction,
    StepResult,
)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _propose_action(obs: Observation) -> ScientistAction:
    """Build a valid propose_protocol action from the observation."""
    from replicalab.scenarios import generate_scenario

    pack = generate_scenario(seed=42, template="math_reasoning", difficulty="easy")
    lab = pack.lab_manager_observation
    spec = pack.hidden_reference_spec
    return ScientistAction(
        action_type="propose_protocol",
        sample_size=10,
        controls=["baseline", "ablation"],
        technique=spec.summary[:60] if spec.summary else "replication_plan",
        duration_days=max(1, min(2, lab.time_limit_days)),
        required_equipment=list(lab.equipment_available[:1]) if lab.equipment_available else [],
        required_reagents=list(lab.reagents_in_stock[:1]) if lab.reagents_in_stock else [],
        questions=[],
        rationale=(
            f"Plan addresses: {', '.join(spec.required_elements[:2])}. "
            f"Target metric: {spec.target_metric}. "
            f"Target value: {spec.target_value}. "
            "Stay within budget and schedule."
        ),
    )


def _accept_action() -> ScientistAction:
    return ScientistAction(
        action_type="accept",
        sample_size=0,
        controls=[],
        technique="",
        duration_days=0,
        required_equipment=[],
        required_reagents=[],
        questions=[],
        rationale="",
    )


# ---------------------------------------------------------------------------
# REST transport tests (uses httpx directly against TestClient-proxied app)
# ---------------------------------------------------------------------------

# We spin up a real uvicorn server on a random port for both transports
# to keep things realistic and test the actual HTTP/WS paths.

_TEST_PORT = 18765


@pytest.fixture(scope="module")
def live_server():
    """Start a live uvicorn server for the test module."""
    from server.app import app

    config = uvicorn.Config(app, host="127.0.0.1", port=_TEST_PORT, log_level="error")
    server = uvicorn.Server(config)
    thread = threading.Thread(target=server.run, daemon=True)
    thread.start()

    # Wait until server is ready
    import httpx
    for _ in range(50):
        try:
            resp = httpx.get(f"http://127.0.0.1:{_TEST_PORT}/health", timeout=1.0)
            if resp.status_code == 200:
                break
        except Exception:
            pass
        time.sleep(0.1)
    else:
        pytest.fail("Live server did not start in time")

    yield f"http://127.0.0.1:{_TEST_PORT}"

    server.should_exit = True
    thread.join(timeout=5)


# ---------------------------------------------------------------------------
# REST transport
# ---------------------------------------------------------------------------


class TestRestConnect:
    """connect() over REST verifies server health."""

    def test_connect_succeeds(self, live_server: str) -> None:
        client = ReplicaLabClient(live_server, transport="rest")
        client.connect()
        assert client.connected
        client.close()

    def test_connect_bad_url_raises(self) -> None:
        client = ReplicaLabClient("http://127.0.0.1:19999", transport="rest", timeout=1.0)
        with pytest.raises(Exception):
            client.connect()


class TestRestReset:
    """reset() over REST."""

    def test_reset_returns_observation(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            obs = client.reset(seed=42, scenario="math_reasoning", difficulty="easy")
            assert isinstance(obs, Observation)
            assert obs.scientist is not None
            assert obs.scientist.paper_title
            assert obs.lab_manager is not None
            assert obs.lab_manager.budget_total > 0

    def test_reset_sets_session_and_episode_id(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            client.reset(seed=1)
            assert client.session_id is not None
            assert client.episode_id is not None

    def test_reset_reuses_session(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            client.reset(seed=1)
            sid1 = client.session_id
            ep1 = client.episode_id
            client.reset(seed=2)
            assert client.session_id == sid1
            assert client.episode_id != ep1


class TestRestStep:
    """step() over REST."""

    def test_step_returns_step_result(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            obs = client.reset(seed=42)
            action = _propose_action(obs)
            result = client.step(action)
            assert isinstance(result, StepResult)
            assert result.done is False
            assert result.observation is not None

    def test_step_before_reset_raises(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            with pytest.raises(RuntimeError, match="reset"):
                client.step(_accept_action())

    def test_full_episode_propose_accept(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            obs = client.reset(seed=42)
            action = _propose_action(obs)
            result1 = client.step(action)
            assert result1.done is False

            result2 = client.step(_accept_action())
            assert result2.done is True
            assert result2.reward > 0.0
            assert result2.info.agreement_reached is True
            assert result2.info.verdict == "accept"
            assert result2.info.reward_breakdown is not None
            assert 0.0 <= result2.info.reward_breakdown.rigor <= 1.0


class TestRestReplay:
    """replay() over REST."""

    def test_replay_after_episode(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="rest") as client:
            obs = client.reset(seed=42)
            action = _propose_action(obs)
            client.step(action)
            client.step(_accept_action())

            episode_id = client.episode_id
            assert episode_id is not None
            replay = client.replay(episode_id)
            assert replay.agreement_reached is True
            assert replay.total_reward > 0.0
            assert replay.verdict == "accept"


class TestRestContextManager:
    """Context manager cleans up on exit."""

    def test_context_manager_closes(self, live_server: str) -> None:
        client = ReplicaLabClient(live_server, transport="rest")
        with client:
            assert client.connected
            client.reset(seed=1)
        assert not client.connected


# ---------------------------------------------------------------------------
# WebSocket transport
# ---------------------------------------------------------------------------


class TestWsConnect:
    """connect() over WebSocket."""

    def test_connect_succeeds(self, live_server: str) -> None:
        client = ReplicaLabClient(live_server, transport="websocket")
        client.connect()
        assert client.connected
        client.close()

    def test_connect_bad_url_raises(self) -> None:
        client = ReplicaLabClient("http://127.0.0.1:19999", transport="websocket", timeout=1.0)
        with pytest.raises(Exception):
            client.connect()


class TestWsReset:
    """reset() over WebSocket."""

    def test_reset_returns_observation(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            obs = client.reset(seed=42, scenario="math_reasoning", difficulty="easy")
            assert isinstance(obs, Observation)
            assert obs.scientist is not None
            assert obs.scientist.paper_title
            assert obs.lab_manager is not None
            assert obs.lab_manager.budget_total > 0

    def test_reset_sets_episode_id(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            client.reset(seed=42)
            assert client.episode_id is not None

    def test_ws_session_id_is_none(self, live_server: str) -> None:
        """WebSocket transport has no explicit session_id."""
        with ReplicaLabClient(live_server, transport="websocket") as client:
            client.reset(seed=42)
            assert client.session_id is None


class TestWsStep:
    """step() over WebSocket."""

    def test_step_returns_step_result(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            obs = client.reset(seed=42)
            action = _propose_action(obs)
            result = client.step(action)
            assert isinstance(result, StepResult)
            assert result.done is False
            assert result.observation is not None

    def test_full_episode_propose_accept(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            obs = client.reset(seed=42)
            action = _propose_action(obs)
            result1 = client.step(action)
            assert result1.done is False

            result2 = client.step(_accept_action())
            assert result2.done is True
            assert result2.reward > 0.0
            assert result2.info.agreement_reached is True
            assert result2.info.verdict == "accept"
            assert result2.info.reward_breakdown is not None
            assert 0.0 <= result2.info.reward_breakdown.rigor <= 1.0

    def test_semantic_invalid_action_step_ok_with_error(self, live_server: str) -> None:
        """Semantically invalid action → step result with info.error, not crash."""
        with ReplicaLabClient(live_server, transport="websocket") as client:
            client.reset(seed=42)
            bad_action = ScientistAction(
                action_type="propose_protocol",
                sample_size=5,
                controls=["baseline"],
                technique="some technique",
                duration_days=999,
                required_equipment=[],
                required_reagents=[],
                questions=[],
                rationale="Duration is impossibly long.",
            )
            result = client.step(bad_action)
            assert result.done is False
            assert result.info.error is not None
            assert "Validation errors" in result.info.error


class TestWsContextManager:
    """Context manager cleans up on exit."""

    def test_context_manager_closes(self, live_server: str) -> None:
        client = ReplicaLabClient(live_server, transport="websocket")
        with client:
            assert client.connected
            client.reset(seed=1)
        assert not client.connected


class TestWsUnsupported:
    """state() and replay() raise NotImplementedError on WS transport."""

    def test_state_not_supported(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            client.reset(seed=42)
            with pytest.raises(NotImplementedError):
                client.state()

    def test_replay_not_supported(self, live_server: str) -> None:
        with ReplicaLabClient(live_server, transport="websocket") as client:
            with pytest.raises(NotImplementedError):
                client.replay("some-id")


# ---------------------------------------------------------------------------
# Constructor validation
# ---------------------------------------------------------------------------


class TestConstructor:
    """Transport selection and validation."""

    def test_unknown_transport_raises(self) -> None:
        with pytest.raises(ValueError, match="Unknown transport"):
            ReplicaLabClient(transport="grpc")

    def test_not_connected_raises_on_reset(self) -> None:
        client = ReplicaLabClient(transport="rest")
        with pytest.raises(RuntimeError, match="not connected"):
            client.reset(seed=1)

    def test_default_transport_is_websocket(self) -> None:
        client = ReplicaLabClient()
        # Check internal transport type
        assert type(client._transport).__name__ == "_WsTransport"