File size: 2,858 Bytes
9eb0831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Unit tests for SRE OpenEnv data models.
"""

import pytest

from models import SREAction, SREObservation, SREState


class TestSREAction:
    """Tests for SREAction dataclass."""

    def test_run_shell_action(self):
        action = SREAction(action_type="run_shell", command="ls -la")
        assert action.action_type == "run_shell"
        assert action.command == "ls -la"

    def test_patch_file_action(self):
        action = SREAction(
            action_type="patch_file",
            file_path="/tmp/test.txt",
            content="hello world",
        )
        assert action.action_type == "patch_file"
        assert action.file_path == "/tmp/test.txt"
        assert action.content == "hello world"

    def test_run_shell_requires_command(self):
        with pytest.raises(ValueError, match="command"):
            SREAction(action_type="run_shell", command="")

    def test_patch_file_requires_path(self):
        with pytest.raises(ValueError, match="file_path"):
            SREAction(action_type="patch_file", file_path="")

    def test_default_action_type(self):
        action = SREAction(command="echo test")
        assert action.action_type == "run_shell"


class TestSREObservation:
    """Tests for SREObservation dataclass."""

    def test_default_values(self):
        obs = SREObservation()
        assert obs.stdout == ""
        assert obs.stderr == ""
        assert obs.exit_code == 0
        assert obs.truncated is False
        assert obs.message == ""

    def test_with_output(self):
        obs = SREObservation(
            stdout="hello\n",
            stderr="warning: deprecated\n",
            exit_code=0,
            truncated=False,
        )
        assert obs.stdout == "hello\n"
        assert obs.stderr == "warning: deprecated\n"

    def test_truncated_flag(self):
        obs = SREObservation(stdout="...", truncated=True)
        assert obs.truncated is True


class TestSREState:
    """Tests for SREState dataclass."""

    def test_default_values(self):
        state = SREState()
        assert state.task_id == ""
        assert state.difficulty == ""
        assert state.max_steps == 30
        assert state.is_done is False
        assert state.current_reward == 0.0
        assert state.action_history == []

    def test_custom_state(self):
        state = SREState(
            task_id="easy_restart",
            task_name="Service Restart",
            difficulty="easy",
            max_steps=15,
        )
        assert state.task_id == "easy_restart"
        assert state.difficulty == "easy"
        assert state.max_steps == 15

    def test_action_history_isolation(self):
        """Ensure action_history is not shared between instances."""
        s1 = SREState()
        s2 = SREState()
        s1.action_history.append("test")
        assert len(s2.action_history) == 0