VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_anybimanual_resume_logic.py
lsnu's picture
Add files using upload-large-folder tool
b14c4b7 verified
import os
import tempfile
from yarr.runners.weight_init_utils import resolve_initial_weight_state
def _make_weight_dir(root, step):
os.makedirs(os.path.join(root, str(step)), exist_ok=True)
def test_anybimanual_prefers_existing_weights_when_resume_enabled():
with tempfile.TemporaryDirectory() as weightsdir, tempfile.TemporaryDirectory() as pretrained_dir:
_make_weight_dir(weightsdir, 0)
_make_weight_dir(weightsdir, 200)
start_iter, load_dir, load_source = resolve_initial_weight_state(
weightsdir=weightsdir,
load_existing_weights=True,
anybimanual=True,
pretrained_dir=pretrained_dir,
)
assert start_iter == 200
assert load_dir == os.path.join(weightsdir, "200")
assert load_source == "resume"
def test_anybimanual_falls_back_to_pretrained_without_existing_weights():
with tempfile.TemporaryDirectory() as weightsdir, tempfile.TemporaryDirectory() as pretrained_dir:
start_iter, load_dir, load_source = resolve_initial_weight_state(
weightsdir=weightsdir,
load_existing_weights=True,
anybimanual=True,
pretrained_dir=pretrained_dir,
)
assert start_iter == 0
assert load_dir == pretrained_dir
assert load_source == "pretrained"
def test_non_anybimanual_with_no_resume_starts_fresh():
with tempfile.TemporaryDirectory() as weightsdir:
_make_weight_dir(weightsdir, 0)
_make_weight_dir(weightsdir, 200)
start_iter, load_dir, load_source = resolve_initial_weight_state(
weightsdir=weightsdir,
load_existing_weights=False,
anybimanual=False,
)
assert start_iter == 0
assert load_dir is None
assert load_source is None
def test_anybimanual_raises_when_pretrained_dir_missing():
with tempfile.TemporaryDirectory() as weightsdir:
missing_pretrained = os.path.join(weightsdir, "missing_pretrained")
try:
resolve_initial_weight_state(
weightsdir=weightsdir,
load_existing_weights=False,
anybimanual=True,
pretrained_dir=missing_pretrained,
)
raised = False
except FileNotFoundError:
raised = True
assert raised