narada-env / scripts /test_env.py
Jin2413's picture
upgrades for openenv validation
b654948
Raw
History Blame Contribute Delete
2.68 kB
"""Quick smoke test for the environment."""
import sys
sys.path.insert(0, "src/envs")
from narada.server.environment import NaradaEnvironment
from narada.models import NaradaAction
env = NaradaEnvironment()
result = env.reset(task_type="monogenic", seed=1)
obs = result.observation
print(f"Reset: task={obs.task_type}")
print(f" Node: {obs.current_node.name} ({obs.current_node.id})")
print(f" Phenotypes: {obs.phenotype_names}")
print(f" Candidates: {[v.gene for v in obs.candidate_variants]}")
assert 0.01 < result.reward < 0.99, f"Reset reward {result.reward} must be in (0.01, 0.99)"
assert "disease_name" not in obs.info, "Disease should not leak before terminal state"
assert env.state().ground_truth_variants == [], "Ground truth should not leak before done"
# Hop to first neighbor
if obs.current_node.connected_node_ids:
hop_target = obs.current_node.connected_node_ids[0]
action = NaradaAction(action_type="hop", node_id=hop_target, reasoning="Exploring")
result = env.step(action)
obs = result.observation
print(f"After hop to {hop_target}: reward={result.reward:.4f}, node={obs.current_node.name}")
# Backtrack
action = NaradaAction(action_type="backtrack", reasoning="Testing backtrack")
result = env.step(action)
print(f"After backtrack: reward={result.reward:.4f}")
# Flag a candidate
v_id = obs.candidate_variants[0].id
action = NaradaAction(action_type="flag_causal", variant_id=v_id, reasoning="Test flag")
result = env.step(action)
print(f"After flag ({v_id}): done={result.done} reward={result.reward:.4f}")
assert result.done, "Episode should be done after flag_causal"
assert 0.01 <= result.reward <= 0.99, f"Reward {result.reward} out of (0.01, 0.99)"
# State check
state = env.state()
print(f"State: flagged={state.flagged_variants} ground_truth={state.ground_truth_variants[:2]}")
assert state.ground_truth_variants, "Ground truth should be revealed after terminal state"
# Oligogenic should allow multiple correct flags before termination.
env = NaradaEnvironment()
result = env.reset(task_type="oligogenic", seed=2)
ground_truth = env._case.ground_truth_variant_ids # test-only access
first = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[0], reasoning="First causal variant"))
assert not first.done, "Oligogenic episode should continue after first correct flag"
second = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[1], reasoning="Second causal variant"))
assert second.done, "Oligogenic episode should end after all causal variants are flagged"
assert second.reward > first.reward, "Completing oligogenic diagnosis should improve reward"
print("ENV SMOKE TEST PASSED")