File size: 2,789 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Tests for reward system."""
import pytest
from reward.tier_lock import EpisodeTierLock, RewardTier
from reward.failure_reward import (
    SpecialistResult, SpecialistStatus,
    compute_failure_penalty, compute_recovery_bonus,
)
from reward.consistency_tracker import PathConsistencyTracker


def test_tier_lock_same_for_atomic():
    lock = EpisodeTierLock.for_task("atomic")
    assert lock.locked_tier == RewardTier.TIER_0


def test_tier_lock_same_for_complex():
    lock = EpisodeTierLock.for_task("complex")
    assert lock.locked_tier == RewardTier.TIER_2


def test_failure_penalty_with_fallback():
    results = [
        SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=True),
    ]
    penalty = compute_failure_penalty(results)
    assert penalty < 0.3  # Reduced because fallback was used


def test_failure_penalty_no_fallback():
    results = [
        SpecialistResult("a", SpecialistStatus.TIMEOUT, "", 8000, fallback_used=False),
    ]
    penalty = compute_failure_penalty(results)
    assert penalty == pytest.approx(0.3)


def test_consistency_nonzero_from_start():
    """Dirichlet prior ensures non-zero consistency from episode 1."""
    tracker = PathConsistencyTracker(specialist_ids=["a", "b", "c"])
    # No recorded paths yet — score should still be > 0
    score = tracker.consistency_score([], "simple")
    assert score > 0.0


def test_recovery_bonus():
    results = [
        SpecialistResult("a", SpecialistStatus.TIMEOUT, "fallback output", 3000, fallback_used=True),
    ]
    bonus = compute_recovery_bonus(results, episode_completed=True)
    assert bonus > 0.0


def test_conflict_detection_no_registry():
    """detect_conflicts works without a registry (keyword fallback only)."""
    from reward.conflict_reward import detect_conflicts
    results = [
        SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
        SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
    ]
    # No registry passed — should still work, returns empty list (no pairs provided)
    conflicts = detect_conflicts(results)
    assert isinstance(conflicts, list)


def test_conflict_detection_with_keyword_pairs():
    """detect_conflicts uses provided contradiction pairs correctly."""
    from reward.conflict_reward import detect_conflicts
    results = [
        SpecialistResult("a", SpecialistStatus.SUCCESS, "Use PostgreSQL for storage", 1000),
        SpecialistResult("b", SpecialistStatus.SUCCESS, "Use MongoDB for storage", 1000),
    ]
    conflicts = detect_conflicts(
        results,
        contradiction_pairs=[("postgresql", "mongodb")]
    )
    assert len(conflicts) == 1
    assert conflicts[0].agent_a == "a"
    assert conflicts[0].agent_b == "b"