| import dataclasses | |
| import os | |
| import pathlib | |
| import pytest | |
| os.environ["JAX_PLATFORMS"] = "cpu" | |
| from openpi.training import config as _config | |
| from . import train | |
| def test_train(tmp_path: pathlib.Path, config_name: str): | |
| config = dataclasses.replace( | |
| _config._CONFIGS_DICT[config_name], # noqa: SLF001 | |
| batch_size=2, | |
| checkpoint_base_dir=str(tmp_path / "checkpoint"), | |
| exp_name="test", | |
| overwrite=False, | |
| resume=False, | |
| num_train_steps=2, | |
| log_interval=1, | |
| ) | |
| train.main(config) | |
| # test resuming | |
| config = dataclasses.replace(config, resume=True, num_train_steps=4) | |
| train.main(config) | |