| | |
| |
|
| | import torch |
| |
|
| |
|
| | class ImageResizeTransform: |
| | """ |
| | Transform that resizes images loaded from a dataset |
| | (BGR data in NCHW channel order, typically uint8) to a format ready to be |
| | consumed by DensePose training (BGR float32 data in NCHW channel order) |
| | """ |
| |
|
| | def __init__(self, min_size: int = 800, max_size: int = 1333): |
| | self.min_size = min_size |
| | self.max_size = max_size |
| |
|
| | def __call__(self, images: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | images (torch.Tensor): tensor of size [N, 3, H, W] that contains |
| | BGR data (typically in uint8) |
| | Returns: |
| | images (torch.Tensor): tensor of size [N, 3, H1, W1] where |
| | H1 and W1 are chosen to respect the specified min and max sizes |
| | and preserve the original aspect ratio, the data channels |
| | follow BGR order and the data type is `torch.float32` |
| | """ |
| | |
| | images = images.float() |
| | min_size = min(images.shape[-2:]) |
| | max_size = max(images.shape[-2:]) |
| | scale = min(self.min_size / min_size, self.max_size / max_size) |
| | images = torch.nn.functional.interpolate( |
| | images, |
| | scale_factor=scale, |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | return images |
| |
|