sre-openenv / tests /test_models.py
Dragonfire146's picture
Initial commit
9eb0831
"""
Unit tests for SRE OpenEnv data models.
"""
import pytest
from models import SREAction, SREObservation, SREState
class TestSREAction:
"""Tests for SREAction dataclass."""
def test_run_shell_action(self):
action = SREAction(action_type="run_shell", command="ls -la")
assert action.action_type == "run_shell"
assert action.command == "ls -la"
def test_patch_file_action(self):
action = SREAction(
action_type="patch_file",
file_path="/tmp/test.txt",
content="hello world",
)
assert action.action_type == "patch_file"
assert action.file_path == "/tmp/test.txt"
assert action.content == "hello world"
def test_run_shell_requires_command(self):
with pytest.raises(ValueError, match="command"):
SREAction(action_type="run_shell", command="")
def test_patch_file_requires_path(self):
with pytest.raises(ValueError, match="file_path"):
SREAction(action_type="patch_file", file_path="")
def test_default_action_type(self):
action = SREAction(command="echo test")
assert action.action_type == "run_shell"
class TestSREObservation:
"""Tests for SREObservation dataclass."""
def test_default_values(self):
obs = SREObservation()
assert obs.stdout == ""
assert obs.stderr == ""
assert obs.exit_code == 0
assert obs.truncated is False
assert obs.message == ""
def test_with_output(self):
obs = SREObservation(
stdout="hello\n",
stderr="warning: deprecated\n",
exit_code=0,
truncated=False,
)
assert obs.stdout == "hello\n"
assert obs.stderr == "warning: deprecated\n"
def test_truncated_flag(self):
obs = SREObservation(stdout="...", truncated=True)
assert obs.truncated is True
class TestSREState:
"""Tests for SREState dataclass."""
def test_default_values(self):
state = SREState()
assert state.task_id == ""
assert state.difficulty == ""
assert state.max_steps == 30
assert state.is_done is False
assert state.current_reward == 0.0
assert state.action_history == []
def test_custom_state(self):
state = SREState(
task_id="easy_restart",
task_name="Service Restart",
difficulty="easy",
max_steps=15,
)
assert state.task_id == "easy_restart"
assert state.difficulty == "easy"
assert state.max_steps == 15
def test_action_history_isolation(self):
"""Ensure action_history is not shared between instances."""
s1 = SREState()
s2 = SREState()
s1.action_history.append("test")
assert len(s2.action_history) == 0