File size: 5,715 Bytes
c7b7c5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Tests for the HallucinationGuard environment."""
import pytest
from server.environment import HallucinationGuardEnvironment
class TestEnvironmentReset:
"""Tests for environment reset functionality."""
def test_reset_returns_observation(self):
"""Reset should return a valid observation."""
env = HallucinationGuardEnvironment()
obs = env.reset()
assert obs is not None
assert hasattr(obs, 'question')
assert hasattr(obs, 'context')
assert hasattr(obs, 'reward')
assert hasattr(obs, 'done')
def test_reset_sets_initial_reward_to_zero(self):
"""Initial reward should be zero."""
env = HallucinationGuardEnvironment()
obs = env.reset()
assert obs.reward == 0.0
def test_reset_sets_done_to_false(self):
"""Episode should not be done after reset."""
env = HallucinationGuardEnvironment()
obs = env.reset()
assert obs.done is False
def test_reset_provides_attempts_remaining(self):
"""Reset should indicate attempts remaining."""
env = HallucinationGuardEnvironment()
obs = env.reset()
assert obs.attempts_remaining > 0
def test_reset_with_task_id(self):
"""Reset with specific task ID should work."""
env = HallucinationGuardEnvironment()
obs = env.reset(task_id="task_1_factual_grounding")
assert obs is not None
def test_reset_clears_previous_state(self):
"""Multiple resets should produce clean state each time."""
env = HallucinationGuardEnvironment()
env.reset()
obs = env.reset()
assert obs.reward == 0.0
assert obs.done is False
class TestEnvironmentStep:
"""Tests for environment step functionality."""
def test_step_returns_observation(self):
"""Step should return a valid observation."""
env = HallucinationGuardEnvironment()
env.reset()
action = {
"answer": "test answer",
"confidence": 0.8,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
obs = env.step(action)
assert obs is not None
assert hasattr(obs, 'reward')
def test_step_reward_in_valid_range(self):
"""Step reward should be in [0.0, 1.0] range."""
env = HallucinationGuardEnvironment()
env.reset()
action = {
"answer": "test answer",
"confidence": 0.5,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
obs = env.step(action)
assert -1.0 <= obs.reward <= 1.0
def test_step_with_high_confidence(self):
"""Step with high confidence should work."""
env = HallucinationGuardEnvironment()
env.reset()
action = {
"answer": "test answer",
"confidence": 1.0,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
obs = env.step(action)
assert obs is not None
def test_step_with_low_confidence(self):
"""Step with low confidence should work."""
env = HallucinationGuardEnvironment()
env.reset()
action = {
"answer": "test answer",
"confidence": 0.1,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
obs = env.step(action)
assert obs is not None
def test_step_updates_attempts(self):
"""Step should decrement attempts remaining."""
env = HallucinationGuardEnvironment()
obs1 = env.reset()
action = {
"answer": "test",
"confidence": 0.5,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
obs2 = env.step(action)
assert obs2.attempts_remaining < obs1.attempts_remaining
class TestEnvironmentState:
"""Tests for environment state functionality."""
def test_state_returns_metadata(self):
"""State should return episode metadata."""
env = HallucinationGuardEnvironment()
env.reset()
state = env.state()
assert state is not None
assert hasattr(state, 'episode_id') or hasattr(state, 'step_count') or 'episode_id' in state or 'step_count' in state
def test_state_tracks_step_count(self):
"""State should track step count."""
env = HallucinationGuardEnvironment()
env.reset()
action = {
"answer": "test",
"confidence": 0.5,
"source_quote": "",
"reasoning": "",
"uncertainty_flags": []
}
env.step(action)
state = env.state()
# State should reflect that a step was taken
assert state is not None
class TestTaskSelection:
"""Tests for task selection."""
def test_reset_with_task_1(self):
"""Reset with task_1_factual_grounding should work."""
env = HallucinationGuardEnvironment()
obs = env.reset(task_id="task_1_factual_grounding")
assert obs is not None
def test_reset_with_task_2(self):
"""Reset with task_2_multi_hop_synthesis should work."""
env = HallucinationGuardEnvironment()
obs = env.reset(task_id="task_2_multi_hop_synthesis")
assert obs is not None
def test_reset_with_task_3(self):
"""Reset with task_3_adversarial_resistance should work."""
env = HallucinationGuardEnvironment()
obs = env.reset(task_id="task_3_adversarial_resistance")
assert obs is not None |