from __future__ import annotations from pathlib import Path import torch from train.run_rlbench_experiment import _load_init_checkpoint class _TinyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.zeros(2, 2)) self.bias = torch.nn.Parameter(torch.zeros(2)) def test_load_init_checkpoint_skips_shape_mismatch_when_not_strict(tmp_path: Path) -> None: model = _TinyModule() checkpoint_path = tmp_path / "checkpoint.pt" torch.save( { "state_dict": { "weight": torch.ones(3, 3), "bias": torch.full((2,), 5.0), } }, checkpoint_path, ) info = _load_init_checkpoint(model, str(checkpoint_path), strict=False) assert info is not None assert info["loaded_keys"] == 1 assert info["skipped_shape_mismatch_keys"] == ["weight"] assert torch.allclose(model.bias, torch.full((2,), 5.0)) assert torch.allclose(model.weight, torch.zeros(2, 2))