| | |
| | from collections import OrderedDict |
| |
|
| | from detectron2.checkpoint import DetectionCheckpointer |
| |
|
| |
|
| | def _rename_HRNet_weights(weights): |
| | |
| | |
| | if ( |
| | len(weights["model"].keys()) == 1956 |
| | and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716 |
| | ): |
| | hrnet_weights = OrderedDict() |
| | for k in weights["model"].keys(): |
| | hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k] |
| | return {"model": hrnet_weights} |
| | else: |
| | return weights |
| |
|
| |
|
| | class DensePoseCheckpointer(DetectionCheckpointer): |
| | """ |
| | Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights |
| | """ |
| |
|
| | def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): |
| | super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables) |
| |
|
| | def _load_file(self, filename: str) -> object: |
| | """ |
| | Adding hrnet support |
| | """ |
| | weights = super()._load_file(filename) |
| | return _rename_HRNet_weights(weights) |
| |
|