File size: 2,684 Bytes
077e61e
 
 
9e4dc47
 
077e61e
9e4dc47
077e61e
 
 
 
 
 
b654948
 
 
077e61e
 
 
 
9e4dc47
077e61e
 
 
 
 
9e4dc47
077e61e
 
 
b654948
077e61e
9e4dc47
077e61e
 
 
 
 
 
 
 
b654948
 
 
 
 
 
 
 
 
 
 
077e61e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Quick smoke test for the environment."""
import sys
sys.path.insert(0, "src/envs")
from narada.server.environment import NaradaEnvironment
from narada.models import NaradaAction

env = NaradaEnvironment()
result = env.reset(task_type="monogenic", seed=1)
obs = result.observation
print(f"Reset: task={obs.task_type}")
print(f"  Node: {obs.current_node.name} ({obs.current_node.id})")
print(f"  Phenotypes: {obs.phenotype_names}")
print(f"  Candidates: {[v.gene for v in obs.candidate_variants]}")
assert 0.01 < result.reward < 0.99, f"Reset reward {result.reward} must be in (0.01, 0.99)"
assert "disease_name" not in obs.info, "Disease should not leak before terminal state"
assert env.state().ground_truth_variants == [], "Ground truth should not leak before done"

# Hop to first neighbor
if obs.current_node.connected_node_ids:
    hop_target = obs.current_node.connected_node_ids[0]
    action = NaradaAction(action_type="hop", node_id=hop_target, reasoning="Exploring")
    result = env.step(action)
    obs = result.observation
    print(f"After hop to {hop_target}: reward={result.reward:.4f}, node={obs.current_node.name}")

# Backtrack
action = NaradaAction(action_type="backtrack", reasoning="Testing backtrack")
result = env.step(action)
print(f"After backtrack: reward={result.reward:.4f}")

# Flag a candidate
v_id = obs.candidate_variants[0].id
action = NaradaAction(action_type="flag_causal", variant_id=v_id, reasoning="Test flag")
result = env.step(action)
print(f"After flag ({v_id}): done={result.done} reward={result.reward:.4f}")
assert result.done, "Episode should be done after flag_causal"
assert 0.01 <= result.reward <= 0.99, f"Reward {result.reward} out of (0.01, 0.99)"

# State check
state = env.state()
print(f"State: flagged={state.flagged_variants} ground_truth={state.ground_truth_variants[:2]}")
assert state.ground_truth_variants, "Ground truth should be revealed after terminal state"

# Oligogenic should allow multiple correct flags before termination.
env = NaradaEnvironment()
result = env.reset(task_type="oligogenic", seed=2)
ground_truth = env._case.ground_truth_variant_ids  # test-only access
first = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[0], reasoning="First causal variant"))
assert not first.done, "Oligogenic episode should continue after first correct flag"
second = env.step(NaradaAction(action_type="flag_causal", variant_id=ground_truth[1], reasoning="Second causal variant"))
assert second.done, "Oligogenic episode should end after all causal variants are flagged"
assert second.reward > first.reward, "Completing oligogenic diagnosis should improve reward"

print("ENV SMOKE TEST PASSED")