Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Tests for GRPO training helpers.""" | |
| from pathlib import Path | |
| from models import ActionType | |
| from training_script import ( | |
| INVALID_ACTION_PENALTY, | |
| OpenEnvReward, | |
| available_numeric_log_keys, | |
| build_prompt_examples, | |
| completion_to_text, | |
| parse_action_completion, | |
| save_training_plots, | |
| select_metric_key, | |
| select_reward_key, | |
| ) | |
| def test_completion_to_text_from_chat_messages(): | |
| completion = [ | |
| {"role": "assistant", "content": '{"action_type":"collect_sample"}'} | |
| ] | |
| assert completion_to_text(completion) == '{"action_type":"collect_sample"}' | |
| def test_parse_action_completion_roundtrip(): | |
| action = parse_action_completion( | |
| '{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",' | |
| '"parameters":{"min_genes":200},"confidence":0.8}' | |
| ) | |
| assert action is not None | |
| assert action.action_type == ActionType.RUN_QC | |
| assert action.method == "scanpy.pp.calculate_qc_metrics" | |
| assert action.parameters["min_genes"] == 200 | |
| assert action.confidence == 0.8 | |
| def test_parse_action_completion_accepts_reasoning_alias(): | |
| action = parse_action_completion( | |
| '{"action_type":"run_qc","reasoning":"Measure quality before filtering."}' | |
| ) | |
| assert action is not None | |
| assert action.justification == "Measure quality before filtering." | |
| def test_parse_action_completion_normalizes_run_agent_aliases(): | |
| action = parse_action_completion( | |
| '{"action_type":"network_inference","method":"pySCENIC"}' | |
| ) | |
| assert action is not None | |
| assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE | |
| assert action.method == "pySCENIC" | |
| def test_build_prompt_examples_contains_reference_action(): | |
| examples = build_prompt_examples( | |
| dataset_episodes=1, | |
| rollout_steps=2, | |
| collection_policy="heuristic", | |
| scenario_names=["cardiac_disease_de"], | |
| seed=0, | |
| domain_randomise=False, | |
| ) | |
| assert len(examples) == 2 | |
| assert examples[0]["scenario_name"] == "cardiac_disease_de" | |
| assert '"action_type": "collect_sample"' in examples[0]["reference_action"] | |
| assert '"action_type": "select_cohort"' in examples[1]["reference_action"] | |
| def test_openenv_reward_penalizes_invalid_completion(): | |
| reward_fn = OpenEnvReward( | |
| reward_backend="local", | |
| base_url="http://localhost:8000", | |
| ) | |
| rewards = reward_fn( | |
| completions=[[{"role": "assistant", "content": "not valid json"}]], | |
| scenario_name=["cardiac_disease_de"], | |
| history_actions=["[]"], | |
| ) | |
| assert rewards == [INVALID_ACTION_PENALTY] | |
| def test_openenv_reward_scores_valid_completion_locally(): | |
| examples = build_prompt_examples( | |
| dataset_episodes=1, | |
| rollout_steps=1, | |
| collection_policy="heuristic", | |
| scenario_names=["cardiac_disease_de"], | |
| seed=0, | |
| domain_randomise=False, | |
| ) | |
| reward_fn = OpenEnvReward( | |
| reward_backend="local", | |
| base_url="http://localhost:8000", | |
| ) | |
| sample = examples[0] | |
| rewards = reward_fn( | |
| completions=[[{"role": "assistant", "content": sample["reference_action"]}]], | |
| scenario_name=[sample["scenario_name"]], | |
| history_actions=[sample["history_actions"]], | |
| ) | |
| assert len(rewards) == 1 | |
| assert rewards[0] > 0.0 | |
| def test_log_key_selection_prefers_reward_and_metric_keys(): | |
| log_history = [ | |
| {"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05}, | |
| {"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04}, | |
| ] | |
| assert available_numeric_log_keys(log_history) == [ | |
| "loss", | |
| "objective/kl", | |
| "rewards/open_env_reward", | |
| ] | |
| reward_key = select_reward_key(log_history) | |
| assert reward_key == "rewards/open_env_reward" | |
| assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl" | |
| def test_save_training_plots_writes_expected_files(tmp_path): | |
| log_history = [ | |
| {"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8}, | |
| {"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5}, | |
| ] | |
| plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm") | |
| assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"} | |
| for plot_path in plot_paths.values(): | |
| assert Path(plot_path).exists() | |
| manifest_path = tmp_path / "training_plot_manifest.json" | |
| assert manifest_path.exists() | |