GR00T / tests /examples /test_simplerenv.py
yqi19's picture
add: source files (batch 4)
b88b79e verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
import platform
import subprocess
import pytest
from test_support.readme import extract_code_blocks, find_block, replace_once, run_bash_blocks
from test_support.runtime import (
DEFAULT_SERVER_STARTUP_SECONDS,
assert_port_available,
get_root,
has_rt_core_gpu,
run_subprocess_step,
start_server_process,
timed,
wait_for_server_ready,
)
logger = logging.getLogger(__name__)
REPO_ROOT = get_root()
README = REPO_ROOT / "examples/SimplerEnv/README.md"
# sapien==2.2.2 (required by ManiSkill2_real2sim) ships x86_64 wheels only.
pytestmark = pytest.mark.skipif(
platform.machine() != "x86_64",
reason="SimplerEnv depends on sapien which has no aarch64 wheels",
)
def _run_simplerenv_eval(
env: dict,
blocks: list,
server_model_key: str,
client_env_name_old: str,
client_env_name_new: str,
server_startup_env_var: str,
) -> None:
"""Shared helper: setup sim, start server, run rollout, assert results."""
# Step 1: Setup sim (shared across both benchmarks)
with timed("step 1: sim venv setup (setup_SimplerEnv.sh)"):
run_bash_blocks(
[find_block(blocks, "setup_SimplerEnv.sh", language="bash")],
cwd=REPO_ROOT,
env=env,
)
model_server_host = "127.0.0.1"
model_server_port = 5559
# Step 2: Server — inject test-specific flags
server_code = find_block(blocks, server_model_key, language="bash").code
server_code += f" --device cuda:0 --host {model_server_host} --port {model_server_port}"
# Step 3: Rollout — substitute test-safe values
rollout_code = replace_once(
replace_once(
replace_once(
replace_once(
replace_once(
find_block(blocks, client_env_name_old, language="bash").code,
"--n-episodes 10",
"--n-episodes 1",
),
"--policy-client-port 5555",
f"--policy-client-port {model_server_port}",
),
"--max-episode-steps 300",
"--max-episode-steps 2",
),
"--n-envs 5",
"--n-envs 1",
),
client_env_name_old,
client_env_name_new,
)
assert_port_available(model_server_host, model_server_port)
model_server_proc, server_log = start_server_process(server_code, cwd=REPO_ROOT, env=env)
with timed("step 2: server startup"):
wait_for_server_ready(
proc=model_server_proc,
host=model_server_host,
port=model_server_port,
timeout_s=float(os.getenv(server_startup_env_var, str(DEFAULT_SERVER_STARTUP_SECONDS))),
server_log=server_log,
)
try:
if has_rt_core_gpu():
with timed("step 3: rollout"):
simulation_result, _ = run_subprocess_step(
["bash", "-c", rollout_code],
step="simplerenv_rollout",
cwd=REPO_ROOT,
env=env,
log_prefix="simplerenv",
failure_prefix="SimplerEnv rollout failed",
output_tail_chars=4000,
)
simulation_output = (simulation_result.stdout or "") + (simulation_result.stderr or "")
assert "results:" in simulation_output, (
"Simulation output did not include expected 'results:' marker.\n"
f"output_tail=\n{simulation_output[-4000:]}"
)
assert "success rate:" in simulation_output, (
"Simulation output did not include expected 'success rate:' marker.\n"
f"output_tail=\n{simulation_output[-4000:]}"
)
finally:
if model_server_proc.poll() is None:
model_server_proc.terminate()
try:
model_server_proc.wait(timeout=15)
except subprocess.TimeoutExpired:
model_server_proc.kill()
model_server_proc.wait(timeout=15)
@pytest.mark.gpu
@pytest.mark.timeout(900)
def test_simplerenv_fractal_readme_eval_flow() -> None:
"""Run the SimplerEnv README server+client eval using the remote fractal (Google robot) checkpoint."""
env = {**os.environ}
blocks = extract_code_blocks(README)
_run_simplerenv_eval(
env=env,
blocks=blocks,
server_model_key="nvidia/GR00T-N1.7-SimplerEnv-Fractal",
client_env_name_old="simpler_env_google/google_robot_pick_coke_can",
client_env_name_new="simpler_env_google/google_robot_pick_coke_can",
server_startup_env_var="SIMPLERENV_SERVER_STARTUP_SECONDS",
)
@pytest.mark.gpu
@pytest.mark.timeout(900)
def test_simplerenv_bridge_readme_eval_flow() -> None:
"""Run the SimplerEnv README server+client eval using the remote bridge (WidowX robot) checkpoint."""
env = {**os.environ}
blocks = extract_code_blocks(README)
_run_simplerenv_eval(
env=env,
blocks=blocks,
server_model_key="nvidia/GR00T-N1.7-SimplerEnv-Bridge",
client_env_name_old="simpler_env_widowx/widowx_spoon_on_towel",
client_env_name_new="simpler_env_widowx/widowx_spoon_on_towel",
server_startup_env_var="SIMPLERENV_SERVER_STARTUP_SECONDS",
)