Spaces:
Running
Running
| # tests/test_integration.py | |
| # Phase 7 β Integration tests for OpenEnv wiring. | |
| # Validates the acceptance criteria from PRD Β§12.6. | |
| from __future__ import annotations | |
| import sys | |
| import os | |
| # Ensure the firewatch_env package root is on the path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models import FirewatchAction, SystemObservation | |
| from simulation import generate_episode | |
| from actions import ActionHandler | |
| from rewards import RewardEngine, EpisodeResult, grade | |
| from server.firewatch_env_environment import FirewatchEnvironment | |
| # -------------------------------------------------------------------------- | |
| # Test 1: Deterministic reset | |
| # Two calls to reset(easy, 42) return identical initial observations | |
| # -------------------------------------------------------------------------- | |
| def test_reset_deterministic(): | |
| """PRD Β§12.6: Two calls to reset(easy, 42) return byte-identical initial observations.""" | |
| env1 = FirewatchEnvironment() | |
| env2 = FirewatchEnvironment() | |
| obs1 = env1.reset(difficulty="easy", seed=42) | |
| obs2 = env2.reset(difficulty="easy", seed=42) | |
| # Same services | |
| assert set(obs1.services.keys()) == set(obs2.services.keys()), \ | |
| f"Service sets differ: {obs1.services.keys()} vs {obs2.services.keys()}" | |
| # Same metrics on each service | |
| for name in obs1.services: | |
| m1 = obs1.services[name] | |
| m2 = obs2.services[name] | |
| assert m1.http_server_error_rate == m2.http_server_error_rate, \ | |
| f"Error rate mismatch on {name}: {m1.http_server_error_rate} vs {m2.http_server_error_rate}" | |
| assert m1.process_memory_utilization == m2.process_memory_utilization, \ | |
| f"Memory util mismatch on {name}: {m1.process_memory_utilization} vs {m2.process_memory_utilization}" | |
| assert m1.http_server_request_duration_p99 == m2.http_server_request_duration_p99, \ | |
| f"Latency mismatch on {name}" | |
| # Same dependency graph | |
| assert obs1.dependency_graph == obs2.dependency_graph | |
| # Same SLO budget | |
| assert obs1.slo_budget_remaining_pct == obs2.slo_budget_remaining_pct | |
| print("β test_reset_deterministic PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Test 2: Full episode flow | |
| # reset β step(fetch_logs) β step(restart_service) β step(declare_resolved) | |
| # -------------------------------------------------------------------------- | |
| def test_full_episode_flow(): | |
| """PRD Β§12.6: Sequential calls complete without error.""" | |
| env = FirewatchEnvironment() | |
| # Reset | |
| obs = env.reset(difficulty="easy", seed=42) | |
| assert obs.sim_tick == 1 # reset() ticks once to propagate initial fault | |
| assert obs.slo_budget_remaining_pct <= 30.0 # may be slightly reduced by initial tick | |
| assert len(obs.services) > 0 | |
| assert obs.done is False | |
| # Pick a service to investigate | |
| target = list(obs.services.keys())[0] | |
| # Step 1: fetch_logs | |
| action1 = FirewatchAction(action_type="fetch_logs", target_service=target) | |
| obs1 = env.step(action1) | |
| assert obs1.sim_tick == 2 # reset tick + 1 step tick | |
| assert obs1.done is False | |
| assert obs1.reward is not None | |
| # Step 2: restart_service | |
| action2 = FirewatchAction(action_type="restart_service", target_service=target) | |
| obs2 = env.step(action2) | |
| assert obs2.sim_tick == 3 # reset tick + 2 step ticks | |
| assert obs2.done is False | |
| # Step 3: declare_resolved | |
| action3 = FirewatchAction(action_type="declare_resolved") | |
| obs3 = env.step(action3) | |
| assert obs3.done is True | |
| assert obs3.reward is not None | |
| # Episode score should be in metadata | |
| assert "episode_score" in obs3.metadata, \ | |
| f"episode_score not in metadata: {list(obs3.metadata.keys())}" | |
| print("β test_full_episode_flow PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Test 3: Invalid action handling | |
| # step() with invalid input returns valid response, not crash | |
| # -------------------------------------------------------------------------- | |
| def test_invalid_action_graceful(): | |
| """PRD Β§12.6: step() with invalid target returns HTTP 200 with error info.""" | |
| env = FirewatchEnvironment() | |
| env.reset(difficulty="easy", seed=42) | |
| # Action with non-existent service | |
| action = FirewatchAction( | |
| action_type="fetch_logs", | |
| target_service="nonexistent-service", | |
| ) | |
| obs = env.step(action) | |
| # Should not crash | |
| assert obs is not None | |
| assert obs.done is False | |
| # Should have error/invalid feedback in action history | |
| assert len(obs.action_history) > 0 | |
| assert "Invalid target" in obs.action_history[-1].get("feedback_string", "") or \ | |
| "not an active service" in obs.action_history[-1].get("feedback_string", "") | |
| print("β test_invalid_action_graceful PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Test 4: Wrong action produces negative reward | |
| # -------------------------------------------------------------------------- | |
| def test_wrong_action_negative_reward(): | |
| """Remediating a healthy service should produce a wrong-action penalty.""" | |
| env = FirewatchEnvironment() | |
| obs = env.reset(difficulty="easy", seed=42) | |
| # Find a healthy service (not the root cause) | |
| # Run a few ticks first so we have some degradation | |
| noop_action = FirewatchAction(action_type="fetch_logs", target_service=list(obs.services.keys())[0]) | |
| env.step(noop_action) | |
| env.step(noop_action) | |
| # Now pick a service with low error rate | |
| healthy_services = [ | |
| name for name, m in env._mesh.services.items() | |
| if m.http_server_error_rate < 0.05 # HEALTHY_ERROR_RATE_THRESHOLD | |
| ] | |
| if healthy_services: | |
| target = healthy_services[0] | |
| action = FirewatchAction(action_type="restart_service", target_service=target) | |
| obs = env.step(action) | |
| # Check for wrong action penalty in metadata | |
| breakdown = obs.metadata.get("reward_breakdown", {}) | |
| assert breakdown.get("wrong_action_penalty", 0.0) < 0.0, \ | |
| f"Expected negative wrong_action_penalty, got {breakdown}" | |
| print("β test_wrong_action_negative_reward PASSED") | |
| else: | |
| print("β test_wrong_action_negative_reward SKIPPED (no healthy services found at this seed)") | |
| # -------------------------------------------------------------------------- | |
| # Test 5: Grader appears in done info | |
| # -------------------------------------------------------------------------- | |
| def test_grader_in_done_info(): | |
| """PRD Β§12.6: episode_score appears in done=True step's info dict.""" | |
| env = FirewatchEnvironment() | |
| env.reset(difficulty="easy", seed=42) | |
| # Immediately declare resolved (worst case agent) | |
| action = FirewatchAction(action_type="declare_resolved") | |
| obs = env.step(action) | |
| assert obs.done is True | |
| assert "episode_score" in obs.metadata | |
| score = obs.metadata["episode_score"] | |
| assert 0.0 <= score <= 1.0, f"Score out of range: {score}" | |
| # Zero-effort agent should score poorly | |
| assert score < 0.30, f"Zero-effort score too high: {score}" | |
| print("β test_grader_in_done_info PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Test 6: SLO breach terminates episode | |
| # -------------------------------------------------------------------------- | |
| def test_slo_breach_terminates(): | |
| """Running enough ticks to deplete SLO causes done=True.""" | |
| env = FirewatchEnvironment() | |
| env.reset(difficulty="hard", seed=100) | |
| # Just do noop investigation actions until SLO runs out or max ticks | |
| target = list(env._mesh.services.keys())[0] | |
| done = False | |
| tick = 0 | |
| while not done and tick < 50: | |
| action = FirewatchAction(action_type="fetch_logs", target_service=target) | |
| obs = env.step(action) | |
| done = obs.done | |
| tick += 1 | |
| assert done is True, f"Episode did not terminate after {tick} ticks" | |
| # Hard difficulty with 40 max ticks should terminate | |
| assert tick <= 41, f"Episode took too many ticks: {tick}" | |
| print("β test_slo_breach_terminates PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Test 7: Score variance (different agent behaviors yield different scores) | |
| # -------------------------------------------------------------------------- | |
| def test_score_variance(): | |
| """Grader must produce meaningfully different scores for different behaviors.""" | |
| # Zero-effort agent: immediately gives up | |
| env1 = FirewatchEnvironment() | |
| env1.reset(difficulty="easy", seed=42) | |
| obs_zero = env1.step(FirewatchAction(action_type="declare_resolved")) | |
| score_zero = obs_zero.metadata["episode_score"] | |
| # Active agent: investigates, lets fault develop, remediates, then resolves | |
| env2 = FirewatchEnvironment() | |
| obs2 = env2.reset(difficulty="easy", seed=42) | |
| root_cause = env2._fault_config.root_cause_service | |
| fault_type = env2._fault_config.fault_type | |
| # Let the fault develop for a few ticks with investigation | |
| for svc in list(obs2.services.keys()): | |
| env2.step(FirewatchAction(action_type="fetch_logs", target_service=svc)) | |
| # Apply correct remediation based on fault type | |
| if fault_type == "oom": | |
| env2.step(FirewatchAction(action_type="scale_replicas", target_service=root_cause)) | |
| elif fault_type == "bad_deploy": | |
| env2.step(FirewatchAction(action_type="rollback_deploy", target_service=root_cause)) | |
| elif fault_type == "config_drift": | |
| env2.step(FirewatchAction(action_type="revert_config", target_service=root_cause)) | |
| elif fault_type == "memory_leak": | |
| env2.step(FirewatchAction(action_type="restart_service", target_service=root_cause)) | |
| elif fault_type == "network_partition": | |
| env2.step(FirewatchAction(action_type="restart_service", target_service=root_cause)) | |
| # Let system recover for a few ticks | |
| for _ in range(3): | |
| env2.step(FirewatchAction(action_type="fetch_logs", target_service=root_cause)) | |
| obs_active = env2.step(FirewatchAction(action_type="declare_resolved")) | |
| score_active = obs_active.metadata["episode_score"] | |
| # Active agent should score higher than zero-effort | |
| assert score_active > score_zero, \ | |
| f"Active agent ({score_active:.4f}) should score higher than zero-effort ({score_zero:.4f})" | |
| print(f"β test_score_variance PASSED (zero={score_zero:.4f}, active={score_active:.4f})") | |
| # -------------------------------------------------------------------------- | |
| # Test 8: No episode active -> graceful response | |
| # -------------------------------------------------------------------------- | |
| def test_no_episode_step(): | |
| """step() without prior reset() should return graceful error.""" | |
| env = FirewatchEnvironment() | |
| action = FirewatchAction(action_type="fetch_logs", target_service="test") | |
| obs = env.step(action) | |
| assert obs is not None | |
| # Should have error info | |
| assert len(obs.action_history) > 0 or obs.metadata.get("error") | |
| print("β test_no_episode_step PASSED") | |
| # -------------------------------------------------------------------------- | |
| # Run all tests | |
| # -------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| tests = [ | |
| test_reset_deterministic, | |
| test_full_episode_flow, | |
| test_invalid_action_graceful, | |
| test_wrong_action_negative_reward, | |
| test_grader_in_done_info, | |
| test_slo_breach_terminates, | |
| test_score_variance, | |
| test_no_episode_step, | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for test in tests: | |
| try: | |
| test() | |
| passed += 1 | |
| except Exception as e: | |
| print(f"β {test.__name__} FAILED: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| failed += 1 | |
| print(f"\n{'='*60}") | |
| print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests") | |
| if failed == 0: | |
| print("All Phase 7 acceptance criteria PASSED β") | |
| else: | |
| print(f"FAILED β {failed} test(s) need fixing") | |
| print(f"{'='*60}") | |