|
|
import dataclasses |
|
|
import os |
|
|
import pathlib |
|
|
|
|
|
import pytest |
|
|
|
|
|
os.environ["JAX_PLATFORMS"] = "cpu" |
|
|
|
|
|
from openpi.training import config as _config |
|
|
|
|
|
from . import train |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("config_name", ["debug"]) |
|
|
def test_train(tmp_path: pathlib.Path, config_name: str): |
|
|
config = dataclasses.replace( |
|
|
_config._CONFIGS_DICT[config_name], |
|
|
batch_size=2, |
|
|
checkpoint_base_dir=str(tmp_path / "checkpoint"), |
|
|
exp_name="test", |
|
|
overwrite=False, |
|
|
resume=False, |
|
|
num_train_steps=2, |
|
|
log_interval=1, |
|
|
) |
|
|
train.main(config) |
|
|
|
|
|
|
|
|
config = dataclasses.replace(config, resume=True, num_train_steps=4) |
|
|
train.main(config) |
|
|
|