File size: 3,415 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Additional rollout tests for TRN 04 metadata and collection behavior."""

from __future__ import annotations

from replicalab.models import Observation, RewardBreakdown, ScientistAction, StepInfo, StepResult
from replicalab.scenarios import generate_scenario
from replicalab.training.rollout import RolloutWorker


def _scientist_obs():
    return generate_scenario(
        seed=7, template="math_reasoning", difficulty="easy"
    ).scientist_observation


def _accept_action() -> ScientistAction:
    return ScientistAction(
        action_type="accept",
        sample_size=0,
        controls=[],
        technique="",
        duration_days=0,
        required_equipment=[],
        required_reagents=[],
        questions=[],
        rationale="",
    )


class _FakeClient:
    def __init__(self) -> None:
        self.episode_id = "episode-1"
        self._step_count = 0

    def reset(self, *, seed: int, scenario: str, difficulty: str) -> Observation:
        obs = _scientist_obs()
        return Observation(scientist=obs, lab_manager=None)

    def step(self, action: ScientistAction) -> StepResult:
        self._step_count += 1
        obs = _scientist_obs()
        if self._step_count == 1:
            return StepResult(
                observation=Observation(scientist=obs, lab_manager=None),
                reward=0.0,
                done=False,
                info=StepInfo(
                    agreement_reached=False,
                    error=None,
                    round=1,
                    tool_traces=[
                        {
                            "tool": "search_evidence",
                            "status": "ok",
                            "query": "baseline reference",
                        }
                    ],
                ),
            )
        return StepResult(
            observation=Observation(scientist=obs, lab_manager=None),
            reward=3.5,
            done=True,
            info=StepInfo(
                agreement_reached=True,
                error=None,
                reward_breakdown=RewardBreakdown(
                    rigor=0.8,
                    feasibility=0.7,
                    fidelity=0.75,
                ),
                judge_notes="Deterministic terminal note.",
                verdict="accept",
                round=2,
                tool_traces=[
                    {
                        "tool": "run_code_check",
                        "status": "ok",
                        "task_type": "metric_check",
                    }
                ],
            ),
        )


def test_rollout_captures_tool_traces_from_step_info_extras() -> None:
    worker = RolloutWorker(_FakeClient())
    record = worker.rollout(lambda _obs: _accept_action(), seed=11)

    assert record.tool_trace_count == 2
    assert record.steps[0].tool_traces[0]["tool"] == "search_evidence"
    assert record.steps[1].tool_traces[0]["tool"] == "run_code_check"
    assert record.terminal_info is not None


def test_collect_rollouts_returns_one_record_per_seed() -> None:
    worker = RolloutWorker(_FakeClient())
    records = worker.collect_rollouts(
        lambda _obs: _accept_action(),
        seeds=[1, 2, 3],
        scenario="math_reasoning",
        difficulty="easy",
    )

    assert [record.seed for record in records] == [1, 2, 3]
    assert all(record.verdict == "accept" for record in records)