File size: 4,175 Bytes
5c3cfae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | """Tests for GRPO training helpers."""
from pathlib import Path
from models import ActionType
from training_script import (
INVALID_ACTION_PENALTY,
OpenEnvReward,
available_numeric_log_keys,
build_prompt_examples,
completion_to_text,
parse_action_completion,
save_training_plots,
select_metric_key,
select_reward_key,
)
def test_completion_to_text_from_chat_messages():
completion = [
{"role": "assistant", "content": '{"action_type":"collect_sample"}'}
]
assert completion_to_text(completion) == '{"action_type":"collect_sample"}'
def test_parse_action_completion_roundtrip():
action = parse_action_completion(
'{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
'"parameters":{"min_genes":200},"confidence":0.8}'
)
assert action is not None
assert action.action_type == ActionType.RUN_QC
assert action.method == "scanpy.pp.calculate_qc_metrics"
assert action.parameters["min_genes"] == 200
assert action.confidence == 0.8
def test_parse_action_completion_accepts_reasoning_alias():
action = parse_action_completion(
'{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
)
assert action is not None
assert action.justification == "Measure quality before filtering."
def test_build_prompt_examples_contains_reference_action():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=2,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
assert len(examples) == 2
assert examples[0]["scenario_name"] == "cardiac_disease_de"
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
def test_openenv_reward_penalizes_invalid_completion():
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
rewards = reward_fn(
completions=[[{"role": "assistant", "content": "not valid json"}]],
scenario_name=["cardiac_disease_de"],
history_actions=["[]"],
)
assert rewards == [INVALID_ACTION_PENALTY]
def test_openenv_reward_scores_valid_completion_locally():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=1,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
sample = examples[0]
rewards = reward_fn(
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
scenario_name=[sample["scenario_name"]],
history_actions=[sample["history_actions"]],
)
assert len(rewards) == 1
assert rewards[0] > 0.0
def test_log_key_selection_prefers_reward_and_metric_keys():
log_history = [
{"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
{"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
]
assert available_numeric_log_keys(log_history) == [
"loss",
"objective/kl",
"rewards/open_env_reward",
]
reward_key = select_reward_key(log_history)
assert reward_key == "rewards/open_env_reward"
assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"
def test_save_training_plots_writes_expected_files(tmp_path):
log_history = [
{"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
{"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
]
plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")
assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
for plot_path in plot_paths.values():
assert Path(plot_path).exists()
manifest_path = tmp_path / "training_plot_manifest.json"
assert manifest_path.exists()
|