| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Union |
| | import torch |
| |
|
| |
|
| | @dataclass |
| | class DensePoseChartPredictorOutput: |
| | """ |
| | Predictor output that contains segmentation and inner coordinates predictions for predefined |
| | body parts: |
| | * coarse segmentation, a tensor of shape [N, K, Hout, Wout] |
| | * fine segmentation, a tensor of shape [N, C, Hout, Wout] |
| | * U coordinates, a tensor of shape [N, C, Hout, Wout] |
| | * V coordinates, a tensor of shape [N, C, Hout, Wout] |
| | where |
| | - N is the number of instances |
| | - K is the number of coarse segmentation channels ( |
| | 2 = foreground / background, |
| | 15 = one of 14 body parts / background) |
| | - C is the number of fine segmentation channels ( |
| | 24 fine body parts / background) |
| | - Hout and Wout are height and width of predictions |
| | """ |
| |
|
| | coarse_segm: torch.Tensor |
| | fine_segm: torch.Tensor |
| | u: torch.Tensor |
| | v: torch.Tensor |
| |
|
| | def __len__(self): |
| | """ |
| | Number of instances (N) in the output |
| | """ |
| | return self.coarse_segm.size(0) |
| |
|
| | def __getitem__( |
| | self, item: Union[int, slice, torch.BoolTensor] |
| | ) -> "DensePoseChartPredictorOutput": |
| | """ |
| | Get outputs for the selected instance(s) |
| | |
| | Args: |
| | item (int or slice or tensor): selected items |
| | """ |
| | if isinstance(item, int): |
| | return DensePoseChartPredictorOutput( |
| | coarse_segm=self.coarse_segm[item].unsqueeze(0), |
| | fine_segm=self.fine_segm[item].unsqueeze(0), |
| | u=self.u[item].unsqueeze(0), |
| | v=self.v[item].unsqueeze(0), |
| | ) |
| | else: |
| | return DensePoseChartPredictorOutput( |
| | coarse_segm=self.coarse_segm[item], |
| | fine_segm=self.fine_segm[item], |
| | u=self.u[item], |
| | v=self.v[item], |
| | ) |
| |
|
| | def to(self, device: torch.device): |
| | """ |
| | Transfers all tensors to the given device |
| | """ |
| | coarse_segm = self.coarse_segm.to(device) |
| | fine_segm = self.fine_segm.to(device) |
| | u = self.u.to(device) |
| | v = self.v.to(device) |
| | return DensePoseChartPredictorOutput(coarse_segm=coarse_segm, fine_segm=fine_segm, u=u, v=v) |
| |
|