rag_debug_env / tests /test_environment.py
vankap-grover's picture
Upload folder using huggingface_hub
ac224ce verified
"""
tests/test_environment.py
-------------------------
Tests for episode lifecycle and action routing in RAGDebugEnvironment.
Verifies that:
- reset() fully initialises state and returns a valid observation.
- step() increments step count and returns bounded rewards.
- Each action type modifies the correct config field.
- Auto-terminate fires at max_steps.
- ADJUST_CHUNK_OVERLAP now triggers _recompute_S_faulted() (bug fix).
"""
import pytest
import numpy as np
from server.rag_debug_env_environment import RAGDebugEnvironment
from server.constants import _MAX_STEPS
from models import (
RAGDebugAction,
ActionType,
EmbeddingModel,
RAGDebugObservation,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(params=[1, 2, 3])
def env(request):
"""Fresh environment reset to each task."""
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=request.param)
return e
@pytest.fixture
def env1():
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=1)
return e
def _step(env, action_type, params=None):
action = RAGDebugAction(action_type=action_type, params=params or {})
return env.step(action)
# ---------------------------------------------------------------------------
# Reset
# ---------------------------------------------------------------------------
class TestReset:
def test_reset_returns_observation(self, env):
obs = env.reset(seed=1, task_id=1)
assert isinstance(obs, RAGDebugObservation)
def test_reset_clears_step_count(self, env1):
env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15}))
assert env1._state.step_count == 1
env1.reset(seed=99, task_id=1)
assert env1._state.step_count == 0
def test_reset_clears_done_flag(self, env1):
# Force done via SUBMIT
env1.step(RAGDebugAction(action_type=ActionType.SUBMIT, params={}))
assert env1._done is True
env1.reset(seed=5, task_id=1)
assert env1._done is False
def test_reset_returns_valid_metrics(self, env):
obs = env.reset(seed=2, task_id=1)
m = obs.metrics
assert 0.0 <= m.mean_coverage <= 1.0
assert 0.0 <= m.mean_precision <= 1.0
assert m.n_empty_retrievals >= 0
assert m.n_context_overflows >= 0
def test_reset_with_different_tasks(self):
e = RAGDebugEnvironment()
for task_id in (1, 2, 3):
obs = e.reset(seed=0, task_id=task_id)
assert obs.task_id == task_id
def test_reset_invalid_task_raises(self):
e = RAGDebugEnvironment()
with pytest.raises(ValueError, match="task_id"):
e.reset(seed=0, task_id=99)
def test_reset_clears_action_history(self, env1):
env1.step(RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 15}))
env1.reset(seed=0, task_id=1)
assert env1._internal_state.action_history == []
assert env1._internal_state.reward_history == []
# ---------------------------------------------------------------------------
# Step lifecycle
# ---------------------------------------------------------------------------
class TestStep:
def test_step_increments_step_count(self, env1):
for expected in range(1, 4):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
assert env1._state.step_count == expected
def test_step_returns_observation(self, env1):
obs = _step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
assert isinstance(obs, RAGDebugObservation)
def test_step_observation_reward_in_unit_interval(self, env1):
obs = _step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.2})
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0
def test_step_after_done_raises(self, env1):
_step(env1, ActionType.SUBMIT)
with pytest.raises(RuntimeError, match="already done"):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 15})
def test_auto_terminate_at_max_steps(self):
e = RAGDebugEnvironment()
obs = e.reset(seed=0, task_id=1)
for _ in range(_MAX_STEPS - 1):
obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10})
assert not obs.done, "Episode should not be done before max_steps"
# Final step hits max_steps
obs = _step(e, ActionType.ADJUST_TOP_K, {"value": 10})
assert obs.done, "Episode should auto-terminate at max_steps"
def test_done_flag_propagates_to_observation(self, env1):
obs = _step(env1, ActionType.SUBMIT)
assert obs.done is True
def test_action_recorded_in_history(self, env1):
action = RAGDebugAction(action_type=ActionType.ADJUST_TOP_K, params={"value": 20})
env1.step(action)
assert len(env1._internal_state.action_history) == 1
assert env1._internal_state.action_history[0].action_type == ActionType.ADJUST_TOP_K
# ---------------------------------------------------------------------------
# Action routing — each action modifies the correct config field
# ---------------------------------------------------------------------------
class TestActionRouting:
def _get_config(self, env):
"""Grab a copy of the current config fields as a dict."""
cfg = env._config
return {
"chunk_size": cfg.chunk_size,
"chunk_overlap": cfg.chunk_overlap,
"threshold": cfg.similarity_threshold,
"top_k": cfg.top_k,
"model": cfg.embedding_model,
"reranking": cfg.use_reranking,
"context_limit": cfg.context_window_limit,
}
def test_adjust_chunk_size(self, env1):
_step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 256})
assert env1._config.chunk_size == 256
def test_adjust_chunk_overlap(self, env1):
_step(env1, ActionType.ADJUST_CHUNK_OVERLAP, {"value": 100})
assert env1._config.chunk_overlap == 100
def test_adjust_threshold(self, env1):
_step(env1, ActionType.ADJUST_THRESHOLD, {"value": 0.15})
assert env1._config.similarity_threshold == pytest.approx(0.15)
def test_adjust_top_k(self, env1):
_step(env1, ActionType.ADJUST_TOP_K, {"value": 25})
assert env1._config.top_k == 25
def test_swap_embedding_model(self, env1):
_step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "medical"})
assert env1._config.embedding_model == EmbeddingModel.MEDICAL
def test_toggle_reranking_on(self, env1):
assert env1._config.use_reranking is False
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True})
assert env1._config.use_reranking is True
def test_toggle_reranking_off(self, env1):
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": True})
_step(env1, ActionType.TOGGLE_RERANKING, {"enabled": False})
assert env1._config.use_reranking is False
def test_adjust_context_limit(self, env1):
_step(env1, ActionType.ADJUST_CONTEXT_LIMIT, {"value": 8192})
assert env1._config.context_window_limit == 8192
def test_invalid_chunk_size_sets_error(self, env1):
# Set chunk_size smaller than the current chunk_overlap (default 50)
# to trigger the model_validator "overlap must be < chunk_size".
obs = _step(env1, ActionType.ADJUST_CHUNK_SIZE, {"value": 10})
assert obs.last_action_error is not None
def test_invalid_model_sets_error(self, env1):
obs = _step(env1, ActionType.SWAP_EMBEDDING_MODEL, {"model": "nonexistent"})
assert obs.last_action_error is not None
def test_unrelated_fields_unchanged_after_action(self, env1):
before = self._get_config(env1)
_step(env1, ActionType.ADJUST_TOP_K, {"value": 20})
after = self._get_config(env1)
# Only top_k should change
assert after["chunk_size"] == before["chunk_size"]
assert after["threshold"] == before["threshold"]
assert after["model"] == before["model"]
assert after["reranking"] == before["reranking"]
assert after["context_limit"] == before["context_limit"]
# ---------------------------------------------------------------------------
# Bug fix: ADJUST_CHUNK_OVERLAP must trigger _recompute_S_faulted()
# ---------------------------------------------------------------------------
class TestChunkOverlapRecompute:
"""
Verifies the fix for the bug where ADJUST_CHUNK_OVERLAP did not call
_recompute_S_faulted(), meaning the overlap parameter had no effect on
retrieval scores until a different action happened to trigger recomputation.
"""
def _make_env_with_chunk_too_small(self, overlap_value):
"""
Set up an environment where CHUNK_TOO_SMALL is active, then set a
specific overlap, and return the S_faulted matrix.
Uses the default chunk_size (512) so that both overlap_value=0 and
overlap_value=450 are valid (450 < 512 satisfies overlap < chunk_size).
"""
from models import FaultConfig, FaultType as FT
e = RAGDebugEnvironment()
e.reset(seed=42, task_id=1)
# Force CHUNK_TOO_SMALL fault so overlap modulation is relevant.
e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)]
# Apply the overlap we want to test.
action = RAGDebugAction(
action_type=ActionType.ADJUST_CHUNK_OVERLAP,
params={"value": overlap_value},
)
e.step(action)
return e._S_faulted.copy()
def test_overlap_recompute_changes_s_faulted(self):
"""
Two environments identical except for chunk_overlap should have
different S_faulted matrices after ADJUST_CHUNK_OVERLAP, proving
the recomputation is happening.
"""
S_low_overlap = self._make_env_with_chunk_too_small(overlap_value=0)
S_high_overlap = self._make_env_with_chunk_too_small(overlap_value=450)
# With CHUNK_TOO_SMALL active, higher overlap reduces noise sigma,
# so the two matrices should differ.
assert not np.allclose(S_low_overlap, S_high_overlap), (
"ADJUST_CHUNK_OVERLAP should immediately recompute S_faulted; "
"different overlap values should yield different matrices."
)
def test_overlap_high_reduces_noise_magnitude(self):
"""
After fixing the bug: higher overlap should reduce the noise added by
CHUNK_TOO_SMALL, making the faulted matrix closer to S_true.
Uses chunk_size=512 (default) so both overlap values (0, 450) are valid.
"""
from models import FaultConfig, FaultType as FT
def _make_and_get_diff(overlap_value):
e = RAGDebugEnvironment()
e.reset(seed=7, task_id=1)
e._injected_faults = [FaultConfig(fault_type=FT.CHUNK_TOO_SMALL)]
# Capture S_true before overlap action (use default chunk_size=512)
model_key = "general"
S_true = e._s_true_episode[model_key].copy()
e.step(RAGDebugAction(
action_type=ActionType.ADJUST_CHUNK_OVERLAP,
params={"value": overlap_value},
))
return float(np.abs(e._S_faulted - S_true).mean())
diff_low = _make_and_get_diff(0)
diff_high = _make_and_get_diff(450)
assert diff_high < diff_low, (
"Higher overlap should reduce CHUNK_TOO_SMALL noise, "
"making S_faulted closer to S_true"
)
# ---------------------------------------------------------------------------
# SUBMIT grading
# ---------------------------------------------------------------------------
class TestSubmit:
def test_submit_sets_done(self, env1):
obs = _step(env1, ActionType.SUBMIT)
assert obs.done is True
def test_submit_success_reward_in_range(self):
"""After enough improvement, submit should yield a high reward."""
e = RAGDebugEnvironment()
e.reset(seed=0, task_id=1)
# Adjust threshold low to maximise coverage, then submit
_step(e, ActionType.ADJUST_THRESHOLD, {"value": 0.05})
_step(e, ActionType.ADJUST_TOP_K, {"value": 50})
obs = _step(e, ActionType.SUBMIT)
# Reward should be in [0.7, 1.0] or [0.0, 0.2] depending on success
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0
def test_early_submit_penalty_reward_low(self, env1):
"""Submitting immediately (without fixing anything) should give a low reward."""
obs = _step(env1, ActionType.SUBMIT)
# Immediate submit without any fixes likely yields failure reward in [0, 0.2]
# This is not guaranteed to always be < 0.7 depending on episode, but
# it's the expected case for a fresh poorly-tuned environment.
assert obs.reward is not None
assert 0.0 <= obs.reward <= 1.0