preference-lab / tests /test_environment.py
Sibam
fix: clamp grader rewards to strictly (0, 1) to pass OpenEnv validation bounds
f3f7bc4
Raw
History Blame Contribute Delete
9.87 kB
"""
Tests for PreferenceLab environment.
Run: pytest tests/ -v
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
import pytest
from models import (
PairwiseAction, LikertAction, ConsistencyAction,
PairwiseObservation, LikertObservation, ConsistencyObservation,
)
from server.environment import (
PreferenceLabEnvironment,
grade_pairwise, grade_likert, grade_consistency,
)
# ── Grader unit tests ─────────────────────────────────────────
class TestPairwiseGrader:
def test_correct_choice_scores_1(self):
action = PairwiseAction(choice="A")
example = {"gold_label": "A", "source": "test"}
reward, info = grade_pairwise(action, example)
assert reward == 0.99
assert info["verdict"] == "correct"
def test_wrong_choice_scores_0(self):
action = PairwiseAction(choice="B")
example = {"gold_label": "A", "source": "test"}
reward, info = grade_pairwise(action, example)
assert reward == 0.01
assert info["verdict"] == "incorrect"
def test_skip_scores_partial(self):
action = PairwiseAction(choice="skip")
example = {"gold_label": "A", "source": "test"}
reward, info = grade_pairwise(action, example)
assert reward == 0.3
def test_tie_scores_low(self):
action = PairwiseAction(choice="tie")
example = {"gold_label": "A", "source": "test"}
reward, info = grade_pairwise(action, example)
assert reward == 0.1
def test_reward_in_range(self):
for choice in ["A", "B", "tie", "skip"]:
action = PairwiseAction(choice=choice)
reward, _ = grade_pairwise(action, {"gold_label": "A", "source": "test"})
assert 0.0 < reward < 1.0
class TestLikertGrader:
def test_perfect_scores_reward_1(self):
action = LikertAction(helpfulness=5, honesty=5, harmlessness=5, instruction_following=5)
example = {
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
"source": "test",
}
reward, info = grade_likert(action, example)
assert reward == 0.99
def test_worst_scores_reward_0(self):
action = LikertAction(helpfulness=1, honesty=1, harmlessness=1, instruction_following=1)
example = {
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
"source": "test",
}
reward, info = grade_likert(action, example)
assert reward == 0.01
def test_partial_error_gives_partial_reward(self):
action = LikertAction(helpfulness=4, honesty=4, harmlessness=4, instruction_following=4)
example = {
"gold_scores": {"helpfulness": 5, "honesty": 5, "harmlessness": 5, "instruction_following": 5},
"source": "test",
}
reward, info = grade_likert(action, example)
assert 0.0 < reward < 1.0
def test_reward_always_in_range(self):
import random
for _ in range(20):
action = LikertAction(
helpfulness=random.randint(1, 5),
honesty=random.randint(1, 5),
harmlessness=random.randint(1, 5),
instruction_following=random.randint(1, 5),
)
example = {
"gold_scores": {
"helpfulness": random.randint(1, 5),
"honesty": random.randint(1, 5),
"harmlessness": random.randint(1, 5),
"instruction_following": random.randint(1, 5),
}
}
reward, _ = grade_likert(action, example)
assert 0.0 < reward < 1.0, f"Reward out of range: {reward}"
class TestConsistencyGrader:
def test_perfect_ranking_scores_1(self):
action = ConsistencyAction(ranking=["A", "B", "C", "D"])
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
reward, info = grade_consistency(action, example)
assert reward == 0.99
def test_reversed_ranking_scores_low(self):
action = ConsistencyAction(ranking=["D", "C", "B", "A"])
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
reward, info = grade_consistency(action, example)
# Transitivity score = 0.5 (ranking is still a valid total order)
# Quality score = 0.0 (worst possible Kendall tau = -1 β†’ normalized to 0)
# Total = 0.5 β€” strictly less than perfect score of 1.0
assert reward < 1.0
assert info["quality_score"] == 0.0
def test_invalid_ids_scores_low(self):
action = ConsistencyAction(ranking=["A", "B", "C", "X"])
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
reward, info = grade_consistency(action, example)
assert reward == 0.01
assert info["has_invalid_ids"] is True
def test_reward_always_in_range(self):
import itertools
import random
ids = ["A", "B", "C", "D"]
gold = ["A", "B", "C", "D"]
for perm in itertools.permutations(ids):
action = ConsistencyAction(ranking=list(perm))
example = {"gold_ranking": gold, "source": "test"}
reward, _ = grade_consistency(action, example)
assert 0.0 < reward < 1.0, f"Reward out of range: {reward} for {perm}"
def test_graders_not_always_same_score(self):
"""Critical: graders must NOT always return the same score."""
action_correct = ConsistencyAction(ranking=["A", "B", "C", "D"])
action_wrong = ConsistencyAction(ranking=["D", "C", "B", "A"])
example = {"gold_ranking": ["A", "B", "C", "D"], "source": "test"}
r1, _ = grade_consistency(action_correct, example)
r2, _ = grade_consistency(action_wrong, example)
assert r1 != r2, "Grader must return different scores for different inputs!"
# ── Environment integration tests ─────────────────────────────
class TestPreferenceLabEnvironment:
def setup_method(self):
self.env = PreferenceLabEnvironment()
def test_reset_returns_observation(self):
obs = self.env.reset()
assert obs is not None
assert hasattr(obs, "prompt")
assert hasattr(obs, "reward")
assert hasattr(obs, "done")
def test_reset_pairwise_returns_pairwise_obs(self):
obs = self.env.reset(task_type="pairwise")
assert isinstance(obs, PairwiseObservation)
assert obs.response_a != ""
assert obs.response_b != ""
def test_reset_likert_returns_likert_obs(self):
obs = self.env.reset(task_type="likert")
assert isinstance(obs, LikertObservation)
assert obs.response != ""
assert obs.rubric != ""
def test_reset_consistency_returns_consistency_obs(self):
obs = self.env.reset(task_type="consistency")
assert isinstance(obs, ConsistencyObservation)
assert obs.response_a != ""
assert obs.response_d != ""
def test_step_pairwise(self):
self.env.reset(task_type="pairwise")
action = PairwiseAction(choice="A")
obs = self.env.step(action)
assert isinstance(obs, PairwiseObservation)
assert 0.0 < obs.reward < 1.0
assert isinstance(obs.done, bool)
def test_step_likert(self):
self.env.reset(task_type="likert")
action = LikertAction(helpfulness=4, honesty=4, harmlessness=5, instruction_following=4)
obs = self.env.step(action)
assert isinstance(obs, LikertObservation)
assert 0.0 < obs.reward < 1.0
def test_step_consistency(self):
self.env.reset(task_type="consistency")
action = ConsistencyAction(ranking=["A", "B", "C", "D"])
obs = self.env.step(action)
assert isinstance(obs, ConsistencyObservation)
assert 0.0 < obs.reward < 1.0
def test_episode_terminates_after_max_steps(self):
self.env.reset(task_type="pairwise")
done = False
steps = 0
while not done:
obs = self.env.step(PairwiseAction(choice="A"))
done = obs.done
steps += 1
assert steps <= 10, "Episode ran too long!"
assert done is True
def test_state_returns_metadata(self):
self.env.reset(seed=42, task_type="pairwise")
state = self.env.state
assert "episode_id" in state.model_dump()
assert "step_count" in state.model_dump()
assert "task_type" in state.model_dump()
assert state.seed == 42
def test_reproducible_with_seed(self):
obs1 = self.env.reset(seed=123, task_type="pairwise")
obs2 = self.env.reset(seed=123, task_type="pairwise")
assert obs1.prompt == obs2.prompt
assert obs1.response_a == obs2.response_a
def test_rewards_vary_across_actions(self):
"""Ensure graders do NOT always return the same reward (disqualification check)."""
rewards = set()
for _ in range(5):
self.env.reset(task_type="pairwise")
obs_a = self.env.step(PairwiseAction(choice="A"))
self.env.reset(task_type="pairwise")
obs_b = self.env.step(PairwiseAction(choice="B"))
rewards.add(obs_a.reward)
rewards.add(obs_b.reward)
assert len(rewards) > 1, "Grader always returns the same score β€” DISQUALIFICATION!"
def test_all_three_tasks_run(self):
for task in ["pairwise", "likert", "consistency"]:
obs = self.env.reset(task_type=task)
assert obs is not None
state = self.env.state
assert state.task_type == task