Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
File size: 4,583 Bytes
5c3cfae db03c40 5c3cfae db03c40 5c3cfae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Tests for GRPO training helpers."""
from pathlib import Path
from models import ActionType
from training_script import (
INVALID_ACTION_PENALTY,
OpenEnvReward,
available_numeric_log_keys,
build_prompt_examples,
completion_to_text,
parse_action_completion,
save_training_plots,
select_metric_key,
select_reward_key,
)
def test_completion_to_text_from_chat_messages():
completion = [
{"role": "assistant", "content": '{"action_type":"collect_sample"}'}
]
assert completion_to_text(completion) == '{"action_type":"collect_sample"}'
def test_parse_action_completion_roundtrip():
action = parse_action_completion(
'{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
'"parameters":{"min_genes":200},"confidence":0.8}'
)
assert action is not None
assert action.action_type == ActionType.RUN_QC
assert action.method == "scanpy.pp.calculate_qc_metrics"
assert action.parameters["min_genes"] == 200
assert action.confidence == 0.8
def test_parse_action_completion_accepts_reasoning_alias():
action = parse_action_completion(
'{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
)
assert action is not None
assert action.justification == "Measure quality before filtering."
def test_parse_action_completion_normalizes_run_agent_aliases():
action = parse_action_completion(
'{"action_type":"network_inference","method":"pySCENIC"}'
)
assert action is not None
assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE
assert action.method == "pySCENIC"
def test_build_prompt_examples_contains_reference_action():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=2,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
assert len(examples) == 2
assert examples[0]["scenario_name"] == "cardiac_disease_de"
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
assert '"action_type": "select_cohort"' in examples[1]["reference_action"]
def test_openenv_reward_penalizes_invalid_completion():
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
rewards = reward_fn(
completions=[[{"role": "assistant", "content": "not valid json"}]],
scenario_name=["cardiac_disease_de"],
history_actions=["[]"],
)
assert rewards == [INVALID_ACTION_PENALTY]
def test_openenv_reward_scores_valid_completion_locally():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=1,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
sample = examples[0]
rewards = reward_fn(
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
scenario_name=[sample["scenario_name"]],
history_actions=[sample["history_actions"]],
)
assert len(rewards) == 1
assert rewards[0] > 0.0
def test_log_key_selection_prefers_reward_and_metric_keys():
log_history = [
{"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
{"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
]
assert available_numeric_log_keys(log_history) == [
"loss",
"objective/kl",
"rewards/open_env_reward",
]
reward_key = select_reward_key(log_history)
assert reward_key == "rewards/open_env_reward"
assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"
def test_save_training_plots_writes_expected_files(tmp_path):
log_history = [
{"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
{"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
]
plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")
assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
for plot_path in plot_paths.values():
assert Path(plot_path).exists()
manifest_path = tmp_path / "training_plot_manifest.json"
assert manifest_path.exists()
|