| |
|
|
| from dataclasses import dataclass |
| from typing import Union |
| import torch |
|
|
|
|
| @dataclass |
| class DensePoseEmbeddingPredictorOutput: |
| """ |
| Predictor output that contains embedding and coarse segmentation data: |
| * embedding: float tensor of size [N, D, H, W], contains estimated embeddings |
| * coarse_segm: float tensor of size [N, K, H, W] |
| Here D = MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
| K = MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS |
| """ |
|
|
| embedding: torch.Tensor |
| coarse_segm: torch.Tensor |
|
|
| def __len__(self): |
| """ |
| Number of instances (N) in the output |
| """ |
| return self.coarse_segm.size(0) |
|
|
| def __getitem__( |
| self, item: Union[int, slice, torch.BoolTensor] |
| ) -> "DensePoseEmbeddingPredictorOutput": |
| """ |
| Get outputs for the selected instance(s) |
| |
| Args: |
| item (int or slice or tensor): selected items |
| """ |
| if isinstance(item, int): |
| return DensePoseEmbeddingPredictorOutput( |
| coarse_segm=self.coarse_segm[item].unsqueeze(0), |
| embedding=self.embedding[item].unsqueeze(0), |
| ) |
| else: |
| return DensePoseEmbeddingPredictorOutput( |
| coarse_segm=self.coarse_segm[item], embedding=self.embedding[item] |
| ) |
|
|
| def to(self, device: torch.device): |
| """ |
| Transfers all tensors to the given device |
| """ |
| coarse_segm = self.coarse_segm.to(device) |
| embedding = self.embedding.to(device) |
| return DensePoseEmbeddingPredictorOutput(coarse_segm=coarse_segm, embedding=embedding) |
|
|