File size: 8,666 Bytes
f38de94
 
 
 
 
 
 
 
 
 
c3fa67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f38de94
 
 
 
 
 
 
 
 
 
 
 
 
 
8056f5b
c3fa67d
 
8056f5b
 
 
 
 
 
c3fa67d
 
52b4770
c3fa67d
 
8056f5b
 
c3fa67d
 
 
8056f5b
 
 
52b4770
 
 
 
 
8056f5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52b4770
 
 
8056f5b
 
 
52b4770
 
8056f5b
 
52b4770
 
 
 
 
8056f5b
52b4770
 
8056f5b
52b4770
 
8056f5b
f38de94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import json

from inference import (
    MAX_TOKENS,
    build_fallback_action,
    build_user_prompt,
    extract_model_decision,
    rank_candidates,
    select_action,
)
from sakha.env import SakhaEnvironment


def test_extract_model_decision_parses_json_block():
    payload = extract_model_decision(
        '{"eligible_actions": [{"id": "A1", "patient_id": 1, "eligible": true}], "chosen_action_id": "A1", "chosen_patient_id": 1}'
    )
    assert payload is not None
    assert payload["chosen_action_id"] == "A1"
    assert payload["chosen_patient_id"] == 1


def test_extract_model_decision_returns_none_for_malformed_output():
    payload = extract_model_decision("not valid json output")
    assert payload is None


def test_extract_model_decision_parses_fenced_json():
    payload = extract_model_decision(
        'Reasoning text\n```json\n{"chosen_action_id":"A1","chosen_patient_id":1}\n```'
    )
    assert payload is not None
    assert payload["chosen_action_id"] == "A1"
    assert payload["chosen_patient_id"] == 1


def test_extract_model_decision_returns_none_for_truncated_json():
    payload = extract_model_decision('{"chosen_action_id":"A1","chosen_patient_id":')
    assert payload is None


def test_select_action_falls_back_to_pending_queue_head():
    env = SakhaEnvironment(patient_count=18, task="hard")
    obs = env.reset(seed=42)
    expected = build_fallback_action(obs)
    action, _scratchpad = select_action(obs, model_payload=None)
    assert getattr(action.action_type, "value", action.action_type) == getattr(
        expected.action_type, "value", expected.action_type
    )
    assert action.patient_id == expected.patient_id


def test_select_action_falls_back_on_invalid_patient():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)
    expected = build_fallback_action(obs)
    action, _scratchpad = select_action(
        obs,
        model_payload={"chosen_action_id": "A3", "chosen_patient_id": 999},
    )
    assert getattr(action.action_type, "value", action.action_type) == getattr(
        expected.action_type, "value", expected.action_type
    )


def test_select_action_extracts_scratchpad():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)
    head_task = obs.ward_state.pending_tasks[0]
    action_id = next(
        key
        for key, value in {
            "A0": "review_patient",
            "A1": "medication_round",
            "A2": "check_vitals",
            "A3": "escalate",
            "A4": "alert_doctor",
            "A5": "update_chart",
            "A6": "prepare_discharge",
            "A7": "ward_sweep",
        }.items()
        if value == head_task.required_action
    )
    action, scratchpad = select_action(
        obs,
        model_payload={
            "chosen_action_id": action_id,
            "chosen_patient_id": head_task.patient_id,
            "scratchpad": "Round bed first",
        },
    )
    assert scratchpad == "Round bed first"
    assert getattr(action.action_type, "value", action.action_type) == head_task.required_action


def test_select_action_handles_missing_scratchpad():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)
    head_task = obs.ward_state.pending_tasks[0]
    action, scratchpad = select_action(
        obs,
        model_payload={"chosen_action_id": "A0", "chosen_patient_id": head_task.patient_id},
    )
    assert scratchpad is None
    assert action.patient_id == head_task.patient_id


def _extract_state_json(prompt: str) -> dict:
    lines = prompt.splitlines()
    assert lines[0] == "STATE JSON:"
    return json.loads(lines[1])


