import json from inference import ( MAX_TOKENS, build_fallback_action, build_user_prompt, extract_model_decision, rank_candidates, select_action, ) from sakha.env import SakhaEnvironment def test_extract_model_decision_parses_json_block(): payload = extract_model_decision( '{"eligible_actions": [{"id": "A1", "patient_id": 1, "eligible": true}], "chosen_action_id": "A1", "chosen_patient_id": 1}' ) assert payload is not None assert payload["chosen_action_id"] == "A1" assert payload["chosen_patient_id"] == 1 def test_extract_model_decision_returns_none_for_malformed_output(): payload = extract_model_decision("not valid json output") assert payload is None def test_extract_model_decision_parses_fenced_json(): payload = extract_model_decision( 'Reasoning text\n```json\n{"chosen_action_id":"A1","chosen_patient_id":1}\n```' ) assert payload is not None assert payload["chosen_action_id"] == "A1" assert payload["chosen_patient_id"] == 1 def test_extract_model_decision_returns_none_for_truncated_json(): payload = extract_model_decision('{"chosen_action_id":"A1","chosen_patient_id":') assert payload is None def test_select_action_falls_back_to_pending_queue_head(): env = SakhaEnvironment(patient_count=18, task="hard") obs = env.reset(seed=42) expected = build_fallback_action(obs) action, _scratchpad = select_action(obs, model_payload=None) assert getattr(action.action_type, "value", action.action_type) == getattr( expected.action_type, "value", expected.action_type ) assert action.patient_id == expected.patient_id def test_select_action_falls_back_on_invalid_patient(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) expected = build_fallback_action(obs) action, _scratchpad = select_action( obs, model_payload={"chosen_action_id": "A3", "chosen_patient_id": 999}, ) assert getattr(action.action_type, "value", action.action_type) == getattr( expected.action_type, "value", expected.action_type ) def test_select_action_extracts_scratchpad(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) head_task = obs.ward_state.pending_tasks[0] action_id = next( key for key, value in { "A0": "review_patient", "A1": "medication_round", "A2": "check_vitals", "A3": "escalate", "A4": "alert_doctor", "A5": "update_chart", "A6": "prepare_discharge", "A7": "ward_sweep", }.items() if value == head_task.required_action ) action, scratchpad = select_action( obs, model_payload={ "chosen_action_id": action_id, "chosen_patient_id": head_task.patient_id, "scratchpad": "Round bed first", }, ) assert scratchpad == "Round bed first" assert getattr(action.action_type, "value", action.action_type) == head_task.required_action def test_select_action_handles_missing_scratchpad(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) head_task = obs.ward_state.pending_tasks[0] action, scratchpad = select_action( obs, model_payload={"chosen_action_id": "A0", "chosen_patient_id": head_task.patient_id}, ) assert scratchpad is None assert action.patient_id == head_task.patient_id def _extract_state_json(prompt: str) -> dict: lines = prompt.splitlines() assert lines[0] == "STATE JSON:" return json.loads(lines[1]) def test_build_user_prompt_operational_profile_payload_shape(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) prompt = build_user_prompt( obs, step=1, history=["A0(patient=1)"], prompt_profile="operational_realism", ) payload = _extract_state_json(prompt) assert "pending_count" in payload assert "pending_tasks" not in payload assert "action_result" in payload assert isinstance(payload["patients"], list) assert payload["patients"] patient = payload["patients"][0] expected_keys = { "bed", "diag", "meds", "med_step", "vitals", "vit_step", "esc", "rev", "rev_step", "inc", "discharge", "adm_req", "adm_reviewed", "adm_documented", "last_vitals", } assert set(patient.keys()) == expected_keys def test_build_user_prompt_strict_bedside_excludes_pending_signals(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) prompt = build_user_prompt( obs, step=1, history=[], prompt_profile="strict_bedside", ) payload = _extract_state_json(prompt) assert "pending_count" not in payload assert "pending_tasks" not in payload def test_build_user_prompt_full_legacy_includes_pending_tasks(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) prompt = build_user_prompt( obs, step=1, history=[], prompt_profile="full_legacy", ) payload = _extract_state_json(prompt) assert "pending_count" in payload assert "pending_tasks" in payload assert isinstance(payload["pending_tasks"], list) assert len(payload["pending_tasks"]) == len(obs.ward_state.pending_tasks) patient = payload["patients"][0] assert "bed_id" in patient assert "bed" not in patient def test_build_user_prompt_invalid_profile_fails_fast(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) try: build_user_prompt( obs, step=1, history=[], prompt_profile="invalid_profile", ) raise AssertionError("Expected ValueError for invalid profile") except ValueError: pass def test_max_tokens_target_reduced(): assert MAX_TOKENS == 64 def test_rank_candidates_independent_of_pending_tasks(): env = SakhaEnvironment(patient_count=8, task="medium") obs = env.reset(seed=42) ranked_before = rank_candidates(obs) obs.ward_state.pending_tasks = [] ranked_after = rank_candidates(obs) assert ranked_after assert ranked_before == ranked_after def test_rank_candidates_prioritizes_incident_workflow_order(): env = SakhaEnvironment(patient_count=8, task="medium") obs = env.reset(seed=42) patient = obs.ward_state.patients[0] patient.active_incident_id = 123 patient.incident_deadline_step = 4 patient.incident_checked = False patient.incident_alerted = False patient.incident_escalated = False patient.incident_documented = False top = rank_candidates(obs)[0] assert top[1] == "A2" patient.incident_checked = True top = rank_candidates(obs)[0] assert top[1] == "A4" patient.incident_alerted = True top = rank_candidates(obs)[0] assert top[1] == "A3" patient.incident_escalated = True top = rank_candidates(obs)[0] assert top[1] == "A5" def test_rank_candidates_tie_breaks_by_due_then_bed_then_action_id(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) first = obs.ward_state.patients[0] second = obs.ward_state.patients[1] first.review_required = True first.last_reviewed_step = -1 first.bed_id = 2 second.review_required = True second.last_reviewed_step = -1 second.bed_id = 1 ranked = rank_candidates(obs) review_actions = [r for r in ranked if r[1] == "A0"] assert review_actions[0][2] == 1 assert review_actions[1][2] == 2 def test_rank_candidates_orders_by_due_step_for_same_bucket(): env = SakhaEnvironment(patient_count=5, task="easy") obs = env.reset(seed=42) first = obs.ward_state.patients[0] second = obs.ward_state.patients[1] first.vitals_due = True second.vitals_due = True first.vitals_due_by_step = 7 second.vitals_due_by_step = 3 ranked = rank_candidates(obs) vitals_actions = [r for r in ranked if r[1] == "A2" and r[2] in {first.bed_id, second.bed_id}] assert vitals_actions[0][2] == second.bed_id assert vitals_actions[1][2] == first.bed_id def test_rank_candidates_is_stable_for_same_observation(): env = SakhaEnvironment(patient_count=8, task="medium") obs = env.reset(seed=42) first = rank_candidates(obs) second = rank_candidates(obs) assert first == second