File size: 4,583 Bytes
5c3cfae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db03c40
 
 
 
 
 
 
 
 
5c3cfae
 
 
 
 
 
 
 
 
 
 
 
db03c40
5c3cfae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Tests for GRPO training helpers."""

from pathlib import Path

from models import ActionType
from training_script import (
    INVALID_ACTION_PENALTY,
    OpenEnvReward,
    available_numeric_log_keys,
    build_prompt_examples,
    completion_to_text,
    parse_action_completion,
    save_training_plots,
    select_metric_key,
    select_reward_key,
)


def test_completion_to_text_from_chat_messages():
    completion = [
        {"role": "assistant", "content": '{"action_type":"collect_sample"}'}
    ]
    assert completion_to_text(completion) == '{"action_type":"collect_sample"}'


def test_parse_action_completion_roundtrip():
    action = parse_action_completion(
        '{"action_type":"run_qc","method":"scanpy.pp.calculate_qc_metrics",'
        '"parameters":{"min_genes":200},"confidence":0.8}'
    )
    assert action is not None
    assert action.action_type == ActionType.RUN_QC
    assert action.method == "scanpy.pp.calculate_qc_metrics"
    assert action.parameters["min_genes"] == 200
    assert action.confidence == 0.8


def test_parse_action_completion_accepts_reasoning_alias():
    action = parse_action_completion(
        '{"action_type":"run_qc","reasoning":"Measure quality before filtering."}'
    )
    assert action is not None
    assert action.justification == "Measure quality before filtering."


def test_parse_action_completion_normalizes_run_agent_aliases():
    action = parse_action_completion(
        '{"action_type":"network_inference","method":"pySCENIC"}'
    )
    assert action is not None
    assert action.action_type == ActionType.REGULATORY_NETWORK_INFERENCE
    assert action.method == "pySCENIC"


def test_build_prompt_examples_contains_reference_action():
    examples = build_prompt_examples(
        dataset_episodes=1,
        rollout_steps=2,
        collection_policy="heuristic",
        scenario_names=["cardiac_disease_de"],
        seed=0,
        domain_randomise=False,
    )
    assert len(examples) == 2
    assert examples[0]["scenario_name"] == "cardiac_disease_de"
    assert '"action_type": "collect_sample"' in examples[0]["reference_action"]
    assert '"action_type": "select_cohort"' in examples[1]["reference_action"]


def test_openenv_reward_penalizes_invalid_completion():
    reward_fn = OpenEnvReward(
        reward_backend="local",
        base_url="http://localhost:8000",
    )
    rewards = reward_fn(
        completions=[[{"role": "assistant", "content": "not valid json"}]],
        scenario_name=["cardiac_disease_de"],
        history_actions=["[]"],
    )
    assert rewards == [INVALID_ACTION_PENALTY]


def test_openenv_reward_scores_valid_completion_locally():
    examples = build_prompt_examples(
        dataset_episodes=1,
        rollout_steps=1,
        collection_policy="heuristic",
        scenario_names=["cardiac_disease_de"],
        seed=0,
        domain_randomise=False,
    )
    reward_fn = OpenEnvReward(
        reward_backend="local",
        base_url="http://localhost:8000",
    )
    sample = examples[0]
    rewards = reward_fn(
        completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
        scenario_name=[sample["scenario_name"]],
        history_actions=[sample["history_actions"]],
    )
    assert len(rewards) == 1
    assert rewards[0] > 0.0


def test_log_key_selection_prefers_reward_and_metric_keys():
    log_history = [
        {"step": 1, "loss": 1.2, "rewards/open_env_reward": 0.4, "objective/kl": 0.05},
        {"step": 2, "loss": 1.0, "rewards/open_env_reward": 0.6, "objective/kl": 0.04},
    ]
    assert available_numeric_log_keys(log_history) == [
        "loss",
        "objective/kl",
        "rewards/open_env_reward",
    ]
    reward_key = select_reward_key(log_history)
    assert reward_key == "rewards/open_env_reward"
    assert select_metric_key(log_history, reward_key=reward_key) == "objective/kl"


def test_save_training_plots_writes_expected_files(tmp_path):
    log_history = [
        {"step": 1, "loss": 1.2, "reward": 0.4, "grad_norm": 0.8},
        {"step": 2, "loss": 0.9, "reward": 0.7, "grad_norm": 0.5},
    ]
    plot_paths = save_training_plots(log_history, tmp_path, metric_key="grad_norm")

    assert set(plot_paths) == {"loss", "reward", "metric", "dashboard"}
    for plot_path in plot_paths.values():
        assert Path(plot_path).exists()

    manifest_path = tmp_path / "training_plot_manifest.json"
    assert manifest_path.exists()