File size: 1,678 Bytes
4433dc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | from __future__ import annotations
from env.adapt_env import AdaptEnvironment
from models import AdaptAction, AdaptObservation
def assert_hidden_tests_are_not_exposed(payload: dict) -> None:
text = str(payload)
assert "hidden_tests" not in text
assert "-1000000000" not in text
def main() -> None:
env = AdaptEnvironment()
observation = env.reset()
assert isinstance(observation, AdaptObservation)
assert observation.visible_tests
assert observation.problem_id == "easy_double"
assert_hidden_tests_are_not_exposed(observation.model_dump())
correct = env.step(AdaptAction(code="n=int(input())\nprint(n*2)"))
print(correct)
assert correct.reward == 1.0, correct.model_dump()
assert correct.pass_rate == 1.0
wrong = env.step(AdaptAction(code="n=int(input())\nprint(n+2)"))
print(wrong)
assert 0.0 <= float(wrong.reward) < 1.0
assert wrong.pass_rate < 1.0
assert "Failed" in wrong.feedback
syntax = env.step(AdaptAction(code="def broken(:\n pass"))
print(syntax)
assert syntax.reward == 0.0
assert syntax.execution_status == "syntax_error"
timeout = env.step(AdaptAction(code="while True:\n pass"))
print(timeout)
assert timeout.timeout_count > 0
assert timeout.execution_status == "timeout"
unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
print(unsafe)
assert unsafe.reward == 0.0
assert unsafe.execution_status == "safety_violation"
assert env.state.step_count == 5
assert_hidden_tests_are_not_exposed(timeout.model_dump())
print("ADAPT OpenEnv smoke tests passed")
if __name__ == "__main__":
main()
|