Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from torch.autograd import Function | |
| from torch.nn import functional as F | |
| class SigmoidGeometricMean(Function): | |
| """Forward and backward function of geometric mean of two sigmoid | |
| functions. | |
| This implementation with analytical gradient function substitutes | |
| the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The | |
| original implementation incurs none during gradient backprapagation | |
| if both x and y are very small values. | |
| """ | |
| def forward(ctx, x, y): | |
| x_sigmoid = x.sigmoid() | |
| y_sigmoid = y.sigmoid() | |
| z = (x_sigmoid * y_sigmoid).sqrt() | |
| ctx.save_for_backward(x_sigmoid, y_sigmoid, z) | |
| return z | |
| def backward(ctx, grad_output): | |
| x_sigmoid, y_sigmoid, z = ctx.saved_tensors | |
| grad_x = grad_output * z * (1 - x_sigmoid) / 2 | |
| grad_y = grad_output * z * (1 - y_sigmoid) / 2 | |
| return grad_x, grad_y | |
| sigmoid_geometric_mean = SigmoidGeometricMean.apply | |
| def interpolate_as(source, target, mode='bilinear', align_corners=False): | |
| """Interpolate the `source` to the shape of the `target`. | |
| The `source` must be a Tensor, but the `target` can be a Tensor or a | |
| np.ndarray with the shape (..., target_h, target_w). | |
| Args: | |
| source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or | |
| (N, C, H, W). | |
| target (Tensor | np.ndarray): The interpolation target with the shape | |
| (..., target_h, target_w). | |
| mode (str): Algorithm used for interpolation. The options are the | |
| same as those in F.interpolate(). Default: ``'bilinear'``. | |
| align_corners (bool): The same as the argument in F.interpolate(). | |
| Returns: | |
| Tensor: The interpolated source Tensor. | |
| """ | |
| assert len(target.shape) >= 2 | |
| def _interpolate_as(source, target, mode='bilinear', align_corners=False): | |
| """Interpolate the `source` (4D) to the shape of the `target`.""" | |
| target_h, target_w = target.shape[-2:] | |
| source_h, source_w = source.shape[-2:] | |
| if target_h != source_h or target_w != source_w: | |
| source = F.interpolate( | |
| source, | |
| size=(target_h, target_w), | |
| mode=mode, | |
| align_corners=align_corners) | |
| return source | |
| if len(source.shape) == 3: | |
| source = source[:, None, :, :] | |
| source = _interpolate_as(source, target, mode, align_corners) | |
| return source[:, 0, :, :] | |
| else: | |
| return _interpolate_as(source, target, mode, align_corners) | |