| | """ Padding Helpers |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | import math |
| | from typing import List, Tuple |
| |
|
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: |
| | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 |
| | return padding |
| |
|
| |
|
| | |
| | def get_same_padding(x: int, k: int, s: int, d: int): |
| | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) |
| |
|
| |
|
| | |
| | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): |
| | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 |
| |
|
| |
|
| | |
| | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): |
| | ih, iw = x.size()[-2:] |
| | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) |
| | if pad_h > 0 or pad_w > 0: |
| | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) |
| | return x |
| |
|
| |
|
| | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: |
| | dynamic = False |
| | if isinstance(padding, str): |
| | |
| | padding = padding.lower() |
| | if padding == 'same': |
| | |
| | if is_static_pad(kernel_size, **kwargs): |
| | |
| | padding = get_padding(kernel_size, **kwargs) |
| | else: |
| | |
| | padding = 0 |
| | dynamic = True |
| | elif padding == 'valid': |
| | |
| | padding = 0 |
| | else: |
| | |
| | padding = get_padding(kernel_size, **kwargs) |
| | return padding, dynamic |
| |
|