NurseSim-Triage-Demo / test_semantic.py
NurseCitizenDeveloper's picture
chore: Update workspace before push
3d38e23
"""
Test script for the Semantic RL Environment.
Validates that the NurseEmbedWrapper works correctly.
"""
import numpy as np
from nursesim_rl import TriageEnv, NurseEmbedWrapper
def test_semantic_wrapper():
print("=" * 60)
print("Testing NurseEmbed Semantic RL Wrapper")
print("=" * 60)
# Create base environment
print("\n[1] Creating base TriageEnv...")
base_env = TriageEnv(max_steps=10, seed=42)
# Wrap with NurseEmbed
print("[2] Wrapping with NurseEmbedWrapper...")
semantic_env = NurseEmbedWrapper(base_env, use_vitals=True)
# Check observation space
print(f"\n[3] Observation Space Check:")
print(f" Base Env: {type(base_env.observation_space)}")
print(f" Semantic Env: {semantic_env.observation_space}")
print(f" Expected shape: (390,) [384 embed + 6 vitals]")
# Reset and get observation
print("\n[4] Resetting environment...")
obs, info = semantic_env.reset(seed=42)
print(f" Observation type: {type(obs)}")
print(f" Observation shape: {obs.shape}")
print(f" Observation range: [{obs.min():.3f}, {obs.max():.3f}]")
# Take a step
print("\n[5] Taking a step with action (Cat=3, Intervention=2)...")
action = {"triage_category": 3, "intervention": 2}
obs2, reward, terminated, truncated, info = semantic_env.step(action)
print(f" Reward: {reward}")
print(f" Terminated: {terminated}")
print(f" New observation shape: {obs2.shape}")
# Verify embedding is meaningful (not just zeros)
print("\n[6] Embedding quality check:")
embed_part = obs[:384] # First 384 dims are the embedding
print(f" Embedding L2 norm: {np.linalg.norm(embed_part):.3f}")
print(f" Embedding is normalized: {abs(np.linalg.norm(embed_part) - 1.0) < 0.1}")
# Test caching
print("\n[7] Testing embedding cache...")
obs3, _ = semantic_env.reset(seed=42) # Same seed = same patient
cache_hit = np.allclose(obs[:384], obs3[:384])
print(f" Cache working: {cache_hit}")
print("\n" + "=" * 60)
print("ALL TESTS PASSED!")
print("=" * 60)
return True
if __name__ == "__main__":
test_semantic_wrapper()