Spaces:
Running
Running
File size: 3,415 Bytes
80d8c84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """Additional rollout tests for TRN 04 metadata and collection behavior."""
from __future__ import annotations
from replicalab.models import Observation, RewardBreakdown, ScientistAction, StepInfo, StepResult
from replicalab.scenarios import generate_scenario
from replicalab.training.rollout import RolloutWorker
def _scientist_obs():
return generate_scenario(
seed=7, template="math_reasoning", difficulty="easy"
).scientist_observation
def _accept_action() -> ScientistAction:
return ScientistAction(
action_type="accept",
sample_size=0,
controls=[],
technique="",
duration_days=0,
required_equipment=[],
required_reagents=[],
questions=[],
rationale="",
)
class _FakeClient:
def __init__(self) -> None:
self.episode_id = "episode-1"
self._step_count = 0
def reset(self, *, seed: int, scenario: str, difficulty: str) -> Observation:
obs = _scientist_obs()
return Observation(scientist=obs, lab_manager=None)
def step(self, action: ScientistAction) -> StepResult:
self._step_count += 1
obs = _scientist_obs()
if self._step_count == 1:
return StepResult(
observation=Observation(scientist=obs, lab_manager=None),
reward=0.0,
done=False,
info=StepInfo(
agreement_reached=False,
error=None,
round=1,
tool_traces=[
{
"tool": "search_evidence",
"status": "ok",
"query": "baseline reference",
}
],
),
)
return StepResult(
observation=Observation(scientist=obs, lab_manager=None),
reward=3.5,
done=True,
info=StepInfo(
agreement_reached=True,
error=None,
reward_breakdown=RewardBreakdown(
rigor=0.8,
feasibility=0.7,
fidelity=0.75,
),
judge_notes="Deterministic terminal note.",
verdict="accept",
round=2,
tool_traces=[
{
"tool": "run_code_check",
"status": "ok",
"task_type": "metric_check",
}
],
),
)
def test_rollout_captures_tool_traces_from_step_info_extras() -> None:
worker = RolloutWorker(_FakeClient())
record = worker.rollout(lambda _obs: _accept_action(), seed=11)
assert record.tool_trace_count == 2
assert record.steps[0].tool_traces[0]["tool"] == "search_evidence"
assert record.steps[1].tool_traces[0]["tool"] == "run_code_check"
assert record.terminal_info is not None
def test_collect_rollouts_returns_one_record_per_seed() -> None:
worker = RolloutWorker(_FakeClient())
records = worker.collect_rollouts(
lambda _obs: _accept_action(),
seeds=[1, 2, 3],
scenario="math_reasoning",
difficulty="easy",
)
assert [record.seed for record in records] == [1, 2, 3]
assert all(record.verdict == "accept" for record in records)
|