Spaces:
Runtime error
Runtime error
File size: 2,789 Bytes
02ff91f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | """Tests for reward system."""
import pytest
from reward.tier_lock import EpisodeTierLock, RewardTier
from reward.failure_reward import (
SpecialistResult, SpecialistStatus,
compute_failure_penalty, compute_recovery_bonus,
)
from reward.consistency_tracker import PathConsistencyTracker
def test_tier_lock_same_for_atomic():
lock = EpisodeTierLock.for_task("atomic")
assert lock.locked_tier == RewardTier.TIER_0
def test_tier_lock_same_for_complex():
lock = EpisodeTierLock.for_task("complex")
assert lock.locked_tier == RewardTier.TIER_2
def test_failure_penalty_with_fallback():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=True),
]
penalty = compute_failure_penalty(results)
assert penalty < 0.3 # Reduced because fallback was used
def test_failure_penalty_no_fallback():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=False),
]
penalty = compute_failure_penalty(results)
assert penalty == pytest.approx(0.3)
def test_consistency_nonzero_from_start():
"""Dirichlet prior ensures non-zero consistency from episode 1."""
tracker = PathConsistencyTracker(specialist_ids=["a", "b", "c"])
# No recorded paths yet — score should still be > 0
score = tracker.consistency_score([], "simple")
assert score > 0.0
def test_recovery_bonus():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "fallback output", 3000, fallback_used=True),
]
bonus = compute_recovery_bonus(results, episode_completed=True)
assert bonus > 0.0
def test_conflict_detection_no_registry():
"""detect_conflicts works without a registry (keyword fallback only)."""
from reward.conflict_reward import detect_conflicts
results = [
SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
]
# No registry passed — should still work, returns empty list (no pairs provided)
conflicts = detect_conflicts(results)
assert isinstance(conflicts, list)
def test_conflict_detection_with_keyword_pairs():
"""detect_conflicts uses provided contradiction pairs correctly."""
from reward.conflict_reward import detect_conflicts
results = [
SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
]
conflicts = detect_conflicts(
results,
contradiction_pairs=[("postgresql", "mongodb")]
)
assert len(conflicts) == 1
assert conflicts[0].agent_a == "a"
assert conflicts[0].agent_b == "b"
|