secureheal-trainer / tests /quick_test.py
Ravindra S
Initial commit: SecureHeal Arena Environment and Training Pipeline
366cf53
"""Quick test runner that writes output to a file for Windows compatibility."""
import sys, os
# Redirect output to file
output_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'test_results.log')
sys.stdout = open(output_file, 'w', encoding='utf-8')
sys.stderr = sys.stdout
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from secureheal_arena.sandbox import SandboxEngine
from secureheal_arena.vulnerabilities import SQL_INJECTION, XSS_STORED, PATH_TRAVERSAL, get_scenarios_for_level
from secureheal_arena.anomalies import MEMORY_SPIKE, select_anomaly
from secureheal_arena.rewards import compute_reward
from secureheal_arena.models import SecureHealState
import random
results = []
def run_test(name, fn):
try:
fn()
results.append((name, True, ""))
print(f"[PASS] {name}")
except Exception as e:
results.append((name, False, str(e)))
print(f"[FAIL] {name}: {e}")
import traceback
traceback.print_exc()
# Test 1: Sandbox basic
def t1():
engine = SandboxEngine()
result = engine.execute('x = 1 + 1\nprint(x)')
assert result.success, f"Sandbox failed: {result.error_message}"
assert '2' in result.stdout
run_test("sandbox_basic", t1)
# Test 2: Sandbox timeout (skip on Windows — tight loops hold the GIL)
def t2():
if sys.platform == 'win32':
print(" [SKIP] Timeout test skipped on Windows (GIL limitation)")
print(" Works correctly on Linux/HF Spaces via SIGALRM")
return
engine = SandboxEngine(timeout=1)
result = engine.execute('while True: pass')
assert result.timeout, "Timeout should have triggered"
run_test("sandbox_timeout", t2)
# Test 3: Sandbox forbidden ops
def t3():
engine = SandboxEngine()
for code in ['import os', 'import subprocess', 'from os import path', 'eval("1+1")', '__import__("os")']:
result = engine.execute(code)
assert result.forbidden_op, f"Should have blocked: {code}"
run_test("sandbox_forbidden_ops", t3)
# Test 4: SQL injection vulnerable fails, patched passes
def t4():
engine = SandboxEngine()
# Vulnerable should fail
r = engine.execute_tests(SQL_INJECTION.vulnerable_code, SQL_INJECTION.test_code)
vul_rate = float(r.return_value) if r.return_value is not None else 0.0
print(f" SQL injection vulnerable pass_rate = {vul_rate}")
# Patched should pass
r = engine.execute_tests(SQL_INJECTION.patched_code, SQL_INJECTION.test_code)
pat_rate = float(r.return_value) if r.return_value is not None else 0.0
print(f" SQL injection patched pass_rate = {pat_rate}")
assert pat_rate >= 0.9, f"Patched should pass, got {pat_rate}"
run_test("sql_injection_tests", t4)
# Test 5: XSS tests
def t5():
engine = SandboxEngine()
r = engine.execute_tests(XSS_STORED.patched_code, XSS_STORED.test_code)
rate = float(r.return_value) if r.return_value is not None else 0.0
print(f" XSS patched pass_rate = {rate}")
assert rate >= 0.9
run_test("xss_tests", t5)
# Test 6: Path traversal tests
def t6():
engine = SandboxEngine()
r = engine.execute_tests(PATH_TRAVERSAL.patched_code, PATH_TRAVERSAL.test_code)
rate = float(r.return_value) if r.return_value is not None else 0.0
print(f" Path traversal patched pass_rate = {rate}")
assert rate >= 0.9
run_test("path_traversal_tests", t6)
# Test 7: Curriculum levels
def t7():
l1 = get_scenarios_for_level(1)
l2 = get_scenarios_for_level(2)
l3 = get_scenarios_for_level(3)
assert len(l1) == 1
assert len(l2) == 2
assert len(l3) == 3
run_test("curriculum_levels", t7)
# Test 8: Reward computation
def t8():
# Bad state
state = SecureHealState(
episode_id="t", step_count=1,
vulnerability_present=True, exploit_possible=True,
patch_applied=False, test_pass_rate=0.0,
system_stability=0.3, cascading_failures=["x"],
)
b = compute_reward(state)
print(f" Bad state reward: {b.total:.3f}")
assert b.total < 0
# Good state
state.patch_applied = True
state.exploit_possible = False
state.test_pass_rate = 1.0
state.system_stability = 1.0
state.cascading_failures = []
b = compute_reward(state)
print(f" Good state reward: {b.total:.3f}")
assert b.total > 0.5
run_test("reward_computation", t8)
# Test 9: Full episode (needs openenv)
def t9():
try:
from secureheal_arena.server.secureheal_environment import SecureHealEnvironment
except ImportError as e:
print(f" Skipping full episode (openenv not installed): {e}")
return
env = SecureHealEnvironment(curriculum_level=1)
obs = env.reset(seed=42)
print(f" Reset done={obs.done}")
assert not obs.done
r = env._handle_scan_code()
print(f" scan_code: vuln_found={r.get('vulnerability_found')}")
assert r['vulnerability_found']
r = env._handle_simulate_attack()
print(f" simulate_attack: exploit={r.get('exploit_succeeded')}")
r = env._handle_apply_patch(SQL_INJECTION.patched_code)
print(f" apply_patch: applied={r.get('patch_applied')}")
assert r['patch_applied']
r = env._handle_run_tests()
print(f" run_tests: pass_rate={r.get('pass_rate')}")
assert r['pass_rate'] >= 0.9
r = env._handle_classify_issue("memory_spike")
print(f" classify_issue: correct={r.get('correct')}")
assert r['correct']
r = env._handle_restart_service("auth-service")
print(f" restart_service: {r.get('result')}")
r = env._handle_reallocate_resources()
print(f" reallocate: stability={r.get('stability')}")
s = env.state
print(f" Final: patch={s.patch_applied} test_rate={s.test_pass_rate} stability={s.system_stability:.2f} reward={s.total_reward:.3f}")
run_test("full_episode_loop", t9)
# Summary
passed = sum(1 for _, ok, _ in results if ok)
failed = sum(1 for _, ok, _ in results if not ok)
print(f"\n{'='*50}")
print(f"RESULTS: {passed} passed, {failed} failed out of {len(results)} tests")
print(f"{'='*50}")
sys.stdout.close()
# Also write exit code
with open(output_file + '.exitcode', 'w') as f:
f.write(str(1 if failed > 0 else 0))