File size: 9,060 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""Tests for replicalab.training.rollout — TRN 03.

Verifies that RolloutWorker can run full episodes through the client,
collect trajectories, and surface judge output for RL training.
"""

from __future__ import annotations

import threading
import time

import pytest
import uvicorn

from replicalab.agents import build_baseline_scientist_action
from replicalab.client import ReplicaLabClient
from replicalab.models import RewardBreakdown, ScientistAction, ScientistObservation
from replicalab.training.rollout import EpisodeRecord, RolloutWorker, StepRecord


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

_TEST_PORT = 18766


@pytest.fixture(scope="module")
def live_server():
    """Start a live uvicorn server for rollout tests."""
    from server.app import app

    config = uvicorn.Config(app, host="127.0.0.1", port=_TEST_PORT, log_level="error")
    server = uvicorn.Server(config)
    thread = threading.Thread(target=server.run, daemon=True)
    thread.start()

    import httpx

    for _ in range(50):
        try:
            resp = httpx.get(f"http://127.0.0.1:{_TEST_PORT}/health", timeout=1.0)
            if resp.status_code == 200:
                break
        except Exception:
            pass
        time.sleep(0.1)
    else:
        pytest.fail("Live server did not start in time")

    yield f"http://127.0.0.1:{_TEST_PORT}"

    server.should_exit = True
    thread.join(timeout=5)


@pytest.fixture()
def client(live_server: str):
    """Provide a connected REST client."""
    c = ReplicaLabClient(live_server, transport="rest")
    c.connect()
    yield c
    c.close()


# ---------------------------------------------------------------------------
# Full episode via baseline policy
# ---------------------------------------------------------------------------


class TestBaselineRollout:
    """Run real episodes with the deterministic baseline policy."""

    def test_rollout_completes(self, client: ReplicaLabClient) -> None:
        """Baseline policy finishes an episode start-to-finish."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        assert isinstance(record, EpisodeRecord)
        assert record.rounds_used > 0
        assert record.verdict is not None

    def test_rollout_returns_reward(self, client: ReplicaLabClient) -> None:
        """Terminal episode has a real total reward."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        assert record.total_reward > 0.0
        assert record.agreement_reached is True
        assert record.succeeded is True

    def test_rollout_returns_reward_breakdown(
        self, client: ReplicaLabClient
    ) -> None:
        """Reward breakdown has rigor, feasibility, fidelity in [0,1]."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        rb = record.reward_breakdown
        assert rb is not None
        assert isinstance(rb, RewardBreakdown)
        assert 0.0 <= rb.rigor <= 1.0
        assert 0.0 <= rb.feasibility <= 1.0
        assert 0.0 <= rb.fidelity <= 1.0

    def test_rollout_returns_judge_notes(
        self, client: ReplicaLabClient
    ) -> None:
        """Judge notes and verdict are populated."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        assert record.judge_notes is not None
        assert len(record.judge_notes) > 0
        assert record.verdict in ("accept", "timeout", "no_agreement")

    def test_rollout_steps_have_observations(
        self, client: ReplicaLabClient
    ) -> None:
        """Each step record contains the scientist observation and action."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        for step in record.steps:
            assert isinstance(step, StepRecord)
            assert isinstance(step.observation, ScientistObservation)
            assert isinstance(step.action, ScientistAction)

    def test_rollout_episode_id_set(self, client: ReplicaLabClient) -> None:
        """Episode ID is captured from the client."""
        worker = RolloutWorker(client)
        record = worker.rollout(build_baseline_scientist_action, seed=42)

        assert record.episode_id is not None
        assert len(record.episode_id) > 0


# ---------------------------------------------------------------------------
# Determinism and configuration
# ---------------------------------------------------------------------------


class TestRolloutConfig:
    """Configuration, determinism, and edge cases."""

    def test_rollout_is_deterministic(self, client: ReplicaLabClient) -> None:
        """Same seed → same reward and verdict."""
        worker = RolloutWorker(client)

        r1 = worker.rollout(build_baseline_scientist_action, seed=99)
        r2 = worker.rollout(build_baseline_scientist_action, seed=99)

        assert r1.total_reward == r2.total_reward
        assert r1.verdict == r2.verdict
        assert r1.rounds_used == r2.rounds_used

    def test_different_seeds_produce_different_episodes(
        self, client: ReplicaLabClient
    ) -> None:
        """Different seeds may produce different episode IDs."""
        worker = RolloutWorker(client)

        r1 = worker.rollout(build_baseline_scientist_action, seed=1)
        r2 = worker.rollout(build_baseline_scientist_action, seed=2)

        assert r1.episode_id != r2.episode_id

    def test_rollout_across_scenarios(self, client: ReplicaLabClient) -> None:
        """Rollout works for all 3 scenario families."""
        worker = RolloutWorker(client)

        for template in ("math_reasoning", "ml_benchmark", "finance_trading"):
            record = worker.rollout(
                build_baseline_scientist_action,
                seed=42,
                scenario=template,
                difficulty="easy",
            )
            assert record.rounds_used > 0
            assert record.verdict is not None

    def test_rollout_metadata_matches_input(
        self, client: ReplicaLabClient
    ) -> None:
        """EpisodeRecord captures the seed, scenario, and difficulty."""
        worker = RolloutWorker(client)
        record = worker.rollout(
            build_baseline_scientist_action,
            seed=77,
            scenario="finance_trading",
            difficulty="medium",
        )

        assert record.seed == 77
        assert record.scenario == "finance_trading"
        assert record.difficulty == "medium"

    def test_max_steps_cap(self, client: ReplicaLabClient) -> None:
        """max_steps prevents infinite loops even with a bad policy."""
        def _always_propose(obs: ScientistObservation) -> ScientistAction:
            return ScientistAction(
                action_type="propose_protocol",
                sample_size=5,
                controls=["baseline"],
                technique="method",
                duration_days=1,
                required_equipment=[],
                required_reagents=[],
                questions=[],
                rationale="Repeating proposal every round.",
            )

        worker = RolloutWorker(client, max_steps=3)
        record = worker.rollout(_always_propose, seed=42)

        assert record.rounds_used <= 3


# ---------------------------------------------------------------------------
# Error path
# ---------------------------------------------------------------------------


class TestRolloutErrors:
    """Error surfacing from env through the rollout."""

    def test_validation_error_captured_in_step(
        self, client: ReplicaLabClient
    ) -> None:
        """If the policy produces a semantically invalid action,
        info.error is captured in the step record."""
        call_count = 0

        def _bad_then_accept(obs: ScientistObservation) -> ScientistAction:
            nonlocal call_count
            call_count += 1
            if call_count == 1:
                # First call: invalid duration
                return ScientistAction(
                    action_type="propose_protocol",
                    sample_size=5,
                    controls=["baseline"],
                    technique="method",
                    duration_days=999,
                    required_equipment=[],
                    required_reagents=[],
                    questions=[],
                    rationale="Duration is impossibly long.",
                )
            # After that: use baseline to finish
            return build_baseline_scientist_action(obs)

        worker = RolloutWorker(client)
        record = worker.rollout(_bad_then_accept, seed=42)

        # First step should have captured the validation error
        assert record.steps[0].error is not None
        assert "Validation" in record.steps[0].error