File size: 16,614 Bytes
34bd75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
"""
test_env.py — Validates OnCallEnv works correctly.

Run: python test_env.py
Requires: pip install -r requirements.txt
"""

import sys
import json
from environment import OnCallEnvironment
from models import Action
from graders import grade_task


def test_easy_optimal():
    """Test easy task with optimal action sequence including mark_resolved."""
    env = OnCallEnvironment()
    obs = env.reset("easy_memory_leak")
    assert obs.task_id == "easy_memory_leak"
    assert obs.step == 0
    assert len(obs.alerts) == 3
    print("  [PASS] Reset returns valid observation")

    # Step 1: Check logs of payment-service
    r = env.step(Action(command="check_logs payment-service"))
    assert not r.done
    assert "OutOfMemoryError" in r.observation.last_action_result
    print("  [PASS] check_logs shows OOM errors")

    # Step 2: Check metrics (dynamic: memory may degrade slightly from 94.7% baseline)
    r = env.step(Action(command="check_metrics payment-service"))
    assert "Memory usage:" in r.observation.last_action_result
    # Memory should be very high (>90%) for the payment service with a memory leak
    import re as _re
    mem_match = _re.search(r"Memory usage:\s+([\d.]+)%", r.observation.last_action_result)
    assert mem_match and float(mem_match.group(1)) > 90
    print("  [PASS] check_metrics shows high memory")

    # Step 3: Restart
    r = env.step(Action(command="restart_service payment-service"))
    assert not r.done  # Agent gets extra steps to mark_resolved
    assert "healthy" in r.observation.last_action_result.lower()
    print("  [PASS] Restart fixes service, episode continues for mark_resolved")

    # Step 4: Mark resolved
    r = env.step(Action(command="mark_resolved payment-service memory leak due to OOM kills"))
    assert r.done
    assert r.reward.total >= 0.9
    print(f"  [PASS] mark_resolved completes incident (score: {r.reward.total})")

    # Grader
    state = env.state()
    score = grade_task("easy_memory_leak", state)
    assert 0.0 <= score <= 1.0
    assert score >= 0.9
    print(f"  [PASS] Grader returns valid score: {score}")
    return score


def test_medium_optimal():
    """Test medium task with optimal action sequence."""
    env = OnCallEnvironment()
    env.reset("medium_cascading_failure")

    # Investigate the chain
    env.step(Action(command="check_metrics api-gateway"))
    env.step(Action(command="check_logs api-gateway"))
    env.step(Action(command="check_metrics order-service"))
    env.step(Action(command="check_logs order-service"))
    r = env.step(Action(command="check_config order-service"))
    assert "db_pool_size" in r.observation.last_action_result
    assert "5" in r.observation.last_action_result
    print("  [PASS] Config shows db_pool_size = 5")

    # Fix it
    r = env.step(Action(command="update_config order-service db_pool_size 50"))
    assert not r.done
    assert "resolved" in r.observation.last_action_result.lower()
    print("  [PASS] Config update fixes the issue")

    # Mark resolved
    r = env.step(Action(command="mark_resolved order-service db_pool_size connection pool exhausted config changed to 5"))
    assert r.done
    assert r.reward.total >= 0.9
    print(f"  [PASS] mark_resolved completes incident (score: {r.reward.total})")

    state = env.state()
    score = grade_task("medium_cascading_failure", state)
    assert score >= 0.8
    print(f"  [PASS] Grader score: {score}")
    return score