def test_build_user_prompt_operational_profile_payload_shape():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    prompt = build_user_prompt(
        obs,
        step=1,
        history=["A0(patient=1)"],
        prompt_profile="operational_realism",
    )
    payload = _extract_state_json(prompt)

    assert "pending_count" in payload
    assert "pending_tasks" not in payload
    assert "action_result" in payload
    assert isinstance(payload["patients"], list)
    assert payload["patients"]

    patient = payload["patients"][0]
    expected_keys = {
        "bed",
        "diag",
        "meds",
        "med_step",
        "vitals",
        "vit_step",
        "esc",
        "rev",
        "rev_step",
        "inc",
        "discharge",
        "adm_req",
        "adm_reviewed",
        "adm_documented",
        "last_vitals",
    }
    assert set(patient.keys()) == expected_keys


def test_build_user_prompt_strict_bedside_excludes_pending_signals():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    prompt = build_user_prompt(
        obs,
        step=1,
        history=[],
        prompt_profile="strict_bedside",
    )
    payload = _extract_state_json(prompt)

    assert "pending_count" not in payload
    assert "pending_tasks" not in payload


def test_build_user_prompt_full_legacy_includes_pending_tasks():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    prompt = build_user_prompt(
        obs,
        step=1,
        history=[],
        prompt_profile="full_legacy",
    )
    payload = _extract_state_json(prompt)

    assert "pending_count" in payload
    assert "pending_tasks" in payload
    assert isinstance(payload["pending_tasks"], list)
    assert len(payload["pending_tasks"]) == len(obs.ward_state.pending_tasks)
    patient = payload["patients"][0]
    assert "bed_id" in patient
    assert "bed" not in patient


def test_build_user_prompt_invalid_profile_fails_fast():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    try:
        build_user_prompt(
            obs,
            step=1,
            history=[],
            prompt_profile="invalid_profile",
        )
        raise AssertionError("Expected ValueError for invalid profile")
    except ValueError:
        pass


def test_max_tokens_target_reduced():
    assert MAX_TOKENS == 64


def test_rank_candidates_independent_of_pending_tasks():
    env = SakhaEnvironment(patient_count=8, task="medium")
    obs = env.reset(seed=42)

    ranked_before = rank_candidates(obs)
    obs.ward_state.pending_tasks = []
    ranked_after = rank_candidates(obs)

    assert ranked_after
    assert ranked_before == ranked_after


def test_rank_candidates_prioritizes_incident_workflow_order():
    env = SakhaEnvironment(patient_count=8, task="medium")
    obs = env.reset(seed=42)

    patient = obs.ward_state.patients[0]
    patient.active_incident_id = 123
    patient.incident_deadline_step = 4
    patient.incident_checked = False
    patient.incident_alerted = False
    patient.incident_escalated = False
    patient.incident_documented = False

    top = rank_candidates(obs)[0]
    assert top[1] == "A2"

    patient.incident_checked = True
    top = rank_candidates(obs)[0]
    assert top[1] == "A4"

    patient.incident_alerted = True
    top = rank_candidates(obs)[0]
    assert top[1] == "A3"

    patient.incident_escalated = True
    top = rank_candidates(obs)[0]
    assert top[1] == "A5"


def test_rank_candidates_tie_breaks_by_due_then_bed_then_action_id():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    first = obs.ward_state.patients[0]
    second = obs.ward_state.patients[1]

    first.review_required = True
    first.last_reviewed_step = -1
    first.bed_id = 2

    second.review_required = True
    second.last_reviewed_step = -1
    second.bed_id = 1

    ranked = rank_candidates(obs)
    review_actions = [r for r in ranked if r[1] == "A0"]
    assert review_actions[0][2] == 1
    assert review_actions[1][2] == 2


def test_rank_candidates_orders_by_due_step_for_same_bucket():
    env = SakhaEnvironment(patient_count=5, task="easy")
    obs = env.reset(seed=42)

    first = obs.ward_state.patients[0]
    second = obs.ward_state.patients[1]

    first.vitals_due = True
    second.vitals_due = True
    first.vitals_due_by_step = 7
    second.vitals_due_by_step = 3

    ranked = rank_candidates(obs)
    vitals_actions = [r for r in ranked if r[1] == "A2" and r[2] in {first.bed_id, second.bed_id}]
    assert vitals_actions[0][2] == second.bed_id
    assert vitals_actions[1][2] == first.bed_id


def test_rank_candidates_is_stable_for_same_observation():
    env = SakhaEnvironment(patient_count=8, task="medium")
    obs = env.reset(seed=42)
    first = rank_candidates(obs)
    second = rank_candidates(obs)
    assert first == second