| | """ |
| | Test script for the Semantic RL Environment. |
| | Validates that the NurseEmbedWrapper works correctly. |
| | """ |
| |
|
| | import numpy as np |
| | from nursesim_rl import TriageEnv, NurseEmbedWrapper |
| |
|
| | def test_semantic_wrapper(): |
| | print("=" * 60) |
| | print("Testing NurseEmbed Semantic RL Wrapper") |
| | print("=" * 60) |
| | |
| | |
| | print("\n[1] Creating base TriageEnv...") |
| | base_env = TriageEnv(max_steps=10, seed=42) |
| | |
| | |
| | print("[2] Wrapping with NurseEmbedWrapper...") |
| | semantic_env = NurseEmbedWrapper(base_env, use_vitals=True) |
| | |
| | |
| | print(f"\n[3] Observation Space Check:") |
| | print(f" Base Env: {type(base_env.observation_space)}") |
| | print(f" Semantic Env: {semantic_env.observation_space}") |
| | print(f" Expected shape: (390,) [384 embed + 6 vitals]") |
| | |
| | |
| | print("\n[4] Resetting environment...") |
| | obs, info = semantic_env.reset(seed=42) |
| | |
| | print(f" Observation type: {type(obs)}") |
| | print(f" Observation shape: {obs.shape}") |
| | print(f" Observation range: [{obs.min():.3f}, {obs.max():.3f}]") |
| | |
| | |
| | print("\n[5] Taking a step with action (Cat=3, Intervention=2)...") |
| | action = {"triage_category": 3, "intervention": 2} |
| | obs2, reward, terminated, truncated, info = semantic_env.step(action) |
| | |
| | print(f" Reward: {reward}") |
| | print(f" Terminated: {terminated}") |
| | print(f" New observation shape: {obs2.shape}") |
| | |
| | |
| | print("\n[6] Embedding quality check:") |
| | embed_part = obs[:384] |
| | print(f" Embedding L2 norm: {np.linalg.norm(embed_part):.3f}") |
| | print(f" Embedding is normalized: {abs(np.linalg.norm(embed_part) - 1.0) < 0.1}") |
| | |
| | |
| | print("\n[7] Testing embedding cache...") |
| | obs3, _ = semantic_env.reset(seed=42) |
| | cache_hit = np.allclose(obs[:384], obs3[:384]) |
| | print(f" Cache working: {cache_hit}") |
| | |
| | print("\n" + "=" * 60) |
| | print("ALL TESTS PASSED!") |
| | print("=" * 60) |
| | |
| | return True |
| |
|
| | if __name__ == "__main__": |
| | test_semantic_wrapper() |
| |
|