File size: 9,432 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""Contract tests for judge audit payloads (TST 11).

Verifies that terminal StepInfo and EpisodeLog contain all required
audit fields with correct types and semantics.
"""

from __future__ import annotations

import json

import pytest
from fastapi.testclient import TestClient

from replicalab.models import EpisodeLog, RewardBreakdown, StepInfo
from server.app import app


@pytest.fixture()
def client():
    return TestClient(app)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _good_action_payload(client: TestClient) -> dict:
    from replicalab.scenarios import generate_scenario

    scenario = generate_scenario(seed=42, template="math_reasoning", difficulty="easy")
    lab = scenario.lab_manager_observation
    spec = scenario.hidden_reference_spec
    return {
        "action_type": "propose_protocol",
        "sample_size": 10,
        "controls": ["baseline", "ablation"],
        "technique": spec.summary[:60] if spec.summary else "replication_plan",
        "duration_days": max(1, min(2, lab.time_limit_days)),
        "required_equipment": list(lab.equipment_available[:1]) if lab.equipment_available else [],
        "required_reagents": list(lab.reagents_in_stock[:1]) if lab.reagents_in_stock else [],
        "questions": [],
        "rationale": (
            f"Plan addresses: {', '.join(spec.required_elements[:2])}. "
            f"Target metric: {spec.target_metric}. "
            f"Target value: {spec.target_value}. "
            "Stay within budget and schedule."
        ),
    }


def _accept_action() -> dict:
    return {
        "action_type": "accept",
        "sample_size": 0,
        "controls": [],
        "technique": "",
        "duration_days": 0,
        "required_equipment": [],
        "required_reagents": [],
        "questions": [],
        "rationale": "",
    }


def _run_to_terminal(client: TestClient, *, accept: bool = True) -> tuple[str, dict]:
    """Reset, propose, and optionally accept. Returns (episode_id, terminal_response_json)."""
    reset = client.post("/reset", json={"seed": 42, "scenario": "math_reasoning", "difficulty": "easy"}).json()
    session_id = reset["session_id"]
    episode_id = reset["episode_id"]

    action = _good_action_payload(client)
    client.post("/step", json={"session_id": session_id, "action": action})

    if accept:
        resp = client.post("/step", json={"session_id": session_id, "action": _accept_action()})
    else:
        # Run to timeout
        for _ in range(10):
            resp = client.post("/step", json={"session_id": session_id, "action": action})
            if resp.json()["done"]:
                break

    return episode_id, resp.json()


# ---------------------------------------------------------------------------
# StepInfo audit field contracts
# ---------------------------------------------------------------------------


class TestStepInfoAuditContract:
    """Terminal StepInfo must contain all audit fields."""

    def test_terminal_info_has_verdict(self, client: TestClient) -> None:
        _, data = _run_to_terminal(client)
        assert data["info"]["verdict"] in ("accept", "revise", "timeout")

    def test_terminal_info_has_judge_notes(self, client: TestClient) -> None:
        _, data = _run_to_terminal(client)
        assert isinstance(data["info"]["judge_notes"], str)
        assert len(data["info"]["judge_notes"]) > 0

    def test_terminal_info_has_reward_breakdown(self, client: TestClient) -> None:
        _, data = _run_to_terminal(client)
        rb = data["info"]["reward_breakdown"]
        assert rb is not None
        for key in ("rigor", "feasibility", "fidelity", "parsimony"):
            assert key in rb
            assert 0.0 <= rb[key] <= 1.0

    def test_terminal_info_has_top_failure_reasons(self, client: TestClient) -> None:
        _, data = _run_to_terminal(client)
        reasons = data["info"]["top_failure_reasons"]
        assert isinstance(reasons, list)

    def test_terminal_info_has_agreement_reached(self, client: TestClient) -> None:
        _, data = _run_to_terminal(client)
        assert isinstance(data["info"]["agreement_reached"], bool)
        assert data["info"]["agreement_reached"] is True

    def test_non_terminal_info_has_no_verdict(self, client: TestClient) -> None:
        reset = client.post("/reset", json={"seed": 42}).json()
        action = _good_action_payload(client)
        resp = client.post("/step", json={"session_id": reset["session_id"], "action": action})
        data = resp.json()
        assert data["done"] is False
        assert data["info"]["verdict"] is None
        assert data["info"]["reward_breakdown"] is None


# ---------------------------------------------------------------------------
# EpisodeLog audit field contracts
# ---------------------------------------------------------------------------


class TestEpisodeLogAuditContract:
    """GET /replay/{episode_id} must include full audit metadata."""

    def test_replay_has_verdict_and_judge_notes(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        assert replay["verdict"] in ("accept", "revise", "timeout")
        assert isinstance(replay["judge_notes"], str)
        assert len(replay["judge_notes"]) > 0

    def test_replay_has_reward_breakdown(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        rb = replay["reward_breakdown"]
        assert rb is not None
        for key in ("rigor", "feasibility", "fidelity", "parsimony"):
            assert key in rb

    def test_replay_has_top_failure_reasons(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        assert isinstance(replay["top_failure_reasons"], list)

    def test_replay_has_transcript(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        assert isinstance(replay["transcript"], list)
        assert len(replay["transcript"]) > 0

    def test_replay_has_invalid_action_fields(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        assert "invalid_action_count" in replay
        assert "invalid_action_rate" in replay
        assert isinstance(replay["invalid_action_count"], int)
        assert isinstance(replay["invalid_action_rate"], (int, float))
        assert replay["invalid_action_count"] >= 0
        assert 0.0 <= replay["invalid_action_rate"] <= 1.0

    def test_replay_has_total_reward(self, client: TestClient) -> None:
        episode_id, _ = _run_to_terminal(client)
        replay = client.get(f"/replay/{episode_id}").json()
        assert isinstance(replay["total_reward"], (int, float))


# ---------------------------------------------------------------------------
# Pydantic model contracts (unit-level)
# ---------------------------------------------------------------------------


class TestAuditModelContracts:
    """Pydantic model round-trip for audit fields."""

    def test_step_info_default_audit_fields(self) -> None:
        info = StepInfo()
        assert info.agreement_reached is False
        assert info.verdict is None
        assert info.judge_notes is None
        assert info.reward_breakdown is None
        assert info.top_failure_reasons == []

    def test_step_info_with_audit_fields(self) -> None:
        info = StepInfo(
            agreement_reached=True,
            verdict="accept",
            judge_notes="All clear.",
            reward_breakdown=RewardBreakdown(rigor=0.9, feasibility=0.8, fidelity=0.7),
            top_failure_reasons=["minor issue"],
        )
        assert info.verdict == "accept"
        assert info.reward_breakdown.rigor == 0.9
        assert info.top_failure_reasons == ["minor issue"]

    def test_episode_log_invalid_action_fields_default(self) -> None:
        log = EpisodeLog()
        assert log.invalid_action_count == 0
        assert log.invalid_action_rate == 0.0

    def test_episode_log_invalid_action_fields_set(self) -> None:
        log = EpisodeLog(invalid_action_count=3, invalid_action_rate=0.25)
        assert log.invalid_action_count == 3
        assert log.invalid_action_rate == 0.25

    def test_episode_log_json_round_trip_preserves_audit(self) -> None:
        log = EpisodeLog(
            episode_id="audit-test",
            verdict="accept",
            judge_notes="Solid work.",
            top_failure_reasons=["none"],
            invalid_action_count=1,
            invalid_action_rate=0.1,
            reward_breakdown=RewardBreakdown(
                rigor=0.9,
                feasibility=0.85,
                fidelity=0.8,
                penalties={"stalling": 0.05},
            ),
        )
        raw = log.model_dump_json()
        restored = EpisodeLog.model_validate_json(raw)
        assert restored.verdict == "accept"
        assert restored.invalid_action_count == 1
        assert restored.invalid_action_rate == 0.1
        assert restored.reward_breakdown.penalties == {"stalling": 0.05}