File size: 1,678 Bytes
4433dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from __future__ import annotations

from env.adapt_env import AdaptEnvironment
from models import AdaptAction, AdaptObservation


def assert_hidden_tests_are_not_exposed(payload: dict) -> None:
    text = str(payload)
    assert "hidden_tests" not in text
    assert "-1000000000" not in text


def main() -> None:
    env = AdaptEnvironment()
    observation = env.reset()
    assert isinstance(observation, AdaptObservation)
    assert observation.visible_tests
    assert observation.problem_id == "easy_double"
    assert_hidden_tests_are_not_exposed(observation.model_dump())

    correct = env.step(AdaptAction(code="n=int(input())\nprint(n*2)"))
    print(correct)
    assert correct.reward == 1.0, correct.model_dump()
    assert correct.pass_rate == 1.0

    wrong = env.step(AdaptAction(code="n=int(input())\nprint(n+2)"))
    print(wrong)
    assert 0.0 <= float(wrong.reward) < 1.0
    assert wrong.pass_rate < 1.0
    assert "Failed" in wrong.feedback

    syntax = env.step(AdaptAction(code="def broken(:\n    pass"))
    print(syntax)
    assert syntax.reward == 0.0
    assert syntax.execution_status == "syntax_error"

    timeout = env.step(AdaptAction(code="while True:\n    pass"))
    print(timeout)
    assert timeout.timeout_count > 0
    assert timeout.execution_status == "timeout"

    unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
    print(unsafe)
    assert unsafe.reward == 0.0
    assert unsafe.execution_status == "safety_violation"

    assert env.state.step_count == 5
    assert_hidden_tests_are_not_exposed(timeout.model_dump())
    print("ADAPT OpenEnv smoke tests passed")


if __name__ == "__main__":
    main()