Spaces:
Running
Running
| """Quick smoke test for the environment.""" | |
| import sys | |
| sys.path.insert(0, "src/envs") | |
| from narada.server.environment import NaradaEnvironment | |
| from narada.models import NaradaAction | |
| env = NaradaEnvironment() | |
| result = env.reset(task_type="monogenic", seed=1) | |
| obs = result.observation | |
| print(f"Reset: task={obs.task_type}") | |
| print(f" Node: {obs.current_node.name} ({obs.current_node.id})") | |
| print(f" Phenotypes: {obs.phenotype_names}") | |
| print(f" Candidates: {[v.gene for v in obs.candidate_variants]}") | |
| assert 0.01 < result.reward < 0.99, f"Reset reward {result.reward} must be in (0.01, 0.99)" | |
| assert "disease_name" not in obs.info, "Disease should not leak before terminal state" | |
| assert env.state().ground_truth_variants == [], "Ground truth should not leak before done" | |
| # Hop to first neighbor | |
| if obs.current_node.connected_node_ids: | |
| hop_target = obs.current_node.connected_node_ids[0] | |
| action = NaradaAction(action_type="hop", node_id=hop_target, reasoning="Exploring") | |
| result = env.step(action) | |
| obs = result.observation | |
| print(f"After hop to {hop_target}: reward={result.reward:.4f}, node={obs.current_node.name}") | |
| # Backtrack | |
| action = NaradaAction(action_type="backtrack", reasoning="Testing backtrack") | |
| result = env.step(action) | |
| print(f"After backtrack: reward={result.reward:.4f}") | |
| # Flag a candidate | |
| v_id = obs.candidate_variants[0].id | |
| action = NaradaAction(action_type="flag_causal", variant_id=v_id, reasoning="Test flag") | |
| result = env.step(action) | |
| print(f"After flag ({v_id}): done={result.done} reward={result.reward:.4f}") | |
| assert result.done, "Episode should be done after flag_causal" | |
| assert 0.01 <= result.reward <= 0.99, f"Reward {result.reward} out of (0.01, 0.99)" | |
| # State check | |
| state = env.state() | |
| print(f"State: flagged={state.flagged_variants} ground_truth={state.ground_truth_variants[:2]}") | |
| assert state.ground_truth_variants, "Ground truth should be revealed after terminal state" | |
| # Oligogenic should allow multiple correct flags before termination. | |
| env = NaradaEnvironment() | |
| result = env.reset(task_type="oligogenic", seed=2) | |
| ground_truth = env._case.ground_truth_variant_ids # test-only access | |
| first = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[0], reasoning="First causal variant")) | |
| assert not first.done, "Oligogenic episode should continue after first correct flag" | |
| second = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[1], reasoning="Second causal variant")) | |
| assert second.done, "Oligogenic episode should end after all causal variants are flagged" | |
| assert second.reward > first.reward, "Completing oligogenic diagnosis should improve reward" | |
| print("ENV SMOKE TEST PASSED") | |