| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import tempfile |
| from pathlib import Path |
|
|
| import multistorageclient as msc |
|
|
| from nemo.lightning.resume import AutoResume |
|
|
|
|
| def test_auto_resume_get_weights_path(): |
| auto_resume = AutoResume() |
| assert auto_resume.get_weights_path(Path("test/checkpoints")) == Path("test/checkpoints/weights") |
| assert auto_resume.get_weights_path(msc.Path("msc://default/tmp/test/checkpoints")) == msc.Path( |
| "msc://default/tmp/test/checkpoints/weights" |
| ) |
|
|
|
|
| def test_auto_resume_get_context_path(): |
| auto_resume = AutoResume() |
|
|
| auto_resume.resume_if_exists = False |
| assert auto_resume.get_context_path() is None |
|
|
| auto_resume.resume_if_exists = True |
| assert auto_resume.get_context_path() is None |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "weights")) |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "context")) |
| auto_resume.resume_from_directory = os.path.join(tmpdir, "checkpoints") |
| assert str(auto_resume.get_context_path()) == os.path.join( |
| tmpdir, "checkpoints", "step=10-epoch=0-last", "context" |
| ) |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "weights")) |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "context")) |
| auto_resume.resume_from_directory = f"msc://default{tmpdir}/checkpoints" |
| path = auto_resume.get_context_path() |
| assert isinstance(path, msc.Path) |
| assert str(path) == os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "context") |
|
|
|
|
| def test_auto_resume_get_trainer_ckpt_path(): |
| auto_resume = AutoResume() |
|
|
| auto_resume.resume_if_exists = False |
| assert auto_resume.get_trainer_ckpt_path() is None |
|
|
| auto_resume.resume_if_exists = True |
| assert auto_resume.get_trainer_ckpt_path() is None |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "weights")) |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "context")) |
| auto_resume.resume_from_path = os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last") |
| assert str(auto_resume.get_trainer_ckpt_path()) == os.path.join( |
| tmpdir, "checkpoints", "step=10-epoch=0-last", "weights" |
| ) |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "weights")) |
| os.makedirs(os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "context")) |
| auto_resume.resume_from_path = f"msc://default{tmpdir}/checkpoints/step=10-epoch=0-last" |
| path = auto_resume.get_trainer_ckpt_path() |
| assert isinstance(path, msc.Path) |
| assert str(path) == os.path.join(tmpdir, "checkpoints", "step=10-epoch=0-last", "weights") |
|
|