Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from torch import nn | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from tensormask import _C | |
| class _SwapAlign2Nat(Function): | |
| def forward(ctx, X, lambda_val, pad_val): | |
| ctx.lambda_val = lambda_val | |
| ctx.input_shape = X.size() | |
| Y = _C.swap_align2nat_forward(X, lambda_val, pad_val) | |
| return Y | |
| def backward(ctx, gY): | |
| lambda_val = ctx.lambda_val | |
| bs, ch, h, w = ctx.input_shape | |
| gX = _C.swap_align2nat_backward(gY, lambda_val, bs, ch, h, w) | |
| return gX, None, None | |
| swap_align2nat = _SwapAlign2Nat.apply | |
| class SwapAlign2Nat(nn.Module): | |
| """ | |
| The op `SwapAlign2Nat` described in https://arxiv.org/abs/1903.12174. | |
| Given an input tensor that predicts masks of shape (N, C=VxU, H, W), | |
| apply the op, it will return masks of shape (N, V'xU', H', W') where | |
| the unit lengths of (V, U) and (H, W) are swapped, and the mask representation | |
| is transformed from aligned to natural. | |
| Args: | |
| lambda_val (int): the relative unit length ratio between (V, U) and (H, W), | |
| as we always have larger unit lengths for (V, U) than (H, W), | |
| lambda_val is always >= 1. | |
| pad_val (float): padding value for the values falling outside of the input | |
| tensor, default set to -6 as sigmoid(-6) is ~0, indicating | |
| that is no masks outside of the tensor. | |
| """ | |
| def __init__(self, lambda_val, pad_val=-6.0): | |
| super(SwapAlign2Nat, self).__init__() | |
| self.lambda_val = lambda_val | |
| self.pad_val = pad_val | |
| def forward(self, X): | |
| return swap_align2nat(X, self.lambda_val, self.pad_val) | |
| def __repr__(self): | |
| tmpstr = self.__class__.__name__ + "(" | |
| tmpstr += "lambda_val=" + str(self.lambda_val) | |
| tmpstr += ", pad_val=" + str(self.pad_val) | |
| tmpstr += ")" | |
| return tmpstr | |