""" Unit tests for SRE OpenEnv data models. """ import pytest from models import SREAction, SREObservation, SREState class TestSREAction: """Tests for SREAction dataclass.""" def test_run_shell_action(self): action = SREAction(action_type="run_shell", command="ls -la") assert action.action_type == "run_shell" assert action.command == "ls -la" def test_patch_file_action(self): action = SREAction( action_type="patch_file", file_path="/tmp/test.txt", content="hello world", ) assert action.action_type == "patch_file" assert action.file_path == "/tmp/test.txt" assert action.content == "hello world" def test_run_shell_requires_command(self): with pytest.raises(ValueError, match="command"): SREAction(action_type="run_shell", command="") def test_patch_file_requires_path(self): with pytest.raises(ValueError, match="file_path"): SREAction(action_type="patch_file", file_path="") def test_default_action_type(self): action = SREAction(command="echo test") assert action.action_type == "run_shell" class TestSREObservation: """Tests for SREObservation dataclass.""" def test_default_values(self): obs = SREObservation() assert obs.stdout == "" assert obs.stderr == "" assert obs.exit_code == 0 assert obs.truncated is False assert obs.message == "" def test_with_output(self): obs = SREObservation( stdout="hello\n", stderr="warning: deprecated\n", exit_code=0, truncated=False, ) assert obs.stdout == "hello\n" assert obs.stderr == "warning: deprecated\n" def test_truncated_flag(self): obs = SREObservation(stdout="...", truncated=True) assert obs.truncated is True class TestSREState: """Tests for SREState dataclass.""" def test_default_values(self): state = SREState() assert state.task_id == "" assert state.difficulty == "" assert state.max_steps == 30 assert state.is_done is False assert state.current_reward == 0.0 assert state.action_history == [] def test_custom_state(self): state = SREState( task_id="easy_restart", task_name="Service Restart", difficulty="easy", max_steps=15, ) assert state.task_id == "easy_restart" assert state.difficulty == "easy" assert state.max_steps == 15 def test_action_history_isolation(self): """Ensure action_history is not shared between instances.""" s1 = SREState() s2 = SREState() s1.action_history.append("test") assert len(s2.action_history) == 0