def test_hard_optimal():
    """Test hard task with optimal action sequence."""
    env = OnCallEnvironment()
    env.reset("hard_cache_degradation")

    # Broad investigation
    env.step(Action(command="check_metrics api-gateway"))
    env.step(Action(command="check_metrics order-service"))
    env.step(Action(command="check_metrics product-service"))
    env.step(Action(command="check_metrics cache-service"))
    env.step(Action(command="check_logs cache-service"))
    r = env.step(Action(command="check_deploy_history cache-service"))
    assert "MurmurHash3" in r.observation.last_action_result or "hashing" in r.observation.last_action_result.lower()
    print("  [PASS] Deploy history reveals hashing change")

    env.step(Action(command="check_metrics postgres-primary"))

    # Rollback cache
    r = env.step(Action(command="rollback_deploy cache-service"))
    assert not r.done
    print("  [PASS] Rollback fixes cache, episode continues")

    # Mark resolved
    r = env.step(Action(command="mark_resolved cache-service deployment changed key hashing algorithm causing cache miss"))
    assert r.done
    assert r.reward.total >= 0.9
    print(f"  [PASS] mark_resolved completes incident (score: {r.reward.total})")

    state = env.state()
    score = grade_task("hard_cache_degradation", state)
    assert score >= 0.8
    print(f"  [PASS] Grader score: {score}")
    return score


def test_dns_optimal():
    """Test DNS misconfiguration scenario."""
    env = OnCallEnvironment()
    obs = env.reset("medium_dns_misconfiguration")
    assert obs.task_id == "medium_dns_misconfiguration"
    print("  [PASS] Reset works")

    env.step(Action(command="check_metrics order-service"))
    env.step(Action(command="check_logs order-service"))
    r = env.step(Action(command="check_config order-service"))
    assert "inventory-service-v2.internal" in r.observation.last_action_result
    print("  [PASS] Config shows wrong hostname")

    env.step(Action(command="check_metrics inventory-service"))

    r = env.step(Action(command="update_config order-service inventory_host inventory-service.internal"))
    assert not r.done
    print("  [PASS] Config fix applied")

    r = env.step(Action(command="mark_resolved order-service dns hostname misconfiguration inventory_host pointed to wrong host"))
    assert r.done
    assert r.reward.total >= 0.9
    print(f"  [PASS] DNS scenario completed (score: {r.reward.total})")

    state = env.state()
    score = grade_task("medium_dns_misconfiguration", state)
    assert score >= 0.8
    print(f"  [PASS] Grader score: {score}")
    return score


def test_replication_lag_optimal():
    """Test DB replication lag scenario."""
    env = OnCallEnvironment()
    obs = env.reset("hard_replication_lag")
    assert obs.task_id == "hard_replication_lag"
    print("  [PASS] Reset works")

    env.step(Action(command="check_metrics user-service"))
    env.step(Action(command="check_logs user-service"))
    env.step(Action(command="check_metrics order-service"))
    env.step(Action(command="check_metrics postgres-primary"))
    env.step(Action(command="check_logs postgres-primary"))
    env.step(Action(command="check_config postgres-primary"))
    env.step(Action(command="check_metrics postgres-replica"))
    print("  [PASS] Investigation chain complete")

    r = env.step(Action(command="update_config postgres-primary batch_job_enabled false"))
    assert not r.done
    print("  [PASS] Batch job disabled")

    r = env.step(Action(command="mark_resolved postgres-primary batch job nightly_aggregation causing replication lag"))
    assert r.done
    assert r.reward.total >= 0.8
    print(f"  [PASS] Replication lag scenario completed (score: {r.reward.total})")

    state = env.state()
    score = grade_task("hard_replication_lag", state)
    assert score >= 0.7
    print(f"  [PASS] Grader score: {score}")
    return score


def test_wrong_actions():
    """Test that wrong actions get penalized."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")

    # Restart wrong service
    r = env.step(Action(command="restart_service user-service"))
    assert not r.done
    print("  [PASS] Restarting wrong service doesn't resolve")

    # Check state has penalty
    state = env.state()
    assert state.reward_breakdown.get("penalty", 0) < 0
    print("  [PASS] Penalty applied for wrong action")


def test_max_steps():
    """Test episode ends at max steps."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")

    # Burn through all steps with no-ops
    for i in range(10):
        r = env.step(Action(command="check_metrics api-gateway"))
    assert r.done
    print(f"  [PASS] Episode ends at max steps (score: {r.reward.total})")


