| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Union |
| | import torch |
| |
|
| |
|
| | @dataclass |
| | class DensePoseEmbeddingPredictorOutput: |
| | """ |
| | Predictor output that contains embedding and coarse segmentation data: |
| | * embedding: float tensor of size [N, D, H, W], contains estimated embeddings |
| | * coarse_segm: float tensor of size [N, K, H, W] |
| | Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
| | K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS |
| | """ |
| |
|
| | embedding: torch.Tensor |
| | coarse_segm: torch.Tensor |
| |
|
| | def __len__(self): |
| | """ |
| | Number of instances (N) in the output |
| | """ |
| | return self.coarse_segm.size(0) |
| |
|
| | def __getitem__( |
| | self, item: Union[int, slice, torch.BoolTensor] |
| | ) -> "DensePoseEmbeddingPredictorOutput": |
| | """ |
| | Get outputs for the selected instance(s) |
| | |
| | Args: |
| | item (int or slice or tensor): selected items |
| | """ |
| | if isinstance(item, int): |
| | return DensePoseEmbeddingPredictorOutput( |
| | coarse_segm=self.coarse_segm[item].unsqueeze(0), |
| | embedding=self.embedding[item].unsqueeze(0), |
| | ) |
| | else: |
| | return DensePoseEmbeddingPredictorOutput( |
| | coarse_segm=self.coarse_segm[item], embedding=self.embedding[item] |
| | ) |
| |
|
| | def to(self, device: torch.device): |
| | """ |
| | Transfers all tensors to the given device |
| | """ |
| | coarse_segm = self.coarse_segm.to(device) |
| | embedding = self.embedding.to(device) |
| | return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding) |
| |
|