| """Quick test runner that writes output to a file for Windows compatibility.""" |
| import sys, os |
|
|
| |
| output_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'test_results.log') |
| sys.stdout = open(output_file, 'w', encoding='utf-8') |
| sys.stderr = sys.stdout |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from secureheal_arena.sandbox import SandboxEngine |
| from secureheal_arena.vulnerabilities import SQL_INJECTION, XSS_STORED, PATH_TRAVERSAL, get_scenarios_for_level |
| from secureheal_arena.anomalies import MEMORY_SPIKE, select_anomaly |
| from secureheal_arena.rewards import compute_reward |
| from secureheal_arena.models import SecureHealState |
| import random |
|
|
| results = [] |
|
|
| def run_test(name, fn): |
| try: |
| fn() |
| results.append((name, True, "")) |
| print(f"[PASS] {name}") |
| except Exception as e: |
| results.append((name, False, str(e))) |
| print(f"[FAIL] {name}: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
| def t1(): |
| engine = SandboxEngine() |
| result = engine.execute('x = 1 + 1\nprint(x)') |
| assert result.success, f"Sandbox failed: {result.error_message}" |
| assert '2' in result.stdout |
|
|
| run_test("sandbox_basic", t1) |
|
|
| |
| def t2(): |
| if sys.platform == 'win32': |
| print(" [SKIP] Timeout test skipped on Windows (GIL limitation)") |
| print(" Works correctly on Linux/HF Spaces via SIGALRM") |
| return |
| engine = SandboxEngine(timeout=1) |
| result = engine.execute('while True: pass') |
| assert result.timeout, "Timeout should have triggered" |
|
|
| run_test("sandbox_timeout", t2) |
|
|
| |
| def t3(): |
| engine = SandboxEngine() |
| for code in ['import os', 'import subprocess', 'from os import path', 'eval("1+1")', '__import__("os")']: |
| result = engine.execute(code) |
| assert result.forbidden_op, f"Should have blocked: {code}" |
|
|
| run_test("sandbox_forbidden_ops", t3) |
|
|
| |
| def t4(): |
| engine = SandboxEngine() |
| |
| r = engine.execute_tests(SQL_INJECTION.vulnerable_code, SQL_INJECTION.test_code) |
| vul_rate = float(r.return_value) if r.return_value is not None else 0.0 |
| print(f" SQL injection vulnerable pass_rate = {vul_rate}") |
| |
| r = engine.execute_tests(SQL_INJECTION.patched_code, SQL_INJECTION.test_code) |
| pat_rate = float(r.return_value) if r.return_value is not None else 0.0 |
| print(f" SQL injection patched pass_rate = {pat_rate}") |
| assert pat_rate >= 0.9, f"Patched should pass, got {pat_rate}" |
|
|
| run_test("sql_injection_tests", t4) |
|
|
| |
| def t5(): |
| engine = SandboxEngine() |
| r = engine.execute_tests(XSS_STORED.patched_code, XSS_STORED.test_code) |
| rate = float(r.return_value) if r.return_value is not None else 0.0 |
| print(f" XSS patched pass_rate = {rate}") |
| assert rate >= 0.9 |
|
|
| run_test("xss_tests", t5) |
|
|
| |
| def t6(): |
| engine = SandboxEngine() |
| r = engine.execute_tests(PATH_TRAVERSAL.patched_code, PATH_TRAVERSAL.test_code) |
| rate = float(r.return_value) if r.return_value is not None else 0.0 |
| print(f" Path traversal patched pass_rate = {rate}") |
| assert rate >= 0.9 |
|
|
| run_test("path_traversal_tests", t6) |
|
|
| |
| def t7(): |
| l1 = get_scenarios_for_level(1) |
| l2 = get_scenarios_for_level(2) |
| l3 = get_scenarios_for_level(3) |
| assert len(l1) == 1 |
| assert len(l2) == 2 |
| assert len(l3) == 3 |
|
|
| run_test("curriculum_levels", t7) |
|
|
| |
| def t8(): |
| |
| state = SecureHealState( |
| episode_id="t", step_count=1, |
| vulnerability_present=True, exploit_possible=True, |
| patch_applied=False, test_pass_rate=0.0, |
| system_stability=0.3, cascading_failures=["x"], |
| ) |
| b = compute_reward(state) |
| print(f" Bad state reward: {b.total:.3f}") |
| assert b.total < 0 |
| |
| state.patch_applied = True |
| state.exploit_possible = False |
| state.test_pass_rate = 1.0 |
| state.system_stability = 1.0 |
| state.cascading_failures = [] |
| b = compute_reward(state) |
| print(f" Good state reward: {b.total:.3f}") |
| assert b.total > 0.5 |
|
|
| run_test("reward_computation", t8) |
|
|
| |
| def t9(): |
| try: |
| from secureheal_arena.server.secureheal_environment import SecureHealEnvironment |
| except ImportError as e: |
| print(f" Skipping full episode (openenv not installed): {e}") |
| return |
|
|
| env = SecureHealEnvironment(curriculum_level=1) |
| obs = env.reset(seed=42) |
| print(f" Reset done={obs.done}") |
| assert not obs.done |
|
|
| r = env._handle_scan_code() |
| print(f" scan_code: vuln_found={r.get('vulnerability_found')}") |
| assert r['vulnerability_found'] |
|
|
| r = env._handle_simulate_attack() |
| print(f" simulate_attack: exploit={r.get('exploit_succeeded')}") |
|
|
| r = env._handle_apply_patch(SQL_INJECTION.patched_code) |
| print(f" apply_patch: applied={r.get('patch_applied')}") |
| assert r['patch_applied'] |
|
|
| r = env._handle_run_tests() |
| print(f" run_tests: pass_rate={r.get('pass_rate')}") |
| assert r['pass_rate'] >= 0.9 |
|
|
| r = env._handle_classify_issue("memory_spike") |
| print(f" classify_issue: correct={r.get('correct')}") |
| assert r['correct'] |
|
|
| r = env._handle_restart_service("auth-service") |
| print(f" restart_service: {r.get('result')}") |
|
|
| r = env._handle_reallocate_resources() |
| print(f" reallocate: stability={r.get('stability')}") |
|
|
| s = env.state |
| print(f" Final: patch={s.patch_applied} test_rate={s.test_pass_rate} stability={s.system_stability:.2f} reward={s.total_reward:.3f}") |
|
|
| run_test("full_episode_loop", t9) |
|
|
| |
| passed = sum(1 for _, ok, _ in results if ok) |
| failed = sum(1 for _, ok, _ in results if not ok) |
| print(f"\n{'='*50}") |
| print(f"RESULTS: {passed} passed, {failed} failed out of {len(results)} tests") |
| print(f"{'='*50}") |
|
|
| sys.stdout.close() |
|
|
| |
| with open(output_file + '.exitcode', 'w') as f: |
| f.write(str(1 if failed > 0 else 0)) |
|
|