| | |
| | import os |
| | import tempfile |
| | import unittest |
| | from collections import OrderedDict |
| | import torch |
| | from iopath.common.file_io import PathHandler, PathManager |
| | from torch import nn |
| |
|
| | from detectron2.checkpoint import DetectionCheckpointer |
| | from detectron2.checkpoint.c2_model_loading import ( |
| | _longest_common_prefix_str, |
| | align_and_update_state_dicts, |
| | ) |
| | from detectron2.utils.logger import setup_logger |
| |
|
| |
|
| | class TestCheckpointer(unittest.TestCase): |
| | def setUp(self): |
| | setup_logger() |
| |
|
| | def create_complex_model(self): |
| | m = nn.Module() |
| | m.block1 = nn.Module() |
| | m.block1.layer1 = nn.Linear(2, 3) |
| | m.layer2 = nn.Linear(3, 2) |
| | m.res = nn.Module() |
| | m.res.layer2 = nn.Linear(3, 2) |
| |
|
| | state_dict = OrderedDict() |
| | state_dict["layer1.weight"] = torch.rand(3, 2) |
| | state_dict["layer1.bias"] = torch.rand(3) |
| | state_dict["layer2.weight"] = torch.rand(2, 3) |
| | state_dict["layer2.bias"] = torch.rand(2) |
| | state_dict["res.layer2.weight"] = torch.rand(2, 3) |
| | state_dict["res.layer2.bias"] = torch.rand(2) |
| | return m, state_dict |
| |
|
| | def test_complex_model_loaded(self): |
| | for add_data_parallel in [False, True]: |
| | model, state_dict = self.create_complex_model() |
| | if add_data_parallel: |
| | model = nn.DataParallel(model) |
| | model_sd = model.state_dict() |
| |
|
| | sd_to_load = align_and_update_state_dicts(model_sd, state_dict) |
| | model.load_state_dict(sd_to_load) |
| | for loaded, stored in zip(model_sd.values(), state_dict.values()): |
| | |
| | self.assertFalse(id(loaded) == id(stored)) |
| | |
| | self.assertTrue(loaded.to(stored).equal(stored)) |
| |
|
| | def test_load_with_matching_heuristics(self): |
| | with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: |
| | model, state_dict = self.create_complex_model() |
| | torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) |
| | checkpointer = DetectionCheckpointer(model, save_dir=d) |
| |
|
| | with torch.no_grad(): |
| | |
| | model.block1.layer1.weight.fill_(1) |
| |
|
| | |
| | checkpointer.load(os.path.join(d, "checkpoint.pth")) |
| | self.assertTrue(model.block1.layer1.weight.equal(torch.ones(3, 2))) |
| |
|
| | |
| | checkpointer.load(os.path.join(d, "checkpoint.pth?matching_heuristics=True")) |
| | self.assertFalse(model.block1.layer1.weight.equal(torch.ones(3, 2))) |
| |
|
| | def test_custom_path_manager_handler(self): |
| | with tempfile.TemporaryDirectory(prefix="detectron2_test") as d: |
| |
|
| | class CustomPathManagerHandler(PathHandler): |
| | PREFIX = "detectron2_test://" |
| |
|
| | def _get_supported_prefixes(self): |
| | return [self.PREFIX] |
| |
|
| | def _get_local_path(self, path, **kwargs): |
| | name = path[len(self.PREFIX) :] |
| | return os.path.join(d, name) |
| |
|
| | def _open(self, path, mode="r", **kwargs): |
| | return open(self._get_local_path(path), mode, **kwargs) |
| |
|
| | pathmgr = PathManager() |
| | pathmgr.register_handler(CustomPathManagerHandler()) |
| |
|
| | model, state_dict = self.create_complex_model() |
| | torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth")) |
| | checkpointer = DetectionCheckpointer(model, save_dir=d) |
| | checkpointer.path_manager = pathmgr |
| | checkpointer.load("detectron2_test://checkpoint.pth") |
| | checkpointer.load("detectron2_test://checkpoint.pth?matching_heuristics=True") |
| |
|
| | def test_lcp(self): |
| | self.assertEqual(_longest_common_prefix_str(["class", "dlaps_model"]), "") |
| | self.assertEqual(_longest_common_prefix_str(["classA", "classB"]), "class") |
| | self.assertEqual(_longest_common_prefix_str(["classA", "classB", "clab"]), "cla") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|