|
|
| import logging
|
| import os
|
| import pickle
|
| from urllib.parse import parse_qs, urlparse
|
| import torch
|
| from fvcore.common.checkpoint import Checkpointer
|
| from torch.nn.parallel import DistributedDataParallel
|
|
|
| import detectron2.utils.comm as comm
|
| from detectron2.utils.file_io import PathManager
|
|
|
| from .c2_model_loading import align_and_update_state_dicts
|
|
|
|
|
| class DetectionCheckpointer(Checkpointer):
|
| """
|
| Same as :class:`Checkpointer`, but is able to:
|
| 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
|
| 2. correctly load checkpoints that are only available on the master worker
|
| """
|
|
|
| def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
|
| is_main_process = comm.is_main_process()
|
| super().__init__(
|
| model,
|
| save_dir,
|
| save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
|
| **checkpointables,
|
| )
|
| self.path_manager = PathManager
|
| self._parsed_url_during_load = None
|
|
|
| def load(self, path, *args, **kwargs):
|
| assert self._parsed_url_during_load is None
|
| need_sync = False
|
| logger = logging.getLogger(__name__)
|
| logger.info("[DetectionCheckpointer] Loading from {} ...".format(path))
|
|
|
| if path and isinstance(self.model, DistributedDataParallel):
|
| path = self.path_manager.get_local_path(path)
|
| has_file = os.path.isfile(path)
|
| all_has_file = comm.all_gather(has_file)
|
| if not all_has_file[0]:
|
| raise OSError(f"File {path} not found on main worker.")
|
| if not all(all_has_file):
|
| logger.warning(
|
| f"Not all workers can read checkpoint {path}. "
|
| "Training may fail to fully resume."
|
| )
|
|
|
|
|
| need_sync = True
|
| if not has_file:
|
| path = None
|
|
|
| if path:
|
| parsed_url = urlparse(path)
|
| self._parsed_url_during_load = parsed_url
|
| path = parsed_url._replace(query="").geturl()
|
| path = self.path_manager.get_local_path(path)
|
| ret = super().load(path, *args, **kwargs)
|
|
|
| if need_sync:
|
| logger.info("Broadcasting model states from main worker ...")
|
| self.model._sync_params_and_buffers()
|
| self._parsed_url_during_load = None
|
| return ret
|
|
|
| def _load_file(self, filename):
|
| if filename.endswith(".pkl"):
|
| with PathManager.open(filename, "rb") as f:
|
| data = pickle.load(f, encoding="latin1")
|
| if "model" in data and "__author__" in data:
|
|
|
| self.logger.info("Reading a file from '{}'".format(data["__author__"]))
|
| return data
|
| else:
|
|
|
| if "blobs" in data:
|
|
|
| data = data["blobs"]
|
| data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
|
| return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
|
| elif filename.endswith(".pyth"):
|
|
|
| with PathManager.open(filename, "rb") as f:
|
| data = torch.load(f)
|
| assert (
|
| "model_state" in data
|
| ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
|
| model_state = {
|
| k: v
|
| for k, v in data["model_state"].items()
|
| if not k.endswith("num_batches_tracked")
|
| }
|
| return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
|
|
|
| loaded = self._torch_load(filename)
|
| if "model" not in loaded:
|
| loaded = {"model": loaded}
|
| assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`"
|
| parsed_url = self._parsed_url_during_load
|
| queries = parse_qs(parsed_url.query)
|
| if queries.pop("matching_heuristics", "False") == ["True"]:
|
| loaded["matching_heuristics"] = True
|
| if len(queries) > 0:
|
| raise ValueError(
|
| f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}"
|
| )
|
| return loaded
|
|
|
| def _torch_load(self, f):
|
| return super()._load_file(f)
|
|
|
| def _load_model(self, checkpoint):
|
| if checkpoint.get("matching_heuristics", False):
|
| self._convert_ndarray_to_tensor(checkpoint["model"])
|
|
|
| checkpoint["model"] = align_and_update_state_dicts(
|
| self.model.state_dict(),
|
| checkpoint["model"],
|
| c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
|
| )
|
|
|
| incompatible = super()._load_model(checkpoint)
|
|
|
| model_buffers = dict(self.model.named_buffers(recurse=False))
|
| for k in ["pixel_mean", "pixel_std"]:
|
|
|
|
|
|
|
| if k in model_buffers:
|
| try:
|
| incompatible.missing_keys.remove(k)
|
| except ValueError:
|
| pass
|
| for k in incompatible.unexpected_keys[:]:
|
|
|
|
|
| if "anchor_generator.cell_anchors" in k:
|
| incompatible.unexpected_keys.remove(k)
|
| return incompatible
|
|
|