| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from pathlib import Path |
| from unittest.mock import Mock, patch |
|
|
| from lerobot.utils.constants import ( |
| CHECKPOINTS_DIR, |
| LAST_CHECKPOINT_LINK, |
| OPTIMIZER_PARAM_GROUPS, |
| OPTIMIZER_STATE, |
| RNG_STATE, |
| SCHEDULER_STATE, |
| TRAINING_STATE_DIR, |
| TRAINING_STEP, |
| ) |
| from lerobot.utils.train_utils import ( |
| get_step_checkpoint_dir, |
| get_step_identifier, |
| load_training_state, |
| load_training_step, |
| save_checkpoint, |
| save_training_state, |
| save_training_step, |
| update_last_checkpoint, |
| ) |
|
|
|
|
| def test_get_step_identifier(): |
| assert get_step_identifier(5, 1000) == "000005" |
| assert get_step_identifier(123, 100_000) == "000123" |
| assert get_step_identifier(456789, 1_000_000) == "0456789" |
|
|
|
|
| def test_get_step_checkpoint_dir(): |
| output_dir = Path("/checkpoints") |
| step_dir = get_step_checkpoint_dir(output_dir, 1000, 5) |
| assert step_dir == output_dir / CHECKPOINTS_DIR / "000005" |
|
|
|
|
| def test_save_load_training_step(tmp_path): |
| save_training_step(5000, tmp_path) |
| assert (tmp_path / TRAINING_STEP).is_file() |
|
|
|
|
| def test_load_training_step(tmp_path): |
| step = 5000 |
| save_training_step(step, tmp_path) |
| loaded_step = load_training_step(tmp_path) |
| assert loaded_step == step |
|
|
|
|
| def test_update_last_checkpoint(tmp_path): |
| checkpoint = tmp_path / "0005" |
| checkpoint.mkdir() |
| update_last_checkpoint(checkpoint) |
| last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK |
| assert last_checkpoint.is_symlink() |
| assert last_checkpoint.resolve() == checkpoint |
|
|
|
|
| @patch("lerobot.utils.train_utils.save_training_state") |
| def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): |
| policy = Mock() |
| cfg = Mock() |
| save_checkpoint(tmp_path, 10, cfg, policy, optimizer) |
| policy.save_pretrained.assert_called_once() |
| cfg.save_pretrained.assert_called_once() |
| mock_save_training_state.assert_called_once() |
|
|
|
|
| @patch("lerobot.utils.train_utils.save_training_state") |
| def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): |
| policy = Mock() |
| policy.config = Mock() |
| policy.config.save_pretrained = Mock() |
| cfg = Mock() |
| cfg.use_peft = True |
| save_checkpoint(tmp_path, 10, cfg, policy, optimizer) |
| policy.save_pretrained.assert_called_once() |
| cfg.save_pretrained.assert_called_once() |
| policy.config.save_pretrained.assert_called_once() |
| mock_save_training_state.assert_called_once() |
|
|
|
|
| def test_save_training_state(tmp_path, optimizer, scheduler): |
| save_training_state(tmp_path, 10, optimizer, scheduler) |
| assert (tmp_path / TRAINING_STATE_DIR).is_dir() |
| assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file() |
| assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file() |
| assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file() |
| assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file() |
| assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file() |
|
|
|
|
| def test_save_load_training_state(tmp_path, optimizer, scheduler): |
| save_training_state(tmp_path, 10, optimizer, scheduler) |
| loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler) |
| assert loaded_step == 10 |
| assert loaded_optimizer is optimizer |
| assert loaded_scheduler is scheduler |
|
|