|
|
|
|
|
|
|
|
|
|
| import itertools
|
| import math
|
|
|
| import torch
|
|
|
|
|
| class CenterPadding(torch.nn.Module):
|
| def __init__(self, multiple: int):
|
| super().__init__()
|
| self.multiple = multiple
|
|
|
| def _get_pad(self, size):
|
| new_size = math.ceil(size / self.multiple) * self.multiple
|
| pad_size = new_size - size
|
| pad_size_left = pad_size // 2
|
| pad_size_right = pad_size - pad_size_left
|
| return pad_size_left, pad_size_right
|
|
|
|
|
| def forward(self, x):
|
|
|
| pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:-3:-1]))
|
| output = torch.nn.functional.pad(x, pads)
|
| return output
|
|
|
| def __extra_repr__(self) -> str:
|
| return f"multiple={self.multiple}"
|
|
|
|
|
| class StretchToMultiple(torch.nn.Module):
|
| def __init__(self, multiple: int):
|
| super().__init__()
|
| self.multiple = multiple
|
|
|
| def forward(self, x):
|
|
|
| *shape, C, H, W = x.shape
|
| new_H = math.ceil(H / self.multiple) * self.multiple
|
| new_W = math.ceil(W / self.multiple) * self.multiple
|
| if new_H != H or new_W != W:
|
| x = x.reshape(-1, C, H, W)
|
| x = torch.nn.functional.interpolate(x, size=(new_H, new_W), mode="bilinear")
|
| x = x.reshape(*shape, C, new_H, new_W)
|
| return x
|
|
|
| def __extra_repr__(self) -> str:
|
| return f"multiple={self.multiple}"
|
|
|