VLAarchtests2 / VLAarchtests /tests /test_rlbench_init_checkpoint.py
lsnu's picture
Add files using upload-large-folder tool
9c74dfe verified
from __future__ import annotations
from pathlib import Path
import torch
from train.run_rlbench_experiment import _load_init_checkpoint
class _TinyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.zeros(2, 2))
self.bias = torch.nn.Parameter(torch.zeros(2))
def test_load_init_checkpoint_skips_shape_mismatch_when_not_strict(tmp_path: Path) -> None:
model = _TinyModule()
checkpoint_path = tmp_path / "checkpoint.pt"
torch.save(
{
"state_dict": {
"weight": torch.ones(3, 3),
"bias": torch.full((2,), 5.0),
}
},
checkpoint_path,
)
info = _load_init_checkpoint(model, str(checkpoint_path), strict=False)
assert info is not None
assert info["loaded_keys"] == 1
assert info["skipped_shape_mismatch_keys"] == ["weight"]
assert torch.allclose(model.bias, torch.full((2,), 5.0))
assert torch.allclose(model.weight, torch.zeros(2, 2))