|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import pytest |
|
|
|
|
|
from tests.fixtures.constants import DUMMY_REPO_ID |
|
|
from tests.utils import require_package |
|
|
|
|
|
|
|
|
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: |
|
|
for f, r in finds_and_replaces: |
|
|
assert f in text |
|
|
text = text.replace(f, r) |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
def _run_script(path): |
|
|
subprocess.run([sys.executable, path], check=True) |
|
|
|
|
|
|
|
|
def _read_file(path): |
|
|
with open(path) as file: |
|
|
return file.read() |
|
|
|
|
|
|
|
|
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") |
|
|
def test_example_1(tmp_path, lerobot_dataset_factory): |
|
|
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID) |
|
|
path = "examples/1_load_lerobot_dataset.py" |
|
|
file_contents = _read_file(path) |
|
|
file_contents = _find_and_replace( |
|
|
file_contents, |
|
|
[ |
|
|
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'), |
|
|
( |
|
|
"LeRobotDataset(repo_id", |
|
|
f"LeRobotDataset(repo_id, root='{str(tmp_path)}'", |
|
|
), |
|
|
], |
|
|
) |
|
|
exec(file_contents, {}) |
|
|
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists() |
|
|
|
|
|
|
|
|
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") |
|
|
@require_package("gym_pusht") |
|
|
def test_examples_basic2_basic3_advanced1(): |
|
|
""" |
|
|
Train a model with example 3, check the outputs. |
|
|
Evaluate the trained model with example 2, check the outputs. |
|
|
Calculate the validation loss with advanced example 1, check the outputs. |
|
|
""" |
|
|
|
|
|
|
|
|
file_contents = _read_file("examples/3_train_policy.py") |
|
|
|
|
|
|
|
|
file_contents = _find_and_replace( |
|
|
file_contents, |
|
|
[ |
|
|
("training_steps = 5000", "training_steps = 1"), |
|
|
("num_workers=4", "num_workers=0"), |
|
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'), |
|
|
("batch_size=64", "batch_size=1"), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
exec(file_contents, {}) |
|
|
|
|
|
for file_name in ["model.safetensors", "config.json"]: |
|
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() |
|
|
|
|
|
|
|
|
file_contents = _read_file("examples/2_evaluate_pretrained_policy.py") |
|
|
|
|
|
|
|
|
file_contents = _find_and_replace( |
|
|
file_contents, |
|
|
[ |
|
|
( |
|
|
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', |
|
|
"", |
|
|
), |
|
|
( |
|
|
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', |
|
|
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', |
|
|
), |
|
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'), |
|
|
("step += 1", "break"), |
|
|
], |
|
|
) |
|
|
|
|
|
exec(file_contents, {}) |
|
|
|
|
|
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() |
|
|
|
|
|
|
|
|
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py") |
|
|
|
|
|
|
|
|
file_contents = _find_and_replace( |
|
|
file_contents, |
|
|
[ |
|
|
( |
|
|
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', |
|
|
"", |
|
|
), |
|
|
( |
|
|
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', |
|
|
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', |
|
|
), |
|
|
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"), |
|
|
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"), |
|
|
("num_workers=4", "num_workers=0"), |
|
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'), |
|
|
("batch_size=64", "batch_size=1"), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
output_buffer = io.StringIO() |
|
|
sys.stdout = output_buffer |
|
|
exec(file_contents, {}) |
|
|
printed_output = output_buffer.getvalue() |
|
|
|
|
|
sys.stdout = sys.__stdout__ |
|
|
assert "Average loss on validation set" in printed_output |
|
|
|