def test_invalid_commands():
    """Test error handling for invalid commands."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")

    r = env.step(Action(command="delete_everything"))
    assert r.observation.last_action_error
    print("  [PASS] Invalid command returns error")

    r = env.step(Action(command="check_metrics nonexistent-service"))
    assert r.observation.last_action_error
    print("  [PASS] Unknown service returns error")


def test_list_tasks():
    """Test task listing."""
    env = OnCallEnvironment()
    tasks = env.list_tasks()
    assert len(tasks) == 6
    difficulties = {t["difficulty"] for t in tasks}
    assert "easy" in difficulties
    assert "medium" in difficulties
    assert "hard" in difficulties
    assert "expert" in difficulties
    print(f"  [PASS] {len(tasks)} tasks with difficulty range")


def test_state_endpoint():
    """Test state returns valid data."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    env.step(Action(command="check_logs payment-service"))

    state = env.state()
    assert state.task_id == "easy_memory_leak"
    assert state.step == 1
    assert len(state.actions_taken) == 1
    assert "payment-service" in state.investigation_log
    print("  [PASS] State endpoint returns correct data")


def test_score_range():
    """Verify all scores are in [0.0, 1.0]."""
    env = OnCallEnvironment()

    for task_id in ["easy_memory_leak", "medium_cascading_failure", "hard_cache_degradation",
                    "medium_dns_misconfiguration", "hard_replication_lag",
                    "expert_multi_root_cause"]:
        env.reset(task_id)
        for _ in range(5):
            r = env.step(Action(command="check_metrics api-gateway"))
        state = env.state()
        assert 0.0 <= state.score <= 1.0, f"{task_id}: score {state.score} out of range"
    print("  [PASS] All scores in [0.0, 1.0]")


def test_mark_resolved_positive():
    """Test mark_resolved with correct keywords gives full root cause credit."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    env.step(Action(command="check_logs payment-service"))
    env.step(Action(command="restart_service payment-service"))
    r = env.step(Action(command="mark_resolved payment-service memory leak OOM heap"))
    assert r.done
    state = env.state()
    assert state.root_cause_identified
    print(f"  [PASS] Correct mark_resolved (score: {state.score})")


def test_mark_resolved_negative():
    """Test mark_resolved with wrong keywords doesn't give full credit."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    r = env.step(Action(command="mark_resolved everything is broken somewhere"))
    assert not r.done
    state = env.state()
    assert not state.root_cause_identified
    print("  [PASS] Wrong mark_resolved rejected")


def test_mark_resolved_partial():
    """Test mark_resolved with partial keywords gives partial credit."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    r = env.step(Action(command="mark_resolved memory issue detected"))
    state = env.state()
    assert state.root_cause_identified  # partial: has 1 keyword
    print("  [PASS] Partial mark_resolved gives partial credit")


def test_remediation_without_mark_resolved():
    """Test that correct remediation without mark_resolved still ends eventually."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    env.step(Action(command="restart_service payment-service"))
    # 2 more steps allowed after remediation
    r = env.step(Action(command="check_metrics api-gateway"))
    assert not r.done  # step 1 after remediation
    r = env.step(Action(command="check_metrics api-gateway"))
    assert r.done  # step 2 after remediation — auto-ends
    state = env.state()
    assert state.score >= 0.3  # Gets remediation credit but no root cause or efficiency
    print(f"  [PASS] Episode ends 2 steps after remediation (score: {state.score})")


