import dataclasses import os import pathlib import pytest os.environ["JAX_PLATFORMS"] = "cpu" from openpi.training import config as _config from . import train @pytest.mark.parametrize("config_name", ["debug"]) def test_train(tmp_path: pathlib.Path, config_name: str): config = dataclasses.replace( _config._CONFIGS_DICT[config_name], # noqa: SLF001 batch_size=2, checkpoint_base_dir=str(tmp_path / "checkpoint"), exp_name="test", overwrite=False, resume=False, num_train_steps=2, log_interval=1, ) train.main(config) # test resuming config = dataclasses.replace(config, resume=True, num_train_steps=4) train.main(config)