File size: 1,044 Bytes
9c74dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import annotations

from pathlib import Path

import torch

from train.run_rlbench_experiment import _load_init_checkpoint


class _TinyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(torch.zeros(2, 2))
        self.bias = torch.nn.Parameter(torch.zeros(2))


def test_load_init_checkpoint_skips_shape_mismatch_when_not_strict(tmp_path: Path) -> None:
    model = _TinyModule()
    checkpoint_path = tmp_path / "checkpoint.pt"
    torch.save(
        {
            "state_dict": {
                "weight": torch.ones(3, 3),
                "bias": torch.full((2,), 5.0),
            }
        },
        checkpoint_path,
    )

    info = _load_init_checkpoint(model, str(checkpoint_path), strict=False)

    assert info is not None
    assert info["loaded_keys"] == 1
    assert info["skipped_shape_mismatch_keys"] == ["weight"]
    assert torch.allclose(model.bias, torch.full((2,), 5.0))
    assert torch.allclose(model.weight, torch.zeros(2, 2))