sakha / tests /test_inference_parser.py
atharva-again's picture
test(inference): add prompt profile and rank_candidates tests
f38de94 unverified
import json
from inference import (
MAX_TOKENS,
build_fallback_action,
build_user_prompt,
extract_model_decision,
rank_candidates,
select_action,
)
from sakha.env import SakhaEnvironment
def test_extract_model_decision_parses_json_block():
payload = extract_model_decision(
'{"eligible_actions": [{"id": "A1", "patient_id": 1, "eligible": true}], "chosen_action_id": "A1", "chosen_patient_id": 1}'
)
assert payload is not None
assert payload["chosen_action_id"] == "A1"
assert payload["chosen_patient_id"] == 1
def test_extract_model_decision_returns_none_for_malformed_output():
payload = extract_model_decision("not valid json output")
assert payload is None
def test_extract_model_decision_parses_fenced_json():
payload = extract_model_decision(
'Reasoning text\n```json\n{"chosen_action_id":"A1","chosen_patient_id":1}\n```'
)
assert payload is not None
assert payload["chosen_action_id"] == "A1"
assert payload["chosen_patient_id"] == 1
def test_extract_model_decision_returns_none_for_truncated_json():
payload = extract_model_decision('{"chosen_action_id":"A1","chosen_patient_id":')
assert payload is None
def test_select_action_falls_back_to_pending_queue_head():
env = SakhaEnvironment(patient_count=18, task="hard")
obs = env.reset(seed=42)
expected = build_fallback_action(obs)
action, _scratchpad = select_action(obs, model_payload=None)
assert getattr(action.action_type, "value", action.action_type) == getattr(
expected.action_type, "value", expected.action_type
)
assert action.patient_id == expected.patient_id
def test_select_action_falls_back_on_invalid_patient():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
expected = build_fallback_action(obs)
action, _scratchpad = select_action(
obs,
model_payload={"chosen_action_id": "A3", "chosen_patient_id": 999},
)
assert getattr(action.action_type, "value", action.action_type) == getattr(
expected.action_type, "value", expected.action_type
)
def test_select_action_extracts_scratchpad():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
head_task = obs.ward_state.pending_tasks[0]
action_id = next(
key
for key, value in {
"A0": "review_patient",
"A1": "medication_round",
"A2": "check_vitals",
"A3": "escalate",
"A4": "alert_doctor",
"A5": "update_chart",
"A6": "prepare_discharge",
"A7": "ward_sweep",
}.items()
if value == head_task.required_action
)
action, scratchpad = select_action(
obs,
model_payload={
"chosen_action_id": action_id,
"chosen_patient_id": head_task.patient_id,
"scratchpad": "Round bed first",
},
)
assert scratchpad == "Round bed first"
assert getattr(action.action_type, "value", action.action_type) == head_task.required_action
def test_select_action_handles_missing_scratchpad():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
head_task = obs.ward_state.pending_tasks[0]
action, scratchpad = select_action(
obs,
model_payload={"chosen_action_id": "A0", "chosen_patient_id": head_task.patient_id},
)
assert scratchpad is None
assert action.patient_id == head_task.patient_id
def _extract_state_json(prompt: str) -> dict:
lines = prompt.splitlines()
assert lines[0] == "STATE JSON:"
return json.loads(lines[1])
def test_build_user_prompt_operational_profile_payload_shape():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
prompt = build_user_prompt(
obs,
step=1,
history=["A0(patient=1)"],
prompt_profile="operational_realism",
)
payload = _extract_state_json(prompt)
assert "pending_count" in payload
assert "pending_tasks" not in payload
assert "action_result" in payload
assert isinstance(payload["patients"], list)
assert payload["patients"]
patient = payload["patients"][0]
expected_keys = {
"bed",
"diag",
"meds",
"med_step",
"vitals",
"vit_step",
"esc",
"rev",
"rev_step",
"inc",
"discharge",
"adm_req",
"adm_reviewed",
"adm_documented",
"last_vitals",
}
assert set(patient.keys()) == expected_keys
def test_build_user_prompt_strict_bedside_excludes_pending_signals():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
prompt = build_user_prompt(
obs,
step=1,
history=[],
prompt_profile="strict_bedside",
)
payload = _extract_state_json(prompt)
assert "pending_count" not in payload
assert "pending_tasks" not in payload
def test_build_user_prompt_full_legacy_includes_pending_tasks():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
prompt = build_user_prompt(
obs,
step=1,
history=[],
prompt_profile="full_legacy",
)
payload = _extract_state_json(prompt)
assert "pending_count" in payload
assert "pending_tasks" in payload
assert isinstance(payload["pending_tasks"], list)
assert len(payload["pending_tasks"]) == len(obs.ward_state.pending_tasks)
patient = payload["patients"][0]
assert "bed_id" in patient
assert "bed" not in patient
def test_build_user_prompt_invalid_profile_fails_fast():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
try:
build_user_prompt(
obs,
step=1,
history=[],
prompt_profile="invalid_profile",
)
raise AssertionError("Expected ValueError for invalid profile")
except ValueError:
pass
def test_max_tokens_target_reduced():
assert MAX_TOKENS == 64
def test_rank_candidates_independent_of_pending_tasks():
env = SakhaEnvironment(patient_count=8, task="medium")
obs = env.reset(seed=42)
ranked_before = rank_candidates(obs)
obs.ward_state.pending_tasks = []
ranked_after = rank_candidates(obs)
assert ranked_after
assert ranked_before == ranked_after
def test_rank_candidates_prioritizes_incident_workflow_order():
env = SakhaEnvironment(patient_count=8, task="medium")
obs = env.reset(seed=42)
patient = obs.ward_state.patients[0]
patient.active_incident_id = 123
patient.incident_deadline_step = 4
patient.incident_checked = False
patient.incident_alerted = False
patient.incident_escalated = False
patient.incident_documented = False
top = rank_candidates(obs)[0]
assert top[1] == "A2"
patient.incident_checked = True
top = rank_candidates(obs)[0]
assert top[1] == "A4"
patient.incident_alerted = True
top = rank_candidates(obs)[0]
assert top[1] == "A3"
patient.incident_escalated = True
top = rank_candidates(obs)[0]
assert top[1] == "A5"
def test_rank_candidates_tie_breaks_by_due_then_bed_then_action_id():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
first = obs.ward_state.patients[0]
second = obs.ward_state.patients[1]
first.review_required = True
first.last_reviewed_step = -1
first.bed_id = 2
second.review_required = True
second.last_reviewed_step = -1
second.bed_id = 1
ranked = rank_candidates(obs)
review_actions = [r for r in ranked if r[1] == "A0"]
assert review_actions[0][2] == 1
assert review_actions[1][2] == 2
def test_rank_candidates_orders_by_due_step_for_same_bucket():
env = SakhaEnvironment(patient_count=5, task="easy")
obs = env.reset(seed=42)
first = obs.ward_state.patients[0]
second = obs.ward_state.patients[1]
first.vitals_due = True
second.vitals_due = True
first.vitals_due_by_step = 7
second.vitals_due_by_step = 3
ranked = rank_candidates(obs)
vitals_actions = [r for r in ranked if r[1] == "A2" and r[2] in {first.bed_id, second.bed_id}]
assert vitals_actions[0][2] == second.bed_id
assert vitals_actions[1][2] == first.bed_id
def test_rank_candidates_is_stable_for_same_observation():
env = SakhaEnvironment(patient_count=8, task="medium")
obs = env.reset(seed=42)
first = rank_candidates(obs)
second = rank_candidates(obs)
assert first == second