Spaces:
Sleeping
Sleeping
| """ | |
| test_env.py — Validates OnCallEnv works correctly. | |
| Run: python test_env.py | |
| Requires: pip install -r requirements.txt | |
| """ | |
| import sys | |
| import json | |
| from environment import OnCallEnvironment | |
| from models import Action | |
| from graders import grade_task | |
| def test_easy_optimal(): | |
| """Test easy task with optimal action sequence including mark_resolved.""" | |
| env = OnCallEnvironment() | |
| obs = env.reset("easy_memory_leak") | |
| assert obs.task_id == "easy_memory_leak" | |
| assert obs.step == 0 | |
| assert len(obs.alerts) == 3 | |
| print(" [PASS] Reset returns valid observation") | |
| # Step 1: Check logs of payment-service | |
| r = env.step(Action(command="check_logs payment-service")) | |
| assert not r.done | |
| assert "OutOfMemoryError" in r.observation.last_action_result | |
| print(" [PASS] check_logs shows OOM errors") | |
| # Step 2: Check metrics (dynamic: memory may degrade slightly from 94.7% baseline) | |
| r = env.step(Action(command="check_metrics payment-service")) | |
| assert "Memory usage:" in r.observation.last_action_result | |
| # Memory should be very high (>90%) for the payment service with a memory leak | |
| import re as _re | |
| mem_match = _re.search(r"Memory usage:\s+([\d.]+)%", r.observation.last_action_result) | |
| assert mem_match and float(mem_match.group(1)) > 90 | |
| print(" [PASS] check_metrics shows high memory") | |
| # Step 3: Restart | |
| r = env.step(Action(command="restart_service payment-service")) | |
| assert not r.done # Agent gets extra steps to mark_resolved | |
| assert "healthy" in r.observation.last_action_result.lower() | |
| print(" [PASS] Restart fixes service, episode continues for mark_resolved") | |
| # Step 4: Mark resolved | |
| r = env.step(Action(command="mark_resolved payment-service memory leak due to OOM kills")) | |
| assert r.done | |
| assert r.reward.total >= 0.9 | |
| print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})") | |
| # Grader | |
| state = env.state() | |
| score = grade_task("easy_memory_leak", state) | |
| assert 0.0 <= score <= 1.0 | |
| assert score >= 0.9 | |
| print(f" [PASS] Grader returns valid score: {score}") | |
| return score | |
| def test_medium_optimal(): | |
| """Test medium task with optimal action sequence.""" | |
| env = OnCallEnvironment() | |
| env.reset("medium_cascading_failure") | |
| # Investigate the chain | |
| env.step(Action(command="check_metrics api-gateway")) | |
| env.step(Action(command="check_logs api-gateway")) | |
| env.step(Action(command="check_metrics order-service")) | |
| env.step(Action(command="check_logs order-service")) | |
| r = env.step(Action(command="check_config order-service")) | |
| assert "db_pool_size" in r.observation.last_action_result | |
| assert "5" in r.observation.last_action_result | |
| print(" [PASS] Config shows db_pool_size = 5") | |
| # Fix it | |
| r = env.step(Action(command="update_config order-service db_pool_size 50")) | |
| assert not r.done | |
| assert "resolved" in r.observation.last_action_result.lower() | |
| print(" [PASS] Config update fixes the issue") | |
| # Mark resolved | |
| r = env.step(Action(command="mark_resolved order-service db_pool_size connection pool exhausted config changed to 5")) | |
| assert r.done | |
| assert r.reward.total >= 0.9 | |
| print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})") | |
| state = env.state() | |
| score = grade_task("medium_cascading_failure", state) | |
| assert score >= 0.8 | |
| print(f" [PASS] Grader score: {score}") | |
| return score | |
| def test_hard_optimal(): | |
| """Test hard task with optimal action sequence.""" | |
| env = OnCallEnvironment() | |
| env.reset("hard_cache_degradation") | |
| # Broad investigation | |
| env.step(Action(command="check_metrics api-gateway")) | |
| env.step(Action(command="check_metrics order-service")) | |
| env.step(Action(command="check_metrics product-service")) | |
| env.step(Action(command="check_metrics cache-service")) | |
| env.step(Action(command="check_logs cache-service")) | |
| r = env.step(Action(command="check_deploy_history cache-service")) | |
| assert "MurmurHash3" in r.observation.last_action_result or "hashing" in r.observation.last_action_result.lower() | |
| print(" [PASS] Deploy history reveals hashing change") | |
| env.step(Action(command="check_metrics postgres-primary")) | |
| # Rollback cache | |
| r = env.step(Action(command="rollback_deploy cache-service")) | |
| assert not r.done | |
| print(" [PASS] Rollback fixes cache, episode continues") | |
| # Mark resolved | |
| r = env.step(Action(command="mark_resolved cache-service deployment changed key hashing algorithm causing cache miss")) | |
| assert r.done | |
| assert r.reward.total >= 0.9 | |
| print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})") | |
| state = env.state() | |
| score = grade_task("hard_cache_degradation", state) | |
| assert score >= 0.8 | |
| print(f" [PASS] Grader score: {score}") | |
| return score | |
| def test_dns_optimal(): | |
| """Test DNS misconfiguration scenario.""" | |
| env = OnCallEnvironment() | |
| obs = env.reset("medium_dns_misconfiguration") | |
| assert obs.task_id == "medium_dns_misconfiguration" | |
| print(" [PASS] Reset works") | |
| env.step(Action(command="check_metrics order-service")) | |
| env.step(Action(command="check_logs order-service")) | |
| r = env.step(Action(command="check_config order-service")) | |
| assert "inventory-service-v2.internal" in r.observation.last_action_result | |
| print(" [PASS] Config shows wrong hostname") | |
| env.step(Action(command="check_metrics inventory-service")) | |
| r = env.step(Action(command="update_config order-service inventory_host inventory-service.internal")) | |
| assert not r.done | |
| print(" [PASS] Config fix applied") | |
| r = env.step(Action(command="mark_resolved order-service dns hostname misconfiguration inventory_host pointed to wrong host")) | |
| assert r.done | |
| assert r.reward.total >= 0.9 | |
| print(f" [PASS] DNS scenario completed (score: {r.reward.total})") | |
| state = env.state() | |
| score = grade_task("medium_dns_misconfiguration", state) | |
| assert score >= 0.8 | |
| print(f" [PASS] Grader score: {score}") | |
| return score | |
| def test_replication_lag_optimal(): | |
| """Test DB replication lag scenario.""" | |
| env = OnCallEnvironment() | |
| obs = env.reset("hard_replication_lag") | |
| assert obs.task_id == "hard_replication_lag" | |
| print(" [PASS] Reset works") | |
| env.step(Action(command="check_metrics user-service")) | |
| env.step(Action(command="check_logs user-service")) | |
| env.step(Action(command="check_metrics order-service")) | |
| env.step(Action(command="check_metrics postgres-primary")) | |
| env.step(Action(command="check_logs postgres-primary")) | |
| env.step(Action(command="check_config postgres-primary")) | |
| env.step(Action(command="check_metrics postgres-replica")) | |
| print(" [PASS] Investigation chain complete") | |
| r = env.step(Action(command="update_config postgres-primary batch_job_enabled false")) | |
| assert not r.done | |
| print(" [PASS] Batch job disabled") | |
| r = env.step(Action(command="mark_resolved postgres-primary batch job nightly_aggregation causing replication lag")) | |
| assert r.done | |
| assert r.reward.total >= 0.8 | |
| print(f" [PASS] Replication lag scenario completed (score: {r.reward.total})") | |
| state = env.state() | |
| score = grade_task("hard_replication_lag", state) | |
| assert score >= 0.7 | |
| print(f" [PASS] Grader score: {score}") | |
| return score | |
| def test_wrong_actions(): | |
| """Test that wrong actions get penalized.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| # Restart wrong service | |
| r = env.step(Action(command="restart_service user-service")) | |
| assert not r.done | |
| print(" [PASS] Restarting wrong service doesn't resolve") | |
| # Check state has penalty | |
| state = env.state() | |
| assert state.reward_breakdown.get("penalty", 0) < 0 | |
| print(" [PASS] Penalty applied for wrong action") | |
| def test_max_steps(): | |
| """Test episode ends at max steps.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| # Burn through all steps with no-ops | |
| for i in range(10): | |
| r = env.step(Action(command="check_metrics api-gateway")) | |
| assert r.done | |
| print(f" [PASS] Episode ends at max steps (score: {r.reward.total})") | |
| def test_invalid_commands(): | |
| """Test error handling for invalid commands.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| r = env.step(Action(command="delete_everything")) | |
| assert r.observation.last_action_error | |
| print(" [PASS] Invalid command returns error") | |
| r = env.step(Action(command="check_metrics nonexistent-service")) | |
| assert r.observation.last_action_error | |
| print(" [PASS] Unknown service returns error") | |
| def test_list_tasks(): | |
| """Test task listing.""" | |
| env = OnCallEnvironment() | |
| tasks = env.list_tasks() | |
| assert len(tasks) == 6 | |
| difficulties = {t["difficulty"] for t in tasks} | |
| assert "easy" in difficulties | |
| assert "medium" in difficulties | |
| assert "hard" in difficulties | |
| assert "expert" in difficulties | |
| print(f" [PASS] {len(tasks)} tasks with difficulty range") | |
| def test_state_endpoint(): | |
| """Test state returns valid data.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| env.step(Action(command="check_logs payment-service")) | |
| state = env.state() | |
| assert state.task_id == "easy_memory_leak" | |
| assert state.step == 1 | |
| assert len(state.actions_taken) == 1 | |
| assert "payment-service" in state.investigation_log | |
| print(" [PASS] State endpoint returns correct data") | |
| def test_score_range(): | |
| """Verify all scores are in [0.0, 1.0].""" | |
| env = OnCallEnvironment() | |
| for task_id in ["easy_memory_leak", "medium_cascading_failure", "hard_cache_degradation", | |
| "medium_dns_misconfiguration", "hard_replication_lag", | |
| "expert_multi_root_cause"]: | |
| env.reset(task_id) | |
| for _ in range(5): | |
| r = env.step(Action(command="check_metrics api-gateway")) | |
| state = env.state() | |
| assert 0.0 <= state.score <= 1.0, f"{task_id}: score {state.score} out of range" | |
| print(" [PASS] All scores in [0.0, 1.0]") | |
| def test_mark_resolved_positive(): | |
| """Test mark_resolved with correct keywords gives full root cause credit.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| env.step(Action(command="check_logs payment-service")) | |
| env.step(Action(command="restart_service payment-service")) | |
| r = env.step(Action(command="mark_resolved payment-service memory leak OOM heap")) | |
| assert r.done | |
| state = env.state() | |
| assert state.root_cause_identified | |
| print(f" [PASS] Correct mark_resolved (score: {state.score})") | |
| def test_mark_resolved_negative(): | |
| """Test mark_resolved with wrong keywords doesn't give full credit.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| r = env.step(Action(command="mark_resolved everything is broken somewhere")) | |
| assert not r.done | |
| state = env.state() | |
| assert not state.root_cause_identified | |
| print(" [PASS] Wrong mark_resolved rejected") | |
| def test_mark_resolved_partial(): | |
| """Test mark_resolved with partial keywords gives partial credit.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| r = env.step(Action(command="mark_resolved memory issue detected")) | |
| state = env.state() | |
| assert state.root_cause_identified # partial: has 1 keyword | |
| print(" [PASS] Partial mark_resolved gives partial credit") | |
| def test_remediation_without_mark_resolved(): | |
| """Test that correct remediation without mark_resolved still ends eventually.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| env.step(Action(command="restart_service payment-service")) | |
| # 2 more steps allowed after remediation | |
| r = env.step(Action(command="check_metrics api-gateway")) | |
| assert not r.done # step 1 after remediation | |
| r = env.step(Action(command="check_metrics api-gateway")) | |
| assert r.done # step 2 after remediation — auto-ends | |
| state = env.state() | |
| assert state.score >= 0.3 # Gets remediation credit but no root cause or efficiency | |
| print(f" [PASS] Episode ends 2 steps after remediation (score: {state.score})") | |
| def test_expert_optimal(): | |
| """Test expert multi-root-cause scenario with both fixes.""" | |
| env = OnCallEnvironment() | |
| obs = env.reset("expert_multi_root_cause") | |
| assert obs.task_id == "expert_multi_root_cause" | |
| assert len(obs.alerts) >= 3 | |
| print(" [PASS] Reset works") | |
| # Investigate both failure chains | |
| env.step(Action(command="check_metrics api-gateway")) | |
| env.step(Action(command="check_logs api-gateway")) | |
| env.step(Action(command="check_metrics search-service")) | |
| env.step(Action(command="check_logs search-service")) | |
| r = env.step(Action(command="check_deploy_history search-service")) | |
| assert "v3.1.0" in r.observation.last_action_result | |
| print(" [PASS] Search deploy history shows broken deployment") | |
| env.step(Action(command="check_metrics order-service")) | |
| env.step(Action(command="check_logs order-service")) | |
| r = env.step(Action(command="check_config order-service")) | |
| assert "db_pool_size" in r.observation.last_action_result | |
| print(" [PASS] Order config shows low pool size") | |
| env.step(Action(command="check_metrics elasticsearch")) | |
| # Fix 1: rollback search-service | |
| r = env.step(Action(command="rollback_deploy search-service")) | |
| assert not r.done | |
| assert "1/2" in r.observation.last_action_result | |
| print(" [PASS] First fix applied (1/2)") | |
| # Fix 2: update order-service config | |
| r = env.step(Action(command="update_config order-service db_pool_size 50")) | |
| assert not r.done | |
| assert "resolved" in r.observation.last_action_result.lower() or "2/2" in r.observation.last_action_result | |
| print(" [PASS] Second fix applied (2/2)") | |
| # Mark resolved | |
| r = env.step(Action(command="mark_resolved search-service bad deployment v3.1.0 elasticsearch query AND order-service db_pool_size config drift both issues")) | |
| assert r.done | |
| assert r.reward.total >= 0.8 | |
| print(f" [PASS] Expert scenario completed (score: {r.reward.total})") | |
| state = env.state() | |
| score = grade_task("expert_multi_root_cause", state) | |
| assert score >= 0.7 | |
| print(f" [PASS] Grader score: {score}") | |
| return score | |
| def test_grader_independence(): | |
| """Test that graders compute scores independently from environment reward.""" | |
| env = OnCallEnvironment() | |
| env.reset("easy_memory_leak") | |
| env.step(Action(command="check_logs payment-service")) | |
| env.step(Action(command="check_metrics payment-service")) | |
| env.step(Action(command="restart_service payment-service")) | |
| env.step(Action(command="mark_resolved payment-service memory leak OOM")) | |
| state = env.state() | |
| env_score = state.score | |
| grader_score = grade_task("easy_memory_leak", state) | |
| # Both should be high (may differ slightly since they compute independently) | |
| assert grader_score >= 0.8 | |
| assert env_score >= 0.8 | |
| print(f" [PASS] Grader ({grader_score}) and env ({env_score}) both score high") | |
| if __name__ == "__main__": | |
| tests = [ | |
| ("Easy optimal run", test_easy_optimal), | |
| ("Medium optimal run", test_medium_optimal), | |
| ("Hard optimal run", test_hard_optimal), | |
| ("DNS misconfiguration optimal", test_dns_optimal), | |
| ("DB replication lag optimal", test_replication_lag_optimal), | |
| ("Expert multi-root-cause optimal", test_expert_optimal), | |
| ("Wrong actions penalty", test_wrong_actions), | |
| ("Max steps termination", test_max_steps), | |
| ("Invalid commands", test_invalid_commands), | |
| ("Task listing", test_list_tasks), | |
| ("State endpoint", test_state_endpoint), | |
| ("Score range validation", test_score_range), | |
| ("mark_resolved positive", test_mark_resolved_positive), | |
| ("mark_resolved negative", test_mark_resolved_negative), | |
| ("mark_resolved partial", test_mark_resolved_partial), | |
| ("Remediation without mark_resolved", test_remediation_without_mark_resolved), | |
| ("Grader independence", test_grader_independence), | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for name, fn in tests: | |
| print(f"\n{'─'*50}") | |
| print(f"TEST: {name}") | |
| try: | |
| fn() | |
| passed += 1 | |
| except Exception as e: | |
| print(f" [FAIL] {e}") | |
| import traceback | |
| traceback.print_exc() | |
| failed += 1 | |
| print(f"\n{'═'*50}") | |
| print(f"Results: {passed} passed, {failed} failed") | |
| if failed: | |
| sys.exit(1) | |
| print("All tests passed!") | |