|
|
| import torch
|
|
|
| from densepose.structures.data_relative import DensePoseDataRelative
|
|
|
|
|
| class DensePoseList:
|
|
|
| _TORCH_DEVICE_CPU = torch.device("cpu")
|
|
|
| def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU):
|
| assert len(densepose_datas) == len(
|
| boxes_xyxy_abs
|
| ), "Attempt to initialize DensePoseList with {} DensePose datas " "and {} boxes".format(
|
| len(densepose_datas), len(boxes_xyxy_abs)
|
| )
|
| self.densepose_datas = []
|
| for densepose_data in densepose_datas:
|
| assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, (
|
| "Attempt to initialize DensePoseList with DensePose datas "
|
| "of type {}, expected DensePoseDataRelative".format(type(densepose_data))
|
| )
|
| densepose_data_ondevice = (
|
| densepose_data.to(device) if densepose_data is not None else None
|
| )
|
| self.densepose_datas.append(densepose_data_ondevice)
|
| self.boxes_xyxy_abs = boxes_xyxy_abs.to(device)
|
| self.image_size_hw = image_size_hw
|
| self.device = device
|
|
|
| def to(self, device):
|
| if self.device == device:
|
| return self
|
| return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device)
|
|
|
| def __iter__(self):
|
| return iter(self.densepose_datas)
|
|
|
| def __len__(self):
|
| return len(self.densepose_datas)
|
|
|
| def __repr__(self):
|
| s = self.__class__.__name__ + "("
|
| s += "num_instances={}, ".format(len(self.densepose_datas))
|
| s += "image_width={}, ".format(self.image_size_hw[1])
|
| s += "image_height={})".format(self.image_size_hw[0])
|
| return s
|
|
|
| def __getitem__(self, item):
|
| if isinstance(item, int):
|
| densepose_data_rel = self.densepose_datas[item]
|
| return densepose_data_rel
|
| elif isinstance(item, slice):
|
| densepose_datas_rel = self.densepose_datas[item]
|
| boxes_xyxy_abs = self.boxes_xyxy_abs[item]
|
| return DensePoseList(
|
| densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
|
| )
|
| elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool):
|
| densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0]
|
| boxes_xyxy_abs = self.boxes_xyxy_abs[item]
|
| return DensePoseList(
|
| densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
|
| )
|
| else:
|
| densepose_datas_rel = [self.densepose_datas[i] for i in item]
|
| boxes_xyxy_abs = self.boxes_xyxy_abs[item]
|
| return DensePoseList(
|
| densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
|
| )
|
|
|