Spaces:
Runtime error
Runtime error
| """Tests for reward system.""" | |
| import pytest | |
| from reward.tier_lock import EpisodeTierLock, RewardTier | |
| from reward.failure_reward import ( | |
| SpecialistResult, SpecialistStatus, | |
| compute_failure_penalty, compute_recovery_bonus, | |
| ) | |
| from reward.consistency_tracker import PathConsistencyTracker | |
| def test_tier_lock_same_for_atomic(): | |
| lock = EpisodeTierLock.for_task("atomic") | |
| assert lock.locked_tier == RewardTier.TIER_0 | |
| def test_tier_lock_same_for_complex(): | |
| lock = EpisodeTierLock.for_task("complex") | |
| assert lock.locked_tier == RewardTier.TIER_2 | |
| def test_failure_penalty_with_fallback(): | |
| results = [ | |
| SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=True), | |
| ] | |
| penalty = compute_failure_penalty(results) | |
| assert penalty < 0.3 # Reduced because fallback was used | |
| def test_failure_penalty_no_fallback(): | |
| results = [ | |
| SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=False), | |
| ] | |
| penalty = compute_failure_penalty(results) | |
| assert penalty == pytest.approx(0.3) | |
| def test_consistency_nonzero_from_start(): | |
| """Dirichlet prior ensures non-zero consistency from episode 1.""" | |
| tracker = PathConsistencyTracker(specialist_ids=["a", "b", "c"]) | |
| # No recorded paths yet β score should still be > 0 | |
| score = tracker.consistency_score([], "simple") | |
| assert score > 0.0 | |
| def test_recovery_bonus(): | |
| results = [ | |
| SpecialistResult("a", SpecialistStatus.TIMEOUT, "fallback output", 3000, fallback_used=True), | |
| ] | |
| bonus = compute_recovery_bonus(results, episode_completed=True) | |
| assert bonus > 0.0 | |
| def test_conflict_detection_no_registry(): | |
| """detect_conflicts works without a registry (keyword fallback only).""" | |
| from reward.conflict_reward import detect_conflicts | |
| results = [ | |
| SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000), | |
| SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000), | |
| ] | |
| # No registry passed β should still work, returns empty list (no pairs provided) | |
| conflicts = detect_conflicts(results) | |
| assert isinstance(conflicts, list) | |
| def test_conflict_detection_with_keyword_pairs(): | |
| """detect_conflicts uses provided contradiction pairs correctly.""" | |
| from reward.conflict_reward import detect_conflicts | |
| results = [ | |
| SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000), | |
| SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000), | |
| ] | |
| conflicts = detect_conflicts( | |
| results, | |
| contradiction_pairs=[("postgresql", "mongodb")] | |
| ) | |
| assert len(conflicts) == 1 | |
| assert conflicts[0].agent_a == "a" | |
| assert conflicts[0].agent_b == "b" | |