s-shah4
Add V1 env
4433dc8
from __future__ import annotations
from env.adapt_env import AdaptEnvironment
from models import AdaptAction, AdaptObservation
def assert_hidden_tests_are_not_exposed(payload: dict) -> None:
text = str(payload)
assert "hidden_tests" not in text
assert "-1000000000" not in text
def main() -> None:
env = AdaptEnvironment()
observation = env.reset()
assert isinstance(observation, AdaptObservation)
assert observation.visible_tests
assert observation.problem_id == "easy_double"
assert_hidden_tests_are_not_exposed(observation.model_dump())
correct = env.step(AdaptAction(code="n=int(input())\nprint(n*2)"))
print(correct)
assert correct.reward == 1.0, correct.model_dump()
assert correct.pass_rate == 1.0
wrong = env.step(AdaptAction(code="n=int(input())\nprint(n+2)"))
print(wrong)
assert 0.0 <= float(wrong.reward) < 1.0
assert wrong.pass_rate < 1.0
assert "Failed" in wrong.feedback
syntax = env.step(AdaptAction(code="def broken(:\n pass"))
print(syntax)
assert syntax.reward == 0.0
assert syntax.execution_status == "syntax_error"
timeout = env.step(AdaptAction(code="while True:\n pass"))
print(timeout)
assert timeout.timeout_count > 0
assert timeout.execution_status == "timeout"
unsafe = env.step(AdaptAction(code="import os\nprint(os.listdir('.'))"))
print(unsafe)
assert unsafe.reward == 0.0
assert unsafe.execution_status == "safety_violation"
assert env.state.step_count == 5
assert_hidden_tests_are_not_exposed(timeout.model_dump())
print("ADAPT OpenEnv smoke tests passed")
if __name__ == "__main__":
main()