Spaces:
Sleeping
Sleeping
File size: 2,858 Bytes
9eb0831 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """
Unit tests for SRE OpenEnv data models.
"""
import pytest
from models import SREAction, SREObservation, SREState
class TestSREAction:
"""Tests for SREAction dataclass."""
def test_run_shell_action(self):
action = SREAction(action_type="run_shell", command="ls -la")
assert action.action_type == "run_shell"
assert action.command == "ls -la"
def test_patch_file_action(self):
action = SREAction(
action_type="patch_file",
file_path="/tmp/test.txt",
content="hello world",
)
assert action.action_type == "patch_file"
assert action.file_path == "/tmp/test.txt"
assert action.content == "hello world"
def test_run_shell_requires_command(self):
with pytest.raises(ValueError, match="command"):
SREAction(action_type="run_shell", command="")
def test_patch_file_requires_path(self):
with pytest.raises(ValueError, match="file_path"):
SREAction(action_type="patch_file", file_path="")
def test_default_action_type(self):
action = SREAction(command="echo test")
assert action.action_type == "run_shell"
class TestSREObservation:
"""Tests for SREObservation dataclass."""
def test_default_values(self):
obs = SREObservation()
assert obs.stdout == ""
assert obs.stderr == ""
assert obs.exit_code == 0
assert obs.truncated is False
assert obs.message == ""
def test_with_output(self):
obs = SREObservation(
stdout="hello\n",
stderr="warning: deprecated\n",
exit_code=0,
truncated=False,
)
assert obs.stdout == "hello\n"
assert obs.stderr == "warning: deprecated\n"
def test_truncated_flag(self):
obs = SREObservation(stdout="...", truncated=True)
assert obs.truncated is True
class TestSREState:
"""Tests for SREState dataclass."""
def test_default_values(self):
state = SREState()
assert state.task_id == ""
assert state.difficulty == ""
assert state.max_steps == 30
assert state.is_done is False
assert state.current_reward == 0.0
assert state.action_history == []
def test_custom_state(self):
state = SREState(
task_id="easy_restart",
task_name="Service Restart",
difficulty="easy",
max_steps=15,
)
assert state.task_id == "easy_restart"
assert state.difficulty == "easy"
assert state.max_steps == 15
def test_action_history_isolation(self):
"""Ensure action_history is not shared between instances."""
s1 = SREState()
s2 = SREState()
s1.action_history.append("test")
assert len(s2.action_history) == 0
|