Spaces:
Sleeping
Sleeping
File size: 16,614 Bytes
34bd75f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 | """
test_env.py — Validates OnCallEnv works correctly.
Run: python test_env.py
Requires: pip install -r requirements.txt
"""
import sys
import json
from environment import OnCallEnvironment
from models import Action
from graders import grade_task
def test_easy_optimal():
"""Test easy task with optimal action sequence including mark_resolved."""
env = OnCallEnvironment()
obs = env.reset("easy_memory_leak")
assert obs.task_id == "easy_memory_leak"
assert obs.step == 0
assert len(obs.alerts) == 3
print(" [PASS] Reset returns valid observation")
# Step 1: Check logs of payment-service
r = env.step(Action(command="check_logs payment-service"))
assert not r.done
assert "OutOfMemoryError" in r.observation.last_action_result
print(" [PASS] check_logs shows OOM errors")
# Step 2: Check metrics (dynamic: memory may degrade slightly from 94.7% baseline)
r = env.step(Action(command="check_metrics payment-service"))
assert "Memory usage:" in r.observation.last_action_result
# Memory should be very high (>90%) for the payment service with a memory leak
import re as _re
mem_match = _re.search(r"Memory usage:\s+([\d.]+)%", r.observation.last_action_result)
assert mem_match and float(mem_match.group(1)) > 90
print(" [PASS] check_metrics shows high memory")
# Step 3: Restart
r = env.step(Action(command="restart_service payment-service"))
assert not r.done # Agent gets extra steps to mark_resolved
assert "healthy" in r.observation.last_action_result.lower()
print(" [PASS] Restart fixes service, episode continues for mark_resolved")
# Step 4: Mark resolved
r = env.step(Action(command="mark_resolved payment-service memory leak due to OOM kills"))
assert r.done
assert r.reward.total >= 0.9
print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})")
# Grader
state = env.state()
score = grade_task("easy_memory_leak", state)
assert 0.0 <= score <= 1.0
assert score >= 0.9
print(f" [PASS] Grader returns valid score: {score}")
return score
def test_medium_optimal():
"""Test medium task with optimal action sequence."""
env = OnCallEnvironment()
env.reset("medium_cascading_failure")
# Investigate the chain
env.step(Action(command="check_metrics api-gateway"))
env.step(Action(command="check_logs api-gateway"))
env.step(Action(command="check_metrics order-service"))
env.step(Action(command="check_logs order-service"))
r = env.step(Action(command="check_config order-service"))
assert "db_pool_size" in r.observation.last_action_result
assert "5" in r.observation.last_action_result
print(" [PASS] Config shows db_pool_size = 5")
# Fix it
r = env.step(Action(command="update_config order-service db_pool_size 50"))
assert not r.done
assert "resolved" in r.observation.last_action_result.lower()
print(" [PASS] Config update fixes the issue")
# Mark resolved
r = env.step(Action(command="mark_resolved order-service db_pool_size connection pool exhausted config changed to 5"))
assert r.done
assert r.reward.total >= 0.9
print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})")
state = env.state()
score = grade_task("medium_cascading_failure", state)
assert score >= 0.8
print(f" [PASS] Grader score: {score}")
return score
def test_hard_optimal():
"""Test hard task with optimal action sequence."""
env = OnCallEnvironment()
env.reset("hard_cache_degradation")
# Broad investigation
env.step(Action(command="check_metrics api-gateway"))
env.step(Action(command="check_metrics order-service"))
env.step(Action(command="check_metrics product-service"))
env.step(Action(command="check_metrics cache-service"))
env.step(Action(command="check_logs cache-service"))
r = env.step(Action(command="check_deploy_history cache-service"))
assert "MurmurHash3" in r.observation.last_action_result or "hashing" in r.observation.last_action_result.lower()
print(" [PASS] Deploy history reveals hashing change")
env.step(Action(command="check_metrics postgres-primary"))
# Rollback cache
r = env.step(Action(command="rollback_deploy cache-service"))
assert not r.done
print(" [PASS] Rollback fixes cache, episode continues")
# Mark resolved
r = env.step(Action(command="mark_resolved cache-service deployment changed key hashing algorithm causing cache miss"))
assert r.done
assert r.reward.total >= 0.9
print(f" [PASS] mark_resolved completes incident (score: {r.reward.total})")
state = env.state()
score = grade_task("hard_cache_degradation", state)
assert score >= 0.8
print(f" [PASS] Grader score: {score}")
return score
def test_dns_optimal():
"""Test DNS misconfiguration scenario."""
env = OnCallEnvironment()
obs = env.reset("medium_dns_misconfiguration")
assert obs.task_id == "medium_dns_misconfiguration"
print(" [PASS] Reset works")
env.step(Action(command="check_metrics order-service"))
env.step(Action(command="check_logs order-service"))
r = env.step(Action(command="check_config order-service"))
assert "inventory-service-v2.internal" in r.observation.last_action_result
print(" [PASS] Config shows wrong hostname")
env.step(Action(command="check_metrics inventory-service"))
r = env.step(Action(command="update_config order-service inventory_host inventory-service.internal"))
assert not r.done
print(" [PASS] Config fix applied")
r = env.step(Action(command="mark_resolved order-service dns hostname misconfiguration inventory_host pointed to wrong host"))
assert r.done
assert r.reward.total >= 0.9
print(f" [PASS] DNS scenario completed (score: {r.reward.total})")
state = env.state()
score = grade_task("medium_dns_misconfiguration", state)
assert score >= 0.8
print(f" [PASS] Grader score: {score}")
return score
def test_replication_lag_optimal():
"""Test DB replication lag scenario."""
env = OnCallEnvironment()
obs = env.reset("hard_replication_lag")
assert obs.task_id == "hard_replication_lag"
print(" [PASS] Reset works")
env.step(Action(command="check_metrics user-service"))
env.step(Action(command="check_logs user-service"))
env.step(Action(command="check_metrics order-service"))
env.step(Action(command="check_metrics postgres-primary"))
env.step(Action(command="check_logs postgres-primary"))
env.step(Action(command="check_config postgres-primary"))
env.step(Action(command="check_metrics postgres-replica"))
print(" [PASS] Investigation chain complete")
r = env.step(Action(command="update_config postgres-primary batch_job_enabled false"))
assert not r.done
print(" [PASS] Batch job disabled")
r = env.step(Action(command="mark_resolved postgres-primary batch job nightly_aggregation causing replication lag"))
assert r.done
assert r.reward.total >= 0.8
print(f" [PASS] Replication lag scenario completed (score: {r.reward.total})")
state = env.state()
score = grade_task("hard_replication_lag", state)
assert score >= 0.7
print(f" [PASS] Grader score: {score}")
return score
def test_wrong_actions():
"""Test that wrong actions get penalized."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
# Restart wrong service
r = env.step(Action(command="restart_service user-service"))
assert not r.done
print(" [PASS] Restarting wrong service doesn't resolve")
# Check state has penalty
state = env.state()
assert state.reward_breakdown.get("penalty", 0) < 0
print(" [PASS] Penalty applied for wrong action")
def test_max_steps():
"""Test episode ends at max steps."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
# Burn through all steps with no-ops
for i in range(10):
r = env.step(Action(command="check_metrics api-gateway"))
assert r.done
print(f" [PASS] Episode ends at max steps (score: {r.reward.total})")
def test_invalid_commands():
"""Test error handling for invalid commands."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
r = env.step(Action(command="delete_everything"))
assert r.observation.last_action_error
print(" [PASS] Invalid command returns error")
r = env.step(Action(command="check_metrics nonexistent-service"))
assert r.observation.last_action_error
print(" [PASS] Unknown service returns error")
def test_list_tasks():
"""Test task listing."""
env = OnCallEnvironment()
tasks = env.list_tasks()
assert len(tasks) == 6
difficulties = {t["difficulty"] for t in tasks}
assert "easy" in difficulties
assert "medium" in difficulties
assert "hard" in difficulties
assert "expert" in difficulties
print(f" [PASS] {len(tasks)} tasks with difficulty range")
def test_state_endpoint():
"""Test state returns valid data."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
env.step(Action(command="check_logs payment-service"))
state = env.state()
assert state.task_id == "easy_memory_leak"
assert state.step == 1
assert len(state.actions_taken) == 1
assert "payment-service" in state.investigation_log
print(" [PASS] State endpoint returns correct data")
def test_score_range():
"""Verify all scores are in [0.0, 1.0]."""
env = OnCallEnvironment()
for task_id in ["easy_memory_leak", "medium_cascading_failure", "hard_cache_degradation",
"medium_dns_misconfiguration", "hard_replication_lag",
"expert_multi_root_cause"]:
env.reset(task_id)
for _ in range(5):
r = env.step(Action(command="check_metrics api-gateway"))
state = env.state()
assert 0.0 <= state.score <= 1.0, f"{task_id}: score {state.score} out of range"
print(" [PASS] All scores in [0.0, 1.0]")
def test_mark_resolved_positive():
"""Test mark_resolved with correct keywords gives full root cause credit."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
env.step(Action(command="check_logs payment-service"))
env.step(Action(command="restart_service payment-service"))
r = env.step(Action(command="mark_resolved payment-service memory leak OOM heap"))
assert r.done
state = env.state()
assert state.root_cause_identified
print(f" [PASS] Correct mark_resolved (score: {state.score})")
def test_mark_resolved_negative():
"""Test mark_resolved with wrong keywords doesn't give full credit."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
r = env.step(Action(command="mark_resolved everything is broken somewhere"))
assert not r.done
state = env.state()
assert not state.root_cause_identified
print(" [PASS] Wrong mark_resolved rejected")
def test_mark_resolved_partial():
"""Test mark_resolved with partial keywords gives partial credit."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
r = env.step(Action(command="mark_resolved memory issue detected"))
state = env.state()
assert state.root_cause_identified # partial: has 1 keyword
print(" [PASS] Partial mark_resolved gives partial credit")
def test_remediation_without_mark_resolved():
"""Test that correct remediation without mark_resolved still ends eventually."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
env.step(Action(command="restart_service payment-service"))
# 2 more steps allowed after remediation
r = env.step(Action(command="check_metrics api-gateway"))
assert not r.done # step 1 after remediation
r = env.step(Action(command="check_metrics api-gateway"))
assert r.done # step 2 after remediation — auto-ends
state = env.state()
assert state.score >= 0.3 # Gets remediation credit but no root cause or efficiency
print(f" [PASS] Episode ends 2 steps after remediation (score: {state.score})")
def test_expert_optimal():
"""Test expert multi-root-cause scenario with both fixes."""
env = OnCallEnvironment()
obs = env.reset("expert_multi_root_cause")
assert obs.task_id == "expert_multi_root_cause"
assert len(obs.alerts) >= 3
print(" [PASS] Reset works")
# Investigate both failure chains
env.step(Action(command="check_metrics api-gateway"))
env.step(Action(command="check_logs api-gateway"))
env.step(Action(command="check_metrics search-service"))
env.step(Action(command="check_logs search-service"))
r = env.step(Action(command="check_deploy_history search-service"))
assert "v3.1.0" in r.observation.last_action_result
print(" [PASS] Search deploy history shows broken deployment")
env.step(Action(command="check_metrics order-service"))
env.step(Action(command="check_logs order-service"))
r = env.step(Action(command="check_config order-service"))
assert "db_pool_size" in r.observation.last_action_result
print(" [PASS] Order config shows low pool size")
env.step(Action(command="check_metrics elasticsearch"))
# Fix 1: rollback search-service
r = env.step(Action(command="rollback_deploy search-service"))
assert not r.done
assert "1/2" in r.observation.last_action_result
print(" [PASS] First fix applied (1/2)")
# Fix 2: update order-service config
r = env.step(Action(command="update_config order-service db_pool_size 50"))
assert not r.done
assert "resolved" in r.observation.last_action_result.lower() or "2/2" in r.observation.last_action_result
print(" [PASS] Second fix applied (2/2)")
# Mark resolved
r = env.step(Action(command="mark_resolved search-service bad deployment v3.1.0 elasticsearch query AND order-service db_pool_size config drift both issues"))
assert r.done
assert r.reward.total >= 0.8
print(f" [PASS] Expert scenario completed (score: {r.reward.total})")
state = env.state()
score = grade_task("expert_multi_root_cause", state)
assert score >= 0.7
print(f" [PASS] Grader score: {score}")
return score
def test_grader_independence():
"""Test that graders compute scores independently from environment reward."""
env = OnCallEnvironment()
env.reset("easy_memory_leak")
env.step(Action(command="check_logs payment-service"))
env.step(Action(command="check_metrics payment-service"))
env.step(Action(command="restart_service payment-service"))
env.step(Action(command="mark_resolved payment-service memory leak OOM"))
state = env.state()
env_score = state.score
grader_score = grade_task("easy_memory_leak", state)
# Both should be high (may differ slightly since they compute independently)
assert grader_score >= 0.8
assert env_score >= 0.8
print(f" [PASS] Grader ({grader_score}) and env ({env_score}) both score high")
if __name__ == "__main__":
tests = [
("Easy optimal run", test_easy_optimal),
("Medium optimal run", test_medium_optimal),
("Hard optimal run", test_hard_optimal),
("DNS misconfiguration optimal", test_dns_optimal),
("DB replication lag optimal", test_replication_lag_optimal),
("Expert multi-root-cause optimal", test_expert_optimal),
("Wrong actions penalty", test_wrong_actions),
("Max steps termination", test_max_steps),
("Invalid commands", test_invalid_commands),
("Task listing", test_list_tasks),
("State endpoint", test_state_endpoint),
("Score range validation", test_score_range),
("mark_resolved positive", test_mark_resolved_positive),
("mark_resolved negative", test_mark_resolved_negative),
("mark_resolved partial", test_mark_resolved_partial),
("Remediation without mark_resolved", test_remediation_without_mark_resolved),
("Grader independence", test_grader_independence),
]
passed = 0
failed = 0
for name, fn in tests:
print(f"\n{'─'*50}")
print(f"TEST: {name}")
try:
fn()
passed += 1
except Exception as e:
print(f" [FAIL] {e}")
import traceback
traceback.print_exc()
failed += 1
print(f"\n{'═'*50}")
print(f"Results: {passed} passed, {failed} failed")
if failed:
sys.exit(1)
print("All tests passed!")
|