Spaces:
Sleeping
Sleeping
| """ | |
| tests/test_environment.py | |
| ------------------------- | |
| Tests for episode lifecycle and action routing in RAGDebugEnvironment. | |
| Verifies that: | |
| - reset() fully initialises state and returns a valid observation. | |
| - step() increments step count and returns bounded rewards. | |
| - Each action type modifies the correct config field. | |
| - Auto-terminate fires at max_steps. | |
| - ADJUST_CHUNK_OVERLAP now triggers _recompute_S_faulted() (bug fix). | |
| """ | |
| import pytest | |
| import numpy as np | |
| from server.rag_debug_env_environment import RAGDebugEnvironment | |
| from server.constants import _MAX_STEPS | |
| from models import ( | |
| RAGDebugAction, | |
| ActionType, | |
| EmbeddingModel, | |
| RAGDebugObservation, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def env(request): | |
| """Fresh environment reset to each task.""" | |
| e = RAGDebugEnvironment() | |
| e.reset(seed=0, task_id=request.param) | |
| return e | |
| def env1(): | |
| e = RAGDebugEnvironment() | |
| e.reset(seed=0, task_id=1) | |
| return e | |
| def _step(env, action_type, params=None): | |
| action = RAGDebugAction(action_type=action_type, params=params or {}) | |
| return env.step(action) | |
| # --------------------------------------------------------------------------- | |
| # Reset | |
| # --------------------------------------------------------------------------- | |
| class TestReset: | |
| def test_reset_returns_observation(self, env): | |
| obs = env.reset(seed=1, task_id=1) | |
| assert isinstance(obs, RAGDebugObservation) | |
| def test_reset_clears_step_count(self, env1): | |
| env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15})) | |
| assert env1._state.step_count == 1 | |
| env1.reset(seed=99, task_id=1) | |
| assert env1._state.step_count == 0 | |
| def test_reset_clears_done_flag(self, env1): | |
| # Force done via SUBMIT | |
| env1.step(RAGDebugAction(action_type=ActionType.SUBMIT, params={})) | |
| assert env1._done is True | |
| env1.reset(seed=5, task_id=1) | |
| assert env1._done is False | |
| def test_reset_returns_valid_metrics(self, env): | |
| obs = env.reset(seed=2, task_id=1) | |
| m = obs.metrics | |
| assert 0.0 <= m.mean_coverage <= 1.0 | |
| assert 0.0 <= m.mean_precision <= 1.0 | |
| assert m.n_empty_retrievals >= 0 | |
| assert m.n_context_overflows >= 0 | |
| def test_reset_with_different_tasks(self): | |
| e = RAGDebugEnvironment() | |
| for task_id in (1, 2, 3): | |
| obs = e.reset(seed=0, task_id=task_id) | |
| assert obs.task_id == task_id | |
| def test_reset_invalid_task_raises(self): | |
| e = RAGDebugEnvironment() | |
| with pytest.raises(ValueError, match="task_id"): | |
| e.reset(seed=0, task_id=99) | |
| def test_reset_clears_action_history(self, env1): | |
| env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15})) | |
| env1.reset(seed=0, task_id=1) | |
| assert env1._internal_state.action_history == [] | |
| assert env1._internal_state.reward_history == [] | |
| # --------------------------------------------------------------------------- | |
| # Step lifecycle | |
| # --------------------------------------------------------------------------- | |
| class TestStep: | |
| def test_step_increments_step_count(self, env1): | |
| for expected in range(1, 4): | |
| _step(env1, ActionType.ADJUST_TOP_K, {"value": 15}) | |
| assert env1._state.step_count == expected | |
| def test_step_returns_observation(self, env1): | |
| obs = _step(env1, ActionType.ADJUST_TOP_K, {"value": 15}) | |
| assert isinstance(obs, RAGDebugObservation) | |
| def test_step_observation_reward_in_unit_interval(self, env1): | |
| obs = _step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.2}) | |
| assert obs.reward is not None | |
| assert 0.0 <= obs.reward <= 1.0 | |
| def test_step_after_done_raises(self, env1): | |
| _step(env1, ActionType.SUBMIT) | |
| with pytest.raises(RuntimeError, match="already done"): | |
| _step(env1, ActionType.ADJUST_TOP_K, {"value": 15}) | |
| def test_auto_terminate_at_max_steps(self): | |
| e = RAGDebugEnvironment() | |
| obs = e.reset(seed=0, task_id=1) | |
| for _ in range(_MAX_STEPS - 1): | |
| obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10}) | |
| assert not obs.done, "Episode should not be done before max_steps" | |
| # Final step hits max_steps | |
| obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10}) | |
| assert obs.done, "Episode should auto-terminate at max_steps" | |
| def test_done_flag_propagates_to_observation(self, env1): | |
| obs = _step(env1, ActionType.SUBMIT) | |
| assert obs.done is True | |
| def test_action_recorded_in_history(self, env1): | |
| action = RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 20}) | |
| env1.step(action) | |
| assert len(env1._internal_state.action_history) == 1 | |
| assert env1._internal_state.action_history[0].action_type == ActionType.ADJUST_TOP_K | |
| # --------------------------------------------------------------------------- | |
| # Action routing — each action modifies the correct config field | |
| # --------------------------------------------------------------------------- | |
| class TestActionRouting: | |
| def _get_config(self, env): | |
| """Grab a copy of the current config fields as a dict.""" | |
| cfg = env._config | |
| return { | |
| "chunk_size": cfg.chunk_size, | |
| "chunk_overlap": cfg.chunk_overlap, | |
| "threshold": cfg.similarity_threshold, | |
| "top_k": cfg.top_k, | |
| "model": cfg.embedding_model, | |
| "reranking": cfg.use_reranking, | |
| "context_limit": cfg.context_window_limit, | |
| } | |
| def test_adjust_chunk_size(self, env1): | |
| _step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 256}) | |
| assert env1._config.chunk_size == 256 | |
| def test_adjust_chunk_overlap(self, env1): | |
| _step(env1, ActionType.ADJUST_CHUNK_OVERLAP, {"value": 100}) | |
| assert env1._config.chunk_overlap == 100 | |
| def test_adjust_threshold(self, env1): | |
| _step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.15}) | |
| assert env1._config.similarity_threshold == pytest.approx(0.15) | |
| def test_adjust_top_k(self, env1): | |
| _step(env1, ActionType.ADJUST_TOP_K, {"value": 25}) | |
| assert env1._config.top_k == 25 | |
| def test_swap_embedding_model(self, env1): | |
| _step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "medical"}) | |
| assert env1._config.embedding_model == EmbeddingModel.MEDICAL | |
| def test_toggle_reranking_on(self, env1): | |
| assert env1._config.use_reranking is False | |
| _step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True}) | |
| assert env1._config.use_reranking is True | |
| def test_toggle_reranking_off(self, env1): | |
| _step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True}) | |
| _step(env1, ActionType.TOGGLE_RERANKING, {"enabled": False}) | |
| assert env1._config.use_reranking is False | |
| def test_adjust_context_limit(self, env1): | |
| _step(env1, ActionType.ADJUST_CONTEXT_LIMIT, {"value": 8192}) | |
| assert env1._config.context_window_limit == 8192 | |
| def test_invalid_chunk_size_sets_error(self, env1): | |
| # Set chunk_size smaller than the current chunk_overlap (default 50) | |
| # to trigger the model_validator "overlap must be < chunk_size". | |
| obs = _step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 10}) | |
| assert obs.last_action_error is not None | |
| def test_invalid_model_sets_error(self, env1): | |
| obs = _step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "nonexistent"}) | |
| assert obs.last_action_error is not None | |
| def test_unrelated_fields_unchanged_after_action(self, env1): | |
| before = self._get_config(env1) | |
| _step(env1, ActionType.ADJUST_TOP_K, {"value": 20}) | |
| after = self._get_config(env1) | |
| # Only top_k should change | |
| assert after["chunk_size"] == before["chunk_size"] | |
| assert after["threshold"] == before["threshold"] | |
| assert after["model"] == before["model"] | |
| assert after["reranking"] == before["reranking"] | |
| assert after["context_limit"] == before["context_limit"] | |
| # --------------------------------------------------------------------------- | |
| # Bug fix: ADJUST_CHUNK_OVERLAP must trigger _recompute_S_faulted() | |
| # --------------------------------------------------------------------------- | |
| class TestChunkOverlapRecompute: | |
| """ | |
| Verifies the fix for the bug where ADJUST_CHUNK_OVERLAP did not call | |
| _recompute_S_faulted(), meaning the overlap parameter had no effect on | |
| retrieval scores until a different action happened to trigger recomputation. | |
| """ | |
| def _make_env_with_chunk_too_small(self, overlap_value): | |
| """ | |
| Set up an environment where CHUNK_TOO_SMALL is active, then set a | |
| specific overlap, and return the S_faulted matrix. | |
| Uses the default chunk_size (512) so that both overlap_value=0 and | |
| overlap_value=450 are valid (450 < 512 satisfies overlap < chunk_size). | |
| """ | |
| from models import FaultConfig, FaultType as FT | |
| e = RAGDebugEnvironment() | |
| e.reset(seed=42, task_id=1) | |
| # Force CHUNK_TOO_SMALL fault so overlap modulation is relevant. | |
| e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)] | |
| # Apply the overlap we want to test. | |
| action = RAGDebugAction( | |
| action_type=ActionType.ADJUST_CHUNK_OVERLAP, | |
| params={"value": overlap_value}, | |
| ) | |
| e.step(action) | |
| return e._S_faulted.copy() | |
| def test_overlap_recompute_changes_s_faulted(self): | |
| """ | |
| Two environments identical except for chunk_overlap should have | |
| different S_faulted matrices after ADJUST_CHUNK_OVERLAP, proving | |
| the recomputation is happening. | |
| """ | |
| S_low_overlap = self._make_env_with_chunk_too_small(overlap_value=0) | |
| S_high_overlap = self._make_env_with_chunk_too_small(overlap_value=450) | |
| # With CHUNK_TOO_SMALL active, higher overlap reduces noise sigma, | |
| # so the two matrices should differ. | |
| assert not np.allclose(S_low_overlap, S_high_overlap), ( | |
| "ADJUST_CHUNK_OVERLAP should immediately recompute S_faulted; " | |
| "different overlap values should yield different matrices." | |
| ) | |
| def test_overlap_high_reduces_noise_magnitude(self): | |
| """ | |
| After fixing the bug: higher overlap should reduce the noise added by | |
| CHUNK_TOO_SMALL, making the faulted matrix closer to S_true. | |
| Uses chunk_size=512 (default) so both overlap values (0, 450) are valid. | |
| """ | |
| from models import FaultConfig, FaultType as FT | |
| def _make_and_get_diff(overlap_value): | |
| e = RAGDebugEnvironment() | |
| e.reset(seed=7, task_id=1) | |
| e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)] | |
| # Capture S_true before overlap action (use default chunk_size=512) | |
| model_key = "general" | |
| S_true = e._s_true_episode[model_key].copy() | |
| e.step(RAGDebugAction( | |
| action_type=ActionType.ADJUST_CHUNK_OVERLAP, | |
| params={"value": overlap_value}, | |
| )) | |
| return float(np.abs(e._S_faulted - S_true).mean()) | |
| diff_low = _make_and_get_diff(0) | |
| diff_high = _make_and_get_diff(450) | |
| assert diff_high < diff_low, ( | |
| "Higher overlap should reduce CHUNK_TOO_SMALL noise, " | |
| "making S_faulted closer to S_true" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # SUBMIT grading | |
| # --------------------------------------------------------------------------- | |
| class TestSubmit: | |
| def test_submit_sets_done(self, env1): | |
| obs = _step(env1, ActionType.SUBMIT) | |
| assert obs.done is True | |
| def test_submit_success_reward_in_range(self): | |
| """After enough improvement, submit should yield a high reward.""" | |
| e = RAGDebugEnvironment() | |
| e.reset(seed=0, task_id=1) | |
| # Adjust threshold low to maximise coverage, then submit | |
| _step(e, ActionType.ADJUST_THRESHOLD, {"value": 0.05}) | |
| _step(e, ActionType.ADJUST_TOP_K, {"value": 50}) | |
| obs = _step(e, ActionType.SUBMIT) | |
| # Reward should be in [0.7, 1.0] or [0.0, 0.2] depending on success | |
| assert obs.reward is not None | |
| assert 0.0 <= obs.reward <= 1.0 | |
| def test_early_submit_penalty_reward_low(self, env1): | |
| """Submitting immediately (without fixing anything) should give a low reward.""" | |
| obs = _step(env1, ActionType.SUBMIT) | |
| # Immediate submit without any fixes likely yields failure reward in [0, 0.2] | |
| # This is not guaranteed to always be < 0.7 depending on episode, but | |
| # it's the expected case for a fresh poorly-tuned environment. | |
| assert obs.reward is not None | |
| assert 0.0 <= obs.reward <= 1.0 | |