File size: 3,746 Bytes
b1603b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import ast
import time

import pytest

from shutdown_gym.restricted_python import (
    EXEC_TIMEOUT_SECONDS,
    MAX_CODE_CHARS,
    RestrictedPythonError,
    exec_restricted,
    validate_ast,
)

# ---- AST validation: allowed ----


@pytest.mark.parametrize(
    "code",
    [
        "_result = 1 + 2",
        "_result = (1 < 2)",
        "_result = [1, 2, 3]",
        "_result = {'a': 1}",
        "_result = (1, 2)",
        "_result = {1, 2, 3}",
        "x = 0\nfor i in range(5):\n    x = x + i\n_result = x",
        "i = 0\nwhile i < 3:\n    i = i + 1\n_result = i",
        "_result = str(42)",
        "_result = float(1)",
        "_result = sum(range(5))",
        "_result = len([1, 2, 3])",
        "_result = abs(-7)",
        "_result = max(1, 2, 3)",
        "_result = min(1, 2, 3)",
        "_result = sorted([3, 1, 2])",
        "_result = -5",
        "_result = not False",
    ],
)
def test_validate_ast_allowed(code):
    tree = ast.parse(code, mode="exec")
    validate_ast(tree)


# ---- AST validation: forbidden node types ----


@pytest.mark.parametrize(
    "code,fragment",
    [
        ("import os", "Import"),
        ("from os import path", "ImportFrom"),
        ("a = b.c", "Attribute"),
        ("lambda x: x", "Lambda"),
        ("def f():\n    pass", "FunctionDef"),
        ("class X:\n    pass", "ClassDef"),
        ("[x for x in range(5)]", "ListComp"),
        ("try:\n    pass\nexcept Exception:\n    pass", "Try"),
        ("with open('x') as f:\n    pass", "With"),
        ("yield 1", "Yield"),
    ],
)
def test_validate_ast_disallowed_node(code, fragment):
    tree = ast.parse(code, mode="exec")
    with pytest.raises(RestrictedPythonError) as exc_info:
        validate_ast(tree)
    assert fragment in str(exc_info.value)


# ---- AST validation: forbidden builtin/free-name calls ----


@pytest.mark.parametrize(
    "code",
    [
        "open('/tmp/x')",
        "exec('print(1)')",
        "eval('1+1')",
        "getattr({}, 'keys')",
        "__import__('os')",
        "print('hi')",
    ],
)
def test_validate_ast_disallowed_call(code):
    tree = ast.parse(code, mode="exec")
    with pytest.raises(RestrictedPythonError) as exc_info:
        validate_ast(tree)
    assert "Disallowed call" in str(exc_info.value)


# ---- end-to-end via exec_restricted (subprocess) ----


def test_exec_returns_result_value():
    assert exec_restricted("_result = 1 + 2") == "3"


def test_exec_default_result_when_underscore_result_unset():
    assert exec_restricted("x = 1") == "OK"


def test_exec_oversized_code_rejected_without_subprocess():
    huge = "x = 1\n" * MAX_CODE_CHARS  # ~6× MAX_CODE_CHARS
    out = exec_restricted(huge)
    assert "exceeds" in out
    assert str(MAX_CODE_CHARS) in out


def test_exec_syntax_error_returned_as_error():
    out = exec_restricted("1 +")
    assert out.startswith("Error:")


def test_exec_disallowed_node_returned_as_error():
    out = exec_restricted("import os")
    assert "Disallowed" in out


def test_exec_disallowed_attribute_blocked_end_to_end():
    """Attribute access blocks at exec_restricted level — important
    because it's the easiest path back to the parent process via
    ``__builtins__.__import__`` etc.
    """
    out = exec_restricted("x = 1\n_result = x.real")
    assert "Disallowed" in out


def test_exec_timeout_fires_on_busy_loop():
    """Busy ``while True`` must hit the subprocess timeout, not pin the
    test process.
    """
    start = time.monotonic()
    out = exec_restricted("while True:\n    x = 1")
    elapsed = time.monotonic() - start
    assert "timed out" in out
    # Spawn-context process startup is slow on macOS; allow generous slack.
    assert elapsed < EXEC_TIMEOUT_SECONDS + 5.0