def test_expert_optimal():
    """Test expert multi-root-cause scenario with both fixes."""
    env = OnCallEnvironment()
    obs = env.reset("expert_multi_root_cause")
    assert obs.task_id == "expert_multi_root_cause"
    assert len(obs.alerts) >= 3
    print("  [PASS] Reset works")

    # Investigate both failure chains
    env.step(Action(command="check_metrics api-gateway"))
    env.step(Action(command="check_logs api-gateway"))
    env.step(Action(command="check_metrics search-service"))
    env.step(Action(command="check_logs search-service"))
    r = env.step(Action(command="check_deploy_history search-service"))
    assert "v3.1.0" in r.observation.last_action_result
    print("  [PASS] Search deploy history shows broken deployment")

    env.step(Action(command="check_metrics order-service"))
    env.step(Action(command="check_logs order-service"))
    r = env.step(Action(command="check_config order-service"))
    assert "db_pool_size" in r.observation.last_action_result
    print("  [PASS] Order config shows low pool size")

    env.step(Action(command="check_metrics elasticsearch"))

    # Fix 1: rollback search-service
    r = env.step(Action(command="rollback_deploy search-service"))
    assert not r.done
    assert "1/2" in r.observation.last_action_result
    print("  [PASS] First fix applied (1/2)")

    # Fix 2: update order-service config
    r = env.step(Action(command="update_config order-service db_pool_size 50"))
    assert not r.done
    assert "resolved" in r.observation.last_action_result.lower() or "2/2" in r.observation.last_action_result
    print("  [PASS] Second fix applied (2/2)")

    # Mark resolved
    r = env.step(Action(command="mark_resolved search-service bad deployment v3.1.0 elasticsearch query AND order-service db_pool_size config drift both issues"))
    assert r.done
    assert r.reward.total >= 0.8
    print(f"  [PASS] Expert scenario completed (score: {r.reward.total})")

    state = env.state()
    score = grade_task("expert_multi_root_cause", state)
    assert score >= 0.7
    print(f"  [PASS] Grader score: {score}")
    return score


def test_grader_independence():
    """Test that graders compute scores independently from environment reward."""
    env = OnCallEnvironment()
    env.reset("easy_memory_leak")
    env.step(Action(command="check_logs payment-service"))
    env.step(Action(command="check_metrics payment-service"))
    env.step(Action(command="restart_service payment-service"))
    env.step(Action(command="mark_resolved payment-service memory leak OOM"))

    state = env.state()
    env_score = state.score
    grader_score = grade_task("easy_memory_leak", state)

    # Both should be high (may differ slightly since they compute independently)
    assert grader_score >= 0.8
    assert env_score >= 0.8
    print(f"  [PASS] Grader ({grader_score}) and env ({env_score}) both score high")


if __name__ == "__main__":
    tests = [
        ("Easy optimal run", test_easy_optimal),
        ("Medium optimal run", test_medium_optimal),
        ("Hard optimal run", test_hard_optimal),
        ("DNS misconfiguration optimal", test_dns_optimal),
        ("DB replication lag optimal", test_replication_lag_optimal),
        ("Expert multi-root-cause optimal", test_expert_optimal),
        ("Wrong actions penalty", test_wrong_actions),
        ("Max steps termination", test_max_steps),
        ("Invalid commands", test_invalid_commands),
        ("Task listing", test_list_tasks),
        ("State endpoint", test_state_endpoint),
        ("Score range validation", test_score_range),
        ("mark_resolved positive", test_mark_resolved_positive),
        ("mark_resolved negative", test_mark_resolved_negative),
        ("mark_resolved partial", test_mark_resolved_partial),
        ("Remediation without mark_resolved", test_remediation_without_mark_resolved),
        ("Grader independence", test_grader_independence),
    ]

    passed = 0
    failed = 0
    for name, fn in tests:
        print(f"\n{'─'*50}")
        print(f"TEST: {name}")
        try:
            fn()
            passed += 1
        except Exception as e:
            print(f"  [FAIL] {e}")
            import traceback
            traceback.print_exc()
            failed += 1

    print(f"\n{'═'*50}")
    print(f"Results: {passed} passed, {failed} failed")
    if failed:
        sys.exit(1)
    print("All tests passed!")