Spaces:
Paused
Paused
| """Smoke tests for the main environment.""" | |
| import pytest | |
| import numpy as np | |
| from env.spindleflow_env import SpindleFlowEnv | |
| def env(): | |
| e = SpindleFlowEnv( | |
| config_path="configs/training_config.yaml", | |
| catalog_path="configs/specialist_catalog.yaml", | |
| use_real_spindleflow=False, | |
| phase=1, | |
| ) | |
| yield e | |
| e.close() | |
| def test_env_reset(env): | |
| obs, info = env.reset() | |
| assert isinstance(obs, np.ndarray) | |
| assert obs.dtype == np.float32 | |
| assert obs.shape == env.observation_space.shape | |
| def test_env_step_stop(env): | |
| obs, _ = env.reset() | |
| action = np.zeros(env.action_space.shape, dtype=np.float32) | |
| action[0] = 1.0 # STOP action | |
| obs2, reward, terminated, truncated, info = env.step(action) | |
| assert isinstance(reward, float) | |
| assert isinstance(terminated, bool) | |
| def test_env_step_call_specialist(env): | |
| obs, _ = env.reset() | |
| action = np.zeros(env.action_space.shape, dtype=np.float32) | |
| action[0] = 0.0 # CALL_SPECIALIST | |
| action[1] = 1.0 # Select first specialist | |
| obs2, reward, terminated, truncated, info = env.step(action) | |
| assert obs2.shape == env.observation_space.shape | |
| def test_observation_space_shape(env): | |
| from env.state import EpisodeState | |
| expected_dim = EpisodeState.observation_dim(env.max_specialists) | |
| assert env.observation_space.shape == (expected_dim,) | |
| def test_episode_runs_to_completion(env): | |
| obs, _ = env.reset() | |
| done = False | |
| steps = 0 | |
| while not done and steps < 15: | |
| action = env.action_space.sample() | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| steps += 1 | |
| assert done # Episode must terminate within max_steps | |