| | |
| |
|
| | from dataclasses import make_dataclass |
| | from functools import lru_cache |
| | from typing import Any, Optional |
| | import torch |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: |
| | """ |
| | Create a new output class from an existing one by adding new attributes |
| | related to confidence estimation: |
| | - sigma_1 (tensor) |
| | - sigma_2 (tensor) |
| | - kappa_u (tensor) |
| | - kappa_v (tensor) |
| | - fine_segm_confidence (tensor) |
| | - coarse_segm_confidence (tensor) |
| | |
| | Details on confidence estimation parameters can be found in: |
| | N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning |
| | Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 |
| | A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 |
| | |
| | The new class inherits the provided `BasePredictorOutput` class, |
| | it's name is composed of the name of the provided class and |
| | "WithConfidences" suffix. |
| | |
| | Args: |
| | BasePredictorOutput (type): output type to which confidence data |
| | is to be added, assumed to be a dataclass |
| | Return: |
| | New dataclass derived from the provided one that has attributes |
| | for confidence estimation |
| | """ |
| |
|
| | PredictorOutput = make_dataclass( |
| | BasePredictorOutput.__name__ + "WithConfidences", |
| | fields=[ |
| | ("sigma_1", Optional[torch.Tensor], None), |
| | ("sigma_2", Optional[torch.Tensor], None), |
| | ("kappa_u", Optional[torch.Tensor], None), |
| | ("kappa_v", Optional[torch.Tensor], None), |
| | ("fine_segm_confidence", Optional[torch.Tensor], None), |
| | ("coarse_segm_confidence", Optional[torch.Tensor], None), |
| | ], |
| | bases=(BasePredictorOutput,), |
| | ) |
| |
|
| | |
| |
|
| | def slice_if_not_none(data, item): |
| | if data is None: |
| | return None |
| | if isinstance(item, int): |
| | return data[item].unsqueeze(0) |
| | return data[item] |
| |
|
| | def PredictorOutput_getitem(self, item): |
| | PredictorOutput = type(self) |
| | base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) |
| | return PredictorOutput( |
| | **base_predictor_output_sliced.__dict__, |
| | coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), |
| | fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item), |
| | sigma_1=slice_if_not_none(self.sigma_1, item), |
| | sigma_2=slice_if_not_none(self.sigma_2, item), |
| | kappa_u=slice_if_not_none(self.kappa_u, item), |
| | kappa_v=slice_if_not_none(self.kappa_v, item), |
| | ) |
| |
|
| | PredictorOutput.__getitem__ = PredictorOutput_getitem |
| |
|
| | def PredictorOutput_to(self, device: torch.device): |
| | """ |
| | Transfers all tensors to the given device |
| | """ |
| | PredictorOutput = type(self) |
| | base_predictor_output_to = super(PredictorOutput, self).to(device) |
| |
|
| | def to_device_if_tensor(var: Any): |
| | if isinstance(var, torch.Tensor): |
| | return var.to(device) |
| | return var |
| |
|
| | return PredictorOutput( |
| | **base_predictor_output_to.__dict__, |
| | sigma_1=to_device_if_tensor(self.sigma_1), |
| | sigma_2=to_device_if_tensor(self.sigma_2), |
| | kappa_u=to_device_if_tensor(self.kappa_u), |
| | kappa_v=to_device_if_tensor(self.kappa_v), |
| | fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), |
| | coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), |
| | ) |
| |
|
| | PredictorOutput.to = PredictorOutput_to |
| | return PredictorOutput |
| |
|