bio-experiment / tests /test_training_script.py
Ev3Dev's picture
Upload folder using huggingface_hub
db03c40 verified
"""Tests for GRPO training helpers."""
from pathlib import Path
from models import ActionType
from training_script import (
INVALID_ACTION_PENALTY,
OpenEnvReward,
available_numeric_log_keys,
build_prompt_examples,
completion_to_text,
parse_action_completion,
save_training_plots,
select_metric_key,
select_reward_key,
)
def test_completion_to_text_from_chat_messages():
completion = [
{"role": "assistant", "content": '{"action_type":"collect_sample"}'}
]
assert completion_to_text(completion) == '{"action_type":"collect_sample"}'
def test_parse_action_completion_roundtrip():
action = parse_action_completion(
'{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
'"parameters":{"min_genes":200},"confidence":0.8}'
)
assert action is not None
assert action.action_type == ActionType.RUN_QC
assert action.method == "scanpy.pp.calculate_qc_metrics"
assert action.parameters["min_genes"] == 200
assert action.confidence == 0.8
def test_parse_action_completion_accepts_reasoning_alias():
action = parse_action_completion(
'{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
)
assert action is not None
assert action.justification == "Measure quality before filtering."
def test_parse_action_completion_normalizes_run_agent_aliases():
action = parse_action_completion(
'{"action_type":"network_inference","method":"pySCENIC"}'
)
assert action is not None
assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE
assert action.method == "pySCENIC"
def test_build_prompt_examples_contains_reference_action():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=2,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
assert len(examples) == 2
assert examples[0]["scenario_name"] == "cardiac_disease_de"
assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
assert '"action_type": "select_cohort"' in examples[1]["reference_action"]
def test_openenv_reward_penalizes_invalid_completion():
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
rewards = reward_fn(
completions=[[{"role": "assistant", "content": "not valid json"}]],
scenario_name=["cardiac_disease_de"],
history_actions=["[]"],
)
assert rewards == [INVALID_ACTION_PENALTY]
def test_openenv_reward_scores_valid_completion_locally():
examples = build_prompt_examples(
dataset_episodes=1,
rollout_steps=1,
collection_policy="heuristic",
scenario_names=["cardiac_disease_de"],
seed=0,
domain_randomise=False,
)
reward_fn = OpenEnvReward(
reward_backend="local",
base_url="http://localhost:8000",
)
sample = examples[0]
rewards = reward_fn(
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
scenario_name=[sample["scenario_name"]],
history_actions=[sample["history_actions"]],
)
assert len(rewards) == 1
assert rewards[0] > 0.0
def test_log_key_selection_prefers_reward_and_metric_keys():
log_history = [
{"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
{"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
]
assert available_numeric_log_keys(log_history) == [
"loss",
"objective/kl",
"rewards/open_env_reward",
]
reward_key = select_reward_key(log_history)
assert reward_key == "rewards/open_env_reward"
assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"
def test_save_training_plots_writes_expected_files(tmp_path):
log_history = [
{"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
{"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
]
plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")
assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
for plot_path in plot_paths.values():
assert Path(plot_path).exists()
manifest_path = tmp_path / "training_plot_manifest.json"
assert manifest_path.exists()