SpindleFlow-RL / tests /test_reward.py
garvitsachdeva's picture
SpindleFlow RL β€” periodic push + log persistence
02ff91f
"""Tests for reward system."""
import pytest
from reward.tier_lock import EpisodeTierLock, RewardTier
from reward.failure_reward import (
SpecialistResult, SpecialistStatus,
compute_failure_penalty, compute_recovery_bonus,
)
from reward.consistency_tracker import PathConsistencyTracker
def test_tier_lock_same_for_atomic():
lock = EpisodeTierLock.for_task("atomic")
assert lock.locked_tier == RewardTier.TIER_0
def test_tier_lock_same_for_complex():
lock = EpisodeTierLock.for_task("complex")
assert lock.locked_tier == RewardTier.TIER_2
def test_failure_penalty_with_fallback():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=True),
]
penalty = compute_failure_penalty(results)
assert penalty < 0.3 # Reduced because fallback was used
def test_failure_penalty_no_fallback():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=False),
]
penalty = compute_failure_penalty(results)
assert penalty == pytest.approx(0.3)
def test_consistency_nonzero_from_start():
"""Dirichlet prior ensures non-zero consistency from episode 1."""
tracker = PathConsistencyTracker(specialist_ids=["a", "b", "c"])
# No recorded paths yet β€” score should still be > 0
score = tracker.consistency_score([], "simple")
assert score > 0.0
def test_recovery_bonus():
results = [
SpecialistResult("a", SpecialistStatus.TIMEOUT, "fallback output", 3000, fallback_used=True),
]
bonus = compute_recovery_bonus(results, episode_completed=True)
assert bonus > 0.0
def test_conflict_detection_no_registry():
"""detect_conflicts works without a registry (keyword fallback only)."""
from reward.conflict_reward import detect_conflicts
results = [
SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
]
# No registry passed β€” should still work, returns empty list (no pairs provided)
conflicts = detect_conflicts(results)
assert isinstance(conflicts, list)
def test_conflict_detection_with_keyword_pairs():
"""detect_conflicts uses provided contradiction pairs correctly."""
from reward.conflict_reward import detect_conflicts
results = [
SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
]
conflicts = detect_conflicts(
results,
contradiction_pairs=[("postgresql", "mongodb")]
)
assert len(conflicts) == 1
assert conflicts[0].agent_a == "a"
assert conflicts[0].agent_b == "b"