|
|
|
|
| from dataclasses import make_dataclass
|
| from functools import lru_cache
|
| from typing import Any, Optional
|
| import torch
|
|
|
|
|
| @lru_cache(maxsize=None)
|
| def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type:
|
| """
|
| Create a new output class from an existing one by adding new attributes
|
| related to confidence estimation:
|
| - sigma_1 (tensor)
|
| - sigma_2 (tensor)
|
| - kappa_u (tensor)
|
| - kappa_v (tensor)
|
| - fine_segm_confidence (tensor)
|
| - coarse_segm_confidence (tensor)
|
|
|
| Details on confidence estimation parameters can be found in:
|
| N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
|
| Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
|
| A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020
|
|
|
| The new class inherits the provided `BasePredictorOutput` class,
|
| it's name is composed of the name of the provided class and
|
| "WithConfidences" suffix.
|
|
|
| Args:
|
| BasePredictorOutput (type): output type to which confidence data
|
| is to be added, assumed to be a dataclass
|
| Return:
|
| New dataclass derived from the provided one that has attributes
|
| for confidence estimation
|
| """
|
|
|
| PredictorOutput = make_dataclass(
|
| BasePredictorOutput.__name__ + "WithConfidences",
|
| fields=[
|
| ("sigma_1", Optional[torch.Tensor], None),
|
| ("sigma_2", Optional[torch.Tensor], None),
|
| ("kappa_u", Optional[torch.Tensor], None),
|
| ("kappa_v", Optional[torch.Tensor], None),
|
| ("fine_segm_confidence", Optional[torch.Tensor], None),
|
| ("coarse_segm_confidence", Optional[torch.Tensor], None),
|
| ],
|
| bases=(BasePredictorOutput,),
|
| )
|
|
|
|
|
|
|
| def slice_if_not_none(data, item):
|
| if data is None:
|
| return None
|
| if isinstance(item, int):
|
| return data[item].unsqueeze(0)
|
| return data[item]
|
|
|
| def PredictorOutput_getitem(self, item):
|
| PredictorOutput = type(self)
|
| base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item)
|
| return PredictorOutput(
|
| **base_predictor_output_sliced.__dict__,
|
| coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item),
|
| fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item),
|
| sigma_1=slice_if_not_none(self.sigma_1, item),
|
| sigma_2=slice_if_not_none(self.sigma_2, item),
|
| kappa_u=slice_if_not_none(self.kappa_u, item),
|
| kappa_v=slice_if_not_none(self.kappa_v, item),
|
| )
|
|
|
| PredictorOutput.__getitem__ = PredictorOutput_getitem
|
|
|
| def PredictorOutput_to(self, device: torch.device):
|
| """
|
| Transfers all tensors to the given device
|
| """
|
| PredictorOutput = type(self)
|
| base_predictor_output_to = super(PredictorOutput, self).to(device)
|
|
|
| def to_device_if_tensor(var: Any):
|
| if isinstance(var, torch.Tensor):
|
| return var.to(device)
|
| return var
|
|
|
| return PredictorOutput(
|
| **base_predictor_output_to.__dict__,
|
| sigma_1=to_device_if_tensor(self.sigma_1),
|
| sigma_2=to_device_if_tensor(self.sigma_2),
|
| kappa_u=to_device_if_tensor(self.kappa_u),
|
| kappa_v=to_device_if_tensor(self.kappa_v),
|
| fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence),
|
| coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence),
|
| )
|
|
|
| PredictorOutput.to = PredictorOutput_to
|
| return PredictorOutput
|
|
|