File size: 1,753 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Smoke tests for the main environment."""
import pytest
import numpy as np
from env.spindleflow_env import SpindleFlowEnv


@pytest.fixture
def env():
    e = SpindleFlowEnv(
        config_path="configs/training_config.yaml",
        catalog_path="configs/specialist_catalog.yaml",
        use_real_spindleflow=False,
        phase=1,
    )
    yield e
    e.close()


def test_env_reset(env):
    obs, info = env.reset()
    assert isinstance(obs, np.ndarray)
    assert obs.dtype == np.float32
    assert obs.shape == env.observation_space.shape


def test_env_step_stop(env):
    obs, _ = env.reset()
    action = np.zeros(env.action_space.shape, dtype=np.float32)
    action[0] = 1.0  # STOP action
    obs2, reward, terminated, truncated, info = env.step(action)
    assert isinstance(reward, float)
    assert isinstance(terminated, bool)


def test_env_step_call_specialist(env):
    obs, _ = env.reset()
    action = np.zeros(env.action_space.shape, dtype=np.float32)
    action[0] = 0.0  # CALL_SPECIALIST
    action[1] = 1.0  # Select first specialist
    obs2, reward, terminated, truncated, info = env.step(action)
    assert obs2.shape == env.observation_space.shape


def test_observation_space_shape(env):
    from env.state import EpisodeState
    expected_dim = EpisodeState.observation_dim(env.max_specialists)
    assert env.observation_space.shape == (expected_dim,)


def test_episode_runs_to_completion(env):
    obs, _ = env.reset()
    done = False
    steps = 0
    while not done and steps < 15:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        steps += 1
    assert done  # Episode must terminate within max_steps