| import os |
| import tempfile |
|
|
| from yarr.runners.weight_init_utils import resolve_initial_weight_state |
|
|
|
|
| def _make_weight_dir(root, step): |
| os.makedirs(os.path.join(root, str(step)), exist_ok=True) |
|
|
|
|
| def test_anybimanual_prefers_existing_weights_when_resume_enabled(): |
| with tempfile.TemporaryDirectory() as weightsdir, tempfile.TemporaryDirectory() as pretrained_dir: |
| _make_weight_dir(weightsdir, 0) |
| _make_weight_dir(weightsdir, 200) |
|
|
| start_iter, load_dir, load_source = resolve_initial_weight_state( |
| weightsdir=weightsdir, |
| load_existing_weights=True, |
| anybimanual=True, |
| pretrained_dir=pretrained_dir, |
| ) |
|
|
| assert start_iter == 200 |
| assert load_dir == os.path.join(weightsdir, "200") |
| assert load_source == "resume" |
|
|
|
|
| def test_anybimanual_falls_back_to_pretrained_without_existing_weights(): |
| with tempfile.TemporaryDirectory() as weightsdir, tempfile.TemporaryDirectory() as pretrained_dir: |
| start_iter, load_dir, load_source = resolve_initial_weight_state( |
| weightsdir=weightsdir, |
| load_existing_weights=True, |
| anybimanual=True, |
| pretrained_dir=pretrained_dir, |
| ) |
|
|
| assert start_iter == 0 |
| assert load_dir == pretrained_dir |
| assert load_source == "pretrained" |
|
|
|
|
| def test_non_anybimanual_with_no_resume_starts_fresh(): |
| with tempfile.TemporaryDirectory() as weightsdir: |
| _make_weight_dir(weightsdir, 0) |
| _make_weight_dir(weightsdir, 200) |
|
|
| start_iter, load_dir, load_source = resolve_initial_weight_state( |
| weightsdir=weightsdir, |
| load_existing_weights=False, |
| anybimanual=False, |
| ) |
|
|
| assert start_iter == 0 |
| assert load_dir is None |
| assert load_source is None |
|
|
|
|
| def test_anybimanual_raises_when_pretrained_dir_missing(): |
| with tempfile.TemporaryDirectory() as weightsdir: |
| missing_pretrained = os.path.join(weightsdir, "missing_pretrained") |
|
|
| try: |
| resolve_initial_weight_state( |
| weightsdir=weightsdir, |
| load_existing_weights=False, |
| anybimanual=True, |
| pretrained_dir=missing_pretrained, |
| ) |
| raised = False |
| except FileNotFoundError: |
| raised = True |
|
|
| assert raised |
|
|