File size: 4,841 Bytes
0f8f2c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Unit tests for DriftEngine — tests drift selection and application logic.

Run:
    docker exec <container> python -m pytest env/tests/test_drift_engine.py -v
"""

from unittest.mock import MagicMock

import pytest

from models import Task, TaskID, TaskDifficulty, SuccessCriteria, SetupCommand
from server.services.drift_engine import DriftEngine, _MIN_DRIFTS, _MAX_DRIFTS


def _task_with_drifts(n: int) -> Task:
    """Create a task with N possible drifts."""
    return Task(
        task_id=TaskID(1),
        difficulty=TaskDifficulty.EXPERT,
        description="test",
        success_criteria=SuccessCriteria(),
        possible_drifts=[
            SetupCommand(command=f"aws cmd-{i}", description=f"drift-{i}")
            for i in range(n)
        ],
    )


@pytest.fixture
def mock_backend() -> MagicMock:
    backend = MagicMock()
    backend.execute_command.return_value = (True, "", "")
    return backend


@pytest.fixture
def engine(mock_backend: MagicMock) -> DriftEngine:
    return DriftEngine(mock_backend)


# ===================================================================
# apply_drift
# ===================================================================


class TestApplyDrift:
    def test_no_drifts_returns_empty(self, engine: DriftEngine) -> None:
        task = Task(
            task_id=TaskID(1),
            difficulty=TaskDifficulty.EXPERT,
            description="t",
            success_criteria=SuccessCriteria(),
        )
        assert engine.apply_drift(task) == []

    def test_single_drift_always_selected(
        self, engine: DriftEngine, mock_backend: MagicMock
    ) -> None:
        task = _task_with_drifts(1)
        applied = engine.apply_drift(task)
        assert len(applied) == 1
        assert applied[0] == "drift-0"
        mock_backend.execute_command.assert_called_once_with("aws cmd-0")

    def test_selects_between_min_and_max(self, engine: DriftEngine) -> None:
        task = _task_with_drifts(10)
        for _ in range(20):
            applied = engine.apply_drift(task)
            assert _MIN_DRIFTS <= len(applied) <= _MAX_DRIFTS

    def test_never_exceeds_pool_size(self, engine: DriftEngine) -> None:
        task = _task_with_drifts(2)
        for _ in range(20):
            applied = engine.apply_drift(task)
            assert len(applied) <= 2

    def test_selected_drifts_are_unique(self, engine: DriftEngine) -> None:
        task = _task_with_drifts(5)
        for _ in range(20):
            applied = engine.apply_drift(task)
            assert len(applied) == len(set(applied))

    def test_failed_drift_not_in_applied(
        self, engine: DriftEngine, mock_backend: MagicMock
    ) -> None:
        mock_backend.execute_command.return_value = (False, "", "error")
        task = _task_with_drifts(1)
        applied = engine.apply_drift(task)
        assert len(applied) == 0

    def test_partial_failure_only_returns_successful(
        self, engine: DriftEngine, mock_backend: MagicMock
    ) -> None:
        task = _task_with_drifts(2)
        mock_backend.execute_command.side_effect = [
            (True, "", ""),
            (False, "", "fail"),
        ]
        applied = engine.apply_drift(task)
        assert len(applied) == 1

    def test_uses_description_as_label(self, engine: DriftEngine) -> None:
        task = Task(
            task_id=TaskID(1),
            difficulty=TaskDifficulty.EXPERT,
            description="t",
            success_criteria=SuccessCriteria(),
            possible_drifts=[
                SetupCommand(command="aws test", description="My drift label"),
            ],
        )
        applied = engine.apply_drift(task)
        assert applied == ["My drift label"]

    def test_uses_command_as_fallback_label(self, engine: DriftEngine) -> None:
        task = Task(
            task_id=TaskID(1),
            difficulty=TaskDifficulty.EXPERT,
            description="t",
            success_criteria=SuccessCriteria(),
            possible_drifts=[SetupCommand(command="aws fallback-cmd")],
        )
        applied = engine.apply_drift(task)
        assert applied == ["aws fallback-cmd"]


# ===================================================================
# _pick_count
# ===================================================================


class TestPickCount:
    def test_zero_pool(self) -> None:
        assert DriftEngine._pick_count(0) == 0

    def test_one_pool(self) -> None:
        assert DriftEngine._pick_count(1) == 1

    def test_two_pool_returns_two(self) -> None:
        # pool_size=2: lo=min(2,2)=2, hi=min(3,2)=2 => always 2
        assert DriftEngine._pick_count(2) == 2

    def test_large_pool_within_bounds(self) -> None:
        for _ in range(50):
            count = DriftEngine._pick_count(10)
            assert _MIN_DRIFTS <= count <= _MAX_DRIFTS