| from __future__ import annotations | |
| from pathlib import Path | |
| import torch | |
| from train.run_rlbench_experiment import _load_init_checkpoint | |
| class _TinyModule(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.weight = torch.nn.Parameter(torch.zeros(2, 2)) | |
| self.bias = torch.nn.Parameter(torch.zeros(2)) | |
| def test_load_init_checkpoint_skips_shape_mismatch_when_not_strict(tmp_path: Path) -> None: | |
| model = _TinyModule() | |
| checkpoint_path = tmp_path / "checkpoint.pt" | |
| torch.save( | |
| { | |
| "state_dict": { | |
| "weight": torch.ones(3, 3), | |
| "bias": torch.full((2,), 5.0), | |
| } | |
| }, | |
| checkpoint_path, | |
| ) | |
| info = _load_init_checkpoint(model, str(checkpoint_path), strict=False) | |
| assert info is not None | |
| assert info["loaded_keys"] == 1 | |
| assert info["skipped_shape_mismatch_keys"] == ["weight"] | |
| assert torch.allclose(model.bias, torch.full((2,), 5.0)) | |
| assert torch.allclose(model.weight, torch.zeros(2, 2)) | |