mimc_rl / util /alignment.py
wangyanhui666's picture
fine tune decoder with mask
9cf79cf
import torch
import torch.nn.functional as F
from numpy import ceil
def cat_k(input):
"""concat second dimesion to batch"""
return input.flatten(0, 1)
def split_k(input, size: int, dim: int = 0):
"""reshape input to original batch size"""
if dim < 0:
dim = input.dim() + dim
split_size = list(input.size())
split_size[dim] = size
split_size.insert(dim+1, -1)
return input.view(split_size)
class Alignment(torch.nn.Module):
"""Image Alignment for model downsample requirement"""
# 调整图像大小以适应特定的下采样要求,确保输入图像的尺寸是指定除数的倍数。
def __init__(self, divisor=64., mode='pad', padding_mode='replicate'):
super().__init__()
self.divisor = float(divisor)
self.mode = mode
self.padding_mode = padding_mode
self._tmp_shape = None
def extra_repr(self):
s = 'divisor={divisor}, mode={mode}'
if self.mode == 'pad':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
@staticmethod
def _resize(input, size): # _resize方法使用双线性插值调整图像大小
return F.interpolate(input, size, mode='bilinear', align_corners=False)
def _align(self, input): # _align方法将图像尺寸调整为指定除数的倍数,可以通过填充或缩放实现
H, W = input.size()[-2:]
H_ = int(ceil(H / self.divisor) * self.divisor)
W_ = int(ceil(W / self.divisor) * self.divisor)
pad_H, pad_W = H_-H, W_-W
if pad_H == pad_W == 0:
self._tmp_shape = None
return input
self._tmp_shape = input.size()
if self.mode == 'pad':
return F.pad(input, (0, pad_W, 0, pad_H), mode=self.padding_mode)
elif self.mode == 'resize':
return self._resize(input, size=(H_, W_))
def _resume(self, input, shape=None): # 将图像恢复到调整前的原始形状
if shape is not None:
self._tmp_shape = shape
if self._tmp_shape is None:
return input
if self.mode == 'pad':
output = input[..., :self._tmp_shape[-2], :self._tmp_shape[-1]]
elif self.mode == 'resize':
output = self._resize(input, size=self._tmp_shape[-2:])
return output
def align(self, input):
"""align"""
if input.dim() == 4:
return self._align(input)
elif input.dim() == 5:
return split_k(self._align(cat_k(input)), input.size(0))
def resume(self, input, shape=None):
"""resume"""
if input.dim() == 4:
return self._resume(input, shape)
elif input.dim() == 5:
return split_k(self._resume(cat_k(input), shape), input.size(0))
def forward(self, func, *args, **kwargs):
pass