GR00T / tests /examples /test_droid.py
yqi19's picture
add: source files (batch 4)
b88b79e verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
import pathlib
import subprocess
import pytest
from test_support.readme import extract_code_blocks, find_block, replace_once, run_bash_blocks
from test_support.runtime import (
assert_port_available,
get_root,
start_server_process,
timed,
wait_for_server_ready,
)
logger = logging.getLogger(__name__)
REPO_ROOT = get_root()
TRAINING_STEPS = 2
README = REPO_ROOT / "examples/DROID/README.md"
MODEL_CHECKPOINT = pathlib.Path(f"/tmp/droid_finetune/checkpoint-{TRAINING_STEPS}")
DEFAULT_SERVER_STARTUP_SECONDS = 900.0
@pytest.mark.gpu
@pytest.mark.timeout(1800)
@pytest.mark.parametrize(
"occurrence",
[1, 2],
ids=["base", "finetuned"],
)
def test_droid_readme_server_starts(occurrence: int) -> None:
"""Verify the DROID inference server starts and accepts connections."""
env = {**os.environ}
blocks = extract_code_blocks(README)
model_server_host = "127.0.0.1"
model_server_port = 5557
server_code = find_block(
blocks, "run_gr00t_server.py", language="bash", occurrence=occurrence
).code
server_code += f" --device cuda:0 --host {model_server_host} --port {model_server_port}"
assert_port_available(model_server_host, model_server_port)
model_server_proc, server_log = start_server_process(server_code, cwd=REPO_ROOT, env=env)
try:
wait_for_server_ready(
proc=model_server_proc,
host=model_server_host,
port=model_server_port,
timeout_s=float(
os.getenv("DROID_SERVER_STARTUP_SECONDS", str(DEFAULT_SERVER_STARTUP_SECONDS))
),
server_log=server_log,
)
finally:
if model_server_proc.poll() is None:
model_server_proc.terminate()
try:
model_server_proc.wait(timeout=15)
except subprocess.TimeoutExpired:
model_server_proc.kill()
model_server_proc.wait(timeout=15)
@pytest.mark.gpu
@pytest.mark.timeout(1800)
def test_droid_finetune_and_finetuned_server() -> None:
"""Run a short DROID finetune, then verify server starts with the finetuned checkpoint."""
env = {**os.environ}
blocks = extract_code_blocks(README)
finetune_code = replace_once(
replace_once(
replace_once(
replace_once(
find_block(blocks, "--output-dir /tmp/droid_finetune", language="bash").code,
"NUM_GPUS=8",
"NUM_GPUS=1",
),
"MAX_STEPS=20000",
f"MAX_STEPS={TRAINING_STEPS}",
),
"SAVE_STEPS=1000",
f"SAVE_STEPS={TRAINING_STEPS}",
),
"GLOBAL_BATCH_SIZE=640",
"GLOBAL_BATCH_SIZE=2",
)
finetune_code = finetune_code.rstrip() + " -- --skip_weight_loading"
run_bash_blocks(
[finetune_code],
cwd=REPO_ROOT,
env={
**env,
"USE_WANDB": "0",
"DATALOADER_NUM_WORKERS": "0",
"SHARD_SIZE": "64",
"NUM_SHARDS_PER_EPOCH": "1",
},
)
assert MODEL_CHECKPOINT.exists(), (
f"Expected model checkpoint after finetune: {MODEL_CHECKPOINT}"
)
model_server_host = "127.0.0.1"
model_server_port = 5558
server_code = replace_once(
find_block(blocks, "nvidia/GR00T-N1.7-DROID", language="bash").code,
"nvidia/GR00T-N1.7-DROID",
str(MODEL_CHECKPOINT),
)
server_code += f" --device cuda:0 --host {model_server_host} --port {model_server_port}"
assert_port_available(model_server_host, model_server_port)
model_server_proc, server_log = start_server_process(server_code, cwd=REPO_ROOT, env=env)
try:
with timed("finetuned server startup"):
wait_for_server_ready(
proc=model_server_proc,
host=model_server_host,
port=model_server_port,
timeout_s=float(
os.getenv("DROID_SERVER_STARTUP_SECONDS", str(DEFAULT_SERVER_STARTUP_SECONDS))
),
server_log=server_log,
)
finally:
if model_server_proc.poll() is None:
model_server_proc.terminate()
try:
model_server_proc.wait(timeout=15)
except subprocess.TimeoutExpired:
model_server_proc.kill()
model_server_proc.wait(timeout=15)