Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import torch | |
| from torch.nn.functional import interpolate | |
| def _resize(tensor, size, mode='bilinear'): | |
| """Resize the image. | |
| Args: | |
| tensor (torch.Tensor): The image tensor to be resized. | |
| size (tuple of int): Size of the resized image (height, width). | |
| mode (str): The pixel sampling interpolation mode to be used. | |
| Returns: | |
| Tensor: The resized image tensor. | |
| """ | |
| assert len(size) == 2 | |
| # If the tensor is already the desired size, return it immediately. | |
| if tensor.shape[-2] == size[0] and tensor.shape[-1] == size[1]: | |
| return tensor | |
| if not tensor.is_floating_point(): | |
| dtype = tensor.dtype | |
| tensor = tensor.to(torch.float32) | |
| tensor = _resize(tensor, size, mode) | |
| return tensor.to(dtype) | |
| out_shape = (*tensor.shape[:-2], *size) | |
| if tensor.ndimension() < 3: | |
| raise Exception('tensor must be at least 2D') | |
| elif tensor.ndimension() == 3: | |
| tensor = tensor.unsqueeze(0) | |
| elif tensor.ndimension() > 4: | |
| tensor = tensor.view(-1, *tensor.shape[-3:]) | |
| align_corners = None | |
| if mode in {'linear', 'bilinear', 'trilinear'}: | |
| align_corners = False | |
| resized = interpolate(tensor, size=size, mode=mode, align_corners=align_corners) | |
| return resized.view(*out_shape) | |
| def _crop(tensor, t, l, h, w, padding_mode='constant', fill=0): | |
| """Crop the image, padding out-of-bounds regions. | |
| Args: | |
| tensor (torch.Tensor): The image tensor to be cropped. | |
| t (int): Top pixel coordinate. | |
| l (int): Left pixel coordinate. | |
| h (int): Height of the cropped image. | |
| w (int): Width of the cropped image. | |
| padding_mode (str): Padding mode (currently "constant" is the only valid option). | |
| fill (float): Fill value to use with constant padding. | |
| Returns: | |
| Tensor: The cropped image tensor. | |
| """ | |
| # If the _crop region is wholly within the image, simply narrow the tensor. | |
| if t >= 0 and l >= 0 and t + h <= tensor.size(-2) and l + w <= tensor.size(-1): | |
| return tensor[..., t:t+h, l:l+w] | |
| if padding_mode == 'constant': | |
| result = torch.full((*tensor.size()[:-2], h, w), fill, | |
| device=tensor.device, dtype=tensor.dtype) | |
| else: | |
| raise Exception('_crop only supports "constant" padding currently.') | |
| sx1 = l | |
| sy1 = t | |
| sx2 = l + w | |
| sy2 = t + h | |
| dx1 = 0 | |
| dy1 = 0 | |
| if sx1 < 0: | |
| dx1 = -sx1 | |
| w += sx1 | |
| sx1 = 0 | |
| if sy1 < 0: | |
| dy1 = -sy1 | |
| h += sy1 | |
| sy1 = 0 | |
| if sx2 >= tensor.size(-1): | |
| w -= sx2 - tensor.size(-1) | |
| if sy2 >= tensor.size(-2): | |
| h -= sy2 - tensor.size(-2) | |
| # Copy the in-bounds sub-area of the _crop region into the result tensor. | |
| if h > 0 and w > 0: | |
| src = tensor.narrow(-2, sy1, h).narrow(-1, sx1, w) | |
| dst = result.narrow(-2, dy1, h).narrow(-1, dx1, w) | |
| dst.copy_(src) | |
| return result | |
| def calculate_fit_contain_output_area(in_height, in_width, out_height, out_width): | |
| ih, iw = in_height, in_width | |
| k = min(out_width / iw, out_height / ih) | |
| oh = round(k * ih) | |
| ow = round(k * iw) | |
| y_off = (out_height - oh) // 2 | |
| x_off = (out_width - ow) // 2 | |
| return y_off, x_off, oh, ow | |
| def fit(tensor, size, fit_mode='cover', resize_mode='bilinear', *, fill=0): | |
| """Fit the image within the given spatial dimensions. | |
| Args: | |
| tensor (torch.Tensor): The image tensor to be fit. | |
| size (tuple of int): Size of the output (height, width). | |
| fit_mode (str): 'fill', 'contain', or 'cover'. These behave in the same way as CSS's | |
| `object-fit` property. | |
| fill (float): padding value (only applicable in 'contain' mode). | |
| Returns: | |
| Tensor: The resized image tensor. | |
| """ | |
| if fit_mode == 'fill': | |
| return _resize(tensor, size, mode=resize_mode) | |
| elif fit_mode == 'contain': | |
| y_off, x_off, oh, ow = calculate_fit_contain_output_area(*tensor.shape[-2:], *size) | |
| resized = _resize(tensor, (oh, ow), mode=resize_mode) | |
| result = tensor.new_full((*tensor.size()[:-2], *size), fill) | |
| result[..., y_off:y_off + oh, x_off:x_off + ow] = resized | |
| return result | |
| elif fit_mode == 'cover': | |
| ih, iw = tensor.shape[-2:] | |
| k = max(size[-1] / iw, size[-2] / ih) | |
| oh = round(k * ih) | |
| ow = round(k * iw) | |
| resized = _resize(tensor, (oh, ow), mode=resize_mode) | |
| y_trim = (oh - size[-2]) // 2 | |
| x_trim = (ow - size[-1]) // 2 | |
| result = _crop(resized, y_trim, x_trim, size[-2], size[-1]) | |
| return result | |
| raise ValueError('Invalid fit_mode: ' + repr(fit_mode)) | |