firewatch-env / tests /test_integration.py
10doshi12's picture
Upload folder using huggingface_hub
bb0313f verified
# tests/test_integration.py
# Phase 7 β€” Integration tests for OpenEnv wiring.
# Validates the acceptance criteria from PRD Β§12.6.
from __future__ import annotations
import sys
import os
# Ensure the firewatch_env package root is on the path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import FirewatchAction, SystemObservation
from simulation import generate_episode
from actions import ActionHandler
from rewards import RewardEngine, EpisodeResult, grade
from server.firewatch_env_environment import FirewatchEnvironment
# --------------------------------------------------------------------------
# Test 1: Deterministic reset
# Two calls to reset(easy, 42) return identical initial observations
# --------------------------------------------------------------------------
def test_reset_deterministic():
"""PRD Β§12.6: Two calls to reset(easy, 42) return byte-identical initial observations."""
env1 = FirewatchEnvironment()
env2 = FirewatchEnvironment()
obs1 = env1.reset(difficulty="easy", seed=42)
obs2 = env2.reset(difficulty="easy", seed=42)
# Same services
assert set(obs1.services.keys()) == set(obs2.services.keys()), \
f"Service sets differ: {obs1.services.keys()} vs {obs2.services.keys()}"
# Same metrics on each service
for name in obs1.services:
m1 = obs1.services[name]
m2 = obs2.services[name]
assert m1.http_server_error_rate == m2.http_server_error_rate, \
f"Error rate mismatch on {name}: {m1.http_server_error_rate} vs {m2.http_server_error_rate}"
assert m1.process_memory_utilization == m2.process_memory_utilization, \
f"Memory util mismatch on {name}: {m1.process_memory_utilization} vs {m2.process_memory_utilization}"
assert m1.http_server_request_duration_p99 == m2.http_server_request_duration_p99, \
f"Latency mismatch on {name}"
# Same dependency graph
assert obs1.dependency_graph == obs2.dependency_graph
# Same SLO budget
assert obs1.slo_budget_remaining_pct == obs2.slo_budget_remaining_pct
print("βœ“ test_reset_deterministic PASSED")
# --------------------------------------------------------------------------
# Test 2: Full episode flow
# reset β†’ step(fetch_logs) β†’ step(restart_service) β†’ step(declare_resolved)
# --------------------------------------------------------------------------
def test_full_episode_flow():
"""PRD Β§12.6: Sequential calls complete without error."""
env = FirewatchEnvironment()
# Reset
obs = env.reset(difficulty="easy", seed=42)
assert obs.sim_tick == 1 # reset() ticks once to propagate initial fault
assert obs.slo_budget_remaining_pct <= 30.0 # may be slightly reduced by initial tick
assert len(obs.services) > 0
assert obs.done is False
# Pick a service to investigate
target = list(obs.services.keys())[0]
# Step 1: fetch_logs
action1 = FirewatchAction(action_type="fetch_logs", target_service=target)
obs1 = env.step(action1)
assert obs1.sim_tick == 2 # reset tick + 1 step tick
assert obs1.done is False
assert obs1.reward is not None
# Step 2: restart_service
action2 = FirewatchAction(action_type="restart_service", target_service=target)
obs2 = env.step(action2)
assert obs2.sim_tick == 3 # reset tick + 2 step ticks
assert obs2.done is False
# Step 3: declare_resolved
action3 = FirewatchAction(action_type="declare_resolved")
obs3 = env.step(action3)
assert obs3.done is True
assert obs3.reward is not None
# Episode score should be in metadata
assert "episode_score" in obs3.metadata, \
f"episode_score not in metadata: {list(obs3.metadata.keys())}"
print("βœ“ test_full_episode_flow PASSED")
# --------------------------------------------------------------------------
# Test 3: Invalid action handling
# step() with invalid input returns valid response, not crash
# --------------------------------------------------------------------------
def test_invalid_action_graceful():
"""PRD Β§12.6: step() with invalid target returns HTTP 200 with error info."""
env = FirewatchEnvironment()
env.reset(difficulty="easy", seed=42)
# Action with non-existent service
action = FirewatchAction(
action_type="fetch_logs",
target_service="nonexistent-service",
)
obs = env.step(action)
# Should not crash
assert obs is not None
assert obs.done is False
# Should have error/invalid feedback in action history
assert len(obs.action_history) > 0
assert "Invalid target" in obs.action_history[-1].get("feedback_string", "") or \
"not an active service" in obs.action_history[-1].get("feedback_string", "")
print("βœ“ test_invalid_action_graceful PASSED")
# --------------------------------------------------------------------------
# Test 4: Wrong action produces negative reward
# --------------------------------------------------------------------------
def test_wrong_action_negative_reward():
"""Remediating a healthy service should produce a wrong-action penalty."""
env = FirewatchEnvironment()
obs = env.reset(difficulty="easy", seed=42)
# Find a healthy service (not the root cause)
# Run a few ticks first so we have some degradation
noop_action = FirewatchAction(action_type="fetch_logs", target_service=list(obs.services.keys())[0])
env.step(noop_action)
env.step(noop_action)
# Now pick a service with low error rate
healthy_services = [
name for name, m in env._mesh.services.items()
if m.http_server_error_rate < 0.05 # HEALTHY_ERROR_RATE_THRESHOLD
]
if healthy_services:
target = healthy_services[0]
action = FirewatchAction(action_type="restart_service", target_service=target)
obs = env.step(action)
# Check for wrong action penalty in metadata
breakdown = obs.metadata.get("reward_breakdown", {})
assert breakdown.get("wrong_action_penalty", 0.0) < 0.0, \
f"Expected negative wrong_action_penalty, got {breakdown}"
print("βœ“ test_wrong_action_negative_reward PASSED")
else:
print("⚠ test_wrong_action_negative_reward SKIPPED (no healthy services found at this seed)")
# --------------------------------------------------------------------------
# Test 5: Grader appears in done info
# --------------------------------------------------------------------------
def test_grader_in_done_info():
"""PRD Β§12.6: episode_score appears in done=True step's info dict."""
env = FirewatchEnvironment()
env.reset(difficulty="easy", seed=42)
# Immediately declare resolved (worst case agent)
action = FirewatchAction(action_type="declare_resolved")
obs = env.step(action)
assert obs.done is True
assert "episode_score" in obs.metadata
score = obs.metadata["episode_score"]
assert 0.0 <= score <= 1.0, f"Score out of range: {score}"
# Zero-effort agent should score poorly
assert score < 0.30, f"Zero-effort score too high: {score}"
print("βœ“ test_grader_in_done_info PASSED")
# --------------------------------------------------------------------------
# Test 6: SLO breach terminates episode
# --------------------------------------------------------------------------
def test_slo_breach_terminates():
"""Running enough ticks to deplete SLO causes done=True."""
env = FirewatchEnvironment()
env.reset(difficulty="hard", seed=100)
# Just do noop investigation actions until SLO runs out or max ticks
target = list(env._mesh.services.keys())[0]
done = False
tick = 0
while not done and tick < 50:
action = FirewatchAction(action_type="fetch_logs", target_service=target)
obs = env.step(action)
done = obs.done
tick += 1
assert done is True, f"Episode did not terminate after {tick} ticks"
# Hard difficulty with 40 max ticks should terminate
assert tick <= 41, f"Episode took too many ticks: {tick}"
print("βœ“ test_slo_breach_terminates PASSED")
# --------------------------------------------------------------------------
# Test 7: Score variance (different agent behaviors yield different scores)
# --------------------------------------------------------------------------
def test_score_variance():
"""Grader must produce meaningfully different scores for different behaviors."""
# Zero-effort agent: immediately gives up
env1 = FirewatchEnvironment()
env1.reset(difficulty="easy", seed=42)
obs_zero = env1.step(FirewatchAction(action_type="declare_resolved"))
score_zero = obs_zero.metadata["episode_score"]
# Active agent: investigates, lets fault develop, remediates, then resolves
env2 = FirewatchEnvironment()
obs2 = env2.reset(difficulty="easy", seed=42)
root_cause = env2._fault_config.root_cause_service
fault_type = env2._fault_config.fault_type
# Let the fault develop for a few ticks with investigation
for svc in list(obs2.services.keys()):
env2.step(FirewatchAction(action_type="fetch_logs", target_service=svc))
# Apply correct remediation based on fault type
if fault_type == "oom":
env2.step(FirewatchAction(action_type="scale_replicas", target_service=root_cause))
elif fault_type == "bad_deploy":
env2.step(FirewatchAction(action_type="rollback_deploy", target_service=root_cause))
elif fault_type == "config_drift":
env2.step(FirewatchAction(action_type="revert_config", target_service=root_cause))
elif fault_type == "memory_leak":
env2.step(FirewatchAction(action_type="restart_service", target_service=root_cause))
elif fault_type == "network_partition":
env2.step(FirewatchAction(action_type="restart_service", target_service=root_cause))
# Let system recover for a few ticks
for _ in range(3):
env2.step(FirewatchAction(action_type="fetch_logs", target_service=root_cause))
obs_active = env2.step(FirewatchAction(action_type="declare_resolved"))
score_active = obs_active.metadata["episode_score"]
# Active agent should score higher than zero-effort
assert score_active > score_zero, \
f"Active agent ({score_active:.4f}) should score higher than zero-effort ({score_zero:.4f})"
print(f"βœ“ test_score_variance PASSED (zero={score_zero:.4f}, active={score_active:.4f})")
# --------------------------------------------------------------------------
# Test 8: No episode active -> graceful response
# --------------------------------------------------------------------------
def test_no_episode_step():
"""step() without prior reset() should return graceful error."""
env = FirewatchEnvironment()
action = FirewatchAction(action_type="fetch_logs", target_service="test")
obs = env.step(action)
assert obs is not None
# Should have error info
assert len(obs.action_history) > 0 or obs.metadata.get("error")
print("βœ“ test_no_episode_step PASSED")
# --------------------------------------------------------------------------
# Run all tests
# --------------------------------------------------------------------------
if __name__ == "__main__":
tests = [
test_reset_deterministic,
test_full_episode_flow,
test_invalid_action_graceful,
test_wrong_action_negative_reward,
test_grader_in_done_info,
test_slo_breach_terminates,
test_score_variance,
test_no_episode_step,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"βœ— {test.__name__} FAILED: {e}")
import traceback
traceback.print_exc()
failed += 1
print(f"\n{'='*60}")
print(f"Results: {passed} passed, {failed} failed out of {len(tests)} tests")
if failed == 0:
print("All Phase 7 acceptance criteria PASSED βœ“")
else:
print(f"FAILED β€” {failed} test(s) need fixing")
print(f"{'='*60}")