Spaces:
Sleeping
Sleeping
File size: 13,084 Bytes
ac224ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 | """
tests/test_environment.py
-------------------------
Tests for episode lifecycle and action routing in RAGDebugEnvironment.
Verifies that:
- reset() fully initialises state and returns a valid observation.
- step() increments step count and returns bounded rewards.
- Each action type modifies the correct config field.
- Auto-terminate fires at max_steps.
- ADJUST_CHUNK_OVERLAP now triggers _recompute_S_faulted() (bug fix).
"""
import pytest
import numpy as np
from server.rag_debug_env_environment import RAGDebugEnvironment
from server.constants import _MAX_STEPS
from models import (
RAGDebugAction,
ActionType,
EmbeddingModel,
RAGDebugObservation,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(params=[1, 2, 3])
def env(request):
"""Fresh environment reset to each task."""
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=request.param)
return e
@pytest.fixture
def env1():
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=1)
return e
def _step(env, action_type, params=None):
action = RAGDebugAction(action_type=action_type, params=params or {})
return env.step(action)
# ---------------------------------------------------------------------------
# Reset
# ---------------------------------------------------------------------------
class TestReset:
def test_reset_returns_observation(self, env):
obs = env.reset(seed=1, task_id=1)
assert isinstance(obs, RAGDebugObservation)
def test_reset_clears_step_count(self, env1):
env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15}))
assert env1._state.step_count == 1
env1.reset(seed=99, task_id=1)
assert env1._state.step_count == 0
def test_reset_clears_done_flag(self, env1):
# Force done via SUBMIT
env1.step(RAGDebugAction(action_type=ActionType.SUBMIT, params={}))
assert env1._done is True
env1.reset(seed=5, task_id=1)
assert env1._done is False
def test_reset_returns_valid_metrics(self, env):
obs = env.reset(seed=2, task_id=1)
m = obs.metrics
assert 0.0 <= m.mean_coverage <= 1.0
assert 0.0 <= m.mean_precision <= 1.0
assert m.n_empty_retrievals >= 0
assert m.n_context_overflows >= 0
def test_reset_with_different_tasks(self):
e = RAGDebugEnvironment()
for task_id in (1, 2, 3):
obs = e.reset(seed=0, task_id=task_id)
assert obs.task_id == task_id
def test_reset_invalid_task_raises(self):
e = RAGDebugEnvironment()
with pytest.raises(ValueError, match="task_id"):
e.reset(seed=0, task_id=99)
def test_reset_clears_action_history(self, env1):
env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15}))
env1.reset(seed=0, task_id=1)
assert env1._internal_state.action_history == []
assert env1._internal_state.reward_history == []
# ---------------------------------------------------------------------------
# Step lifecycle
# ---------------------------------------------------------------------------
class TestStep:
def test_step_increments_step_count(self, env1):
for expected in range(1, 4):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
assert env1._state.step_count == expected
def test_step_returns_observation(self, env1):
obs = _step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
assert isinstance(obs, RAGDebugObservation)
def test_step_observation_reward_in_unit_interval(self, env1):
obs = _step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.2})
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0
def test_step_after_done_raises(self, env1):
_step(env1, ActionType.SUBMIT)
with pytest.raises(RuntimeError, match="already done"):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
def test_auto_terminate_at_max_steps(self):
e = RAGDebugEnvironment()
obs = e.reset(seed=0, task_id=1)
for _ in range(_MAX_STEPS - 1):
obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10})
assert not obs.done, "Episode should not be done before max_steps"
# Final step hits max_steps
obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10})
assert obs.done, "Episode should auto-terminate at max_steps"
def test_done_flag_propagates_to_observation(self, env1):
obs = _step(env1, ActionType.SUBMIT)
assert obs.done is True
def test_action_recorded_in_history(self, env1):
action = RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 20})
env1.step(action)
assert len(env1._internal_state.action_history) == 1
assert env1._internal_state.action_history[0].action_type == ActionType.ADJUST_TOP_K
# ---------------------------------------------------------------------------
# Action routing — each action modifies the correct config field
# ---------------------------------------------------------------------------
class TestActionRouting:
def _get_config(self, env):
"""Grab a copy of the current config fields as a dict."""
cfg = env._config
return {
"chunk_size": cfg.chunk_size,
"chunk_overlap": cfg.chunk_overlap,
"threshold": cfg.similarity_threshold,
"top_k": cfg.top_k,
"model": cfg.embedding_model,
"reranking": cfg.use_reranking,
"context_limit": cfg.context_window_limit,
}
def test_adjust_chunk_size(self, env1):
_step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 256})
assert env1._config.chunk_size == 256
def test_adjust_chunk_overlap(self, env1):
_step(env1, ActionType.ADJUST_CHUNK_OVERLAP, {"value": 100})
assert env1._config.chunk_overlap == 100
def test_adjust_threshold(self, env1):
_step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.15})
assert env1._config.similarity_threshold == pytest.approx(0.15)
def test_adjust_top_k(self, env1):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 25})
assert env1._config.top_k == 25
def test_swap_embedding_model(self, env1):
_step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "medical"})
assert env1._config.embedding_model == EmbeddingModel.MEDICAL
def test_toggle_reranking_on(self, env1):
assert env1._config.use_reranking is False
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True})
assert env1._config.use_reranking is True
def test_toggle_reranking_off(self, env1):
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True})
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": False})
assert env1._config.use_reranking is False
def test_adjust_context_limit(self, env1):
_step(env1, ActionType.ADJUST_CONTEXT_LIMIT, {"value": 8192})
assert env1._config.context_window_limit == 8192
def test_invalid_chunk_size_sets_error(self, env1):
# Set chunk_size smaller than the current chunk_overlap (default 50)
# to trigger the model_validator "overlap must be < chunk_size".
obs = _step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 10})
assert obs.last_action_error is not None
def test_invalid_model_sets_error(self, env1):
obs = _step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "nonexistent"})
assert obs.last_action_error is not None
def test_unrelated_fields_unchanged_after_action(self, env1):
before = self._get_config(env1)
_step(env1, ActionType.ADJUST_TOP_K, {"value": 20})
after = self._get_config(env1)
# Only top_k should change
assert after["chunk_size"] == before["chunk_size"]
assert after["threshold"] == before["threshold"]
assert after["model"] == before["model"]
assert after["reranking"] == before["reranking"]
assert after["context_limit"] == before["context_limit"]
# ---------------------------------------------------------------------------
# Bug fix: ADJUST_CHUNK_OVERLAP must trigger _recompute_S_faulted()
# ---------------------------------------------------------------------------
class TestChunkOverlapRecompute:
"""
Verifies the fix for the bug where ADJUST_CHUNK_OVERLAP did not call
_recompute_S_faulted(), meaning the overlap parameter had no effect on
retrieval scores until a different action happened to trigger recomputation.
"""
def _make_env_with_chunk_too_small(self, overlap_value):
"""
Set up an environment where CHUNK_TOO_SMALL is active, then set a
specific overlap, and return the S_faulted matrix.
Uses the default chunk_size (512) so that both overlap_value=0 and
overlap_value=450 are valid (450 < 512 satisfies overlap < chunk_size).
"""
from models import FaultConfig, FaultType as FT
e = RAGDebugEnvironment()
e.reset(seed=42, task_id=1)
# Force CHUNK_TOO_SMALL fault so overlap modulation is relevant.
e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)]
# Apply the overlap we want to test.
action = RAGDebugAction(
action_type=ActionType.ADJUST_CHUNK_OVERLAP,
params={"value": overlap_value},
)
e.step(action)
return e._S_faulted.copy()
def test_overlap_recompute_changes_s_faulted(self):
"""
Two environments identical except for chunk_overlap should have
different S_faulted matrices after ADJUST_CHUNK_OVERLAP, proving
the recomputation is happening.
"""
S_low_overlap = self._make_env_with_chunk_too_small(overlap_value=0)
S_high_overlap = self._make_env_with_chunk_too_small(overlap_value=450)
# With CHUNK_TOO_SMALL active, higher overlap reduces noise sigma,
# so the two matrices should differ.
assert not np.allclose(S_low_overlap, S_high_overlap), (
"ADJUST_CHUNK_OVERLAP should immediately recompute S_faulted; "
"different overlap values should yield different matrices."
)
def test_overlap_high_reduces_noise_magnitude(self):
"""
After fixing the bug: higher overlap should reduce the noise added by
CHUNK_TOO_SMALL, making the faulted matrix closer to S_true.
Uses chunk_size=512 (default) so both overlap values (0, 450) are valid.
"""
from models import FaultConfig, FaultType as FT
def _make_and_get_diff(overlap_value):
e = RAGDebugEnvironment()
e.reset(seed=7, task_id=1)
e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)]
# Capture S_true before overlap action (use default chunk_size=512)
model_key = "general"
S_true = e._s_true_episode[model_key].copy()
e.step(RAGDebugAction(
action_type=ActionType.ADJUST_CHUNK_OVERLAP,
params={"value": overlap_value},
))
return float(np.abs(e._S_faulted - S_true).mean())
diff_low = _make_and_get_diff(0)
diff_high = _make_and_get_diff(450)
assert diff_high < diff_low, (
"Higher overlap should reduce CHUNK_TOO_SMALL noise, "
"making S_faulted closer to S_true"
)
# ---------------------------------------------------------------------------
# SUBMIT grading
# ---------------------------------------------------------------------------
class TestSubmit:
def test_submit_sets_done(self, env1):
obs = _step(env1, ActionType.SUBMIT)
assert obs.done is True
def test_submit_success_reward_in_range(self):
"""After enough improvement, submit should yield a high reward."""
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=1)
# Adjust threshold low to maximise coverage, then submit
_step(e, ActionType.ADJUST_THRESHOLD, {"value": 0.05})
_step(e, ActionType.ADJUST_TOP_K, {"value": 50})
obs = _step(e, ActionType.SUBMIT)
# Reward should be in [0.7, 1.0] or [0.0, 0.2] depending on success
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0
def test_early_submit_penalty_reward_low(self, env1):
"""Submitting immediately (without fixing anything) should give a low reward."""
obs = _step(env1, ActionType.SUBMIT)
# Immediate submit without any fixes likely yields failure reward in [0, 0.2]
# This is not guaranteed to always be < 0.7 depending on episode, but
# it's the expected case for a fresh poorly-tuned environment.
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0
|