import torch import torch.nn.functional as F from numpy import ceil def cat_k(input): """concat second dimesion to batch""" return input.flatten(0, 1) def split_k(input, size: int, dim: int = 0): """reshape input to original batch size""" if dim < 0: dim = input.dim() + dim split_size = list(input.size()) split_size[dim] = size split_size.insert(dim+1, -1) return input.view(split_size) class Alignment(torch.nn.Module): """Image Alignment for model downsample requirement""" # 调整图像大小以适应特定的下采样要求,确保输入图像的尺寸是指定除数的倍数。 def __init__(self, divisor=64., mode='pad', padding_mode='replicate'): super().__init__() self.divisor = float(divisor) self.mode = mode self.padding_mode = padding_mode self._tmp_shape = None def extra_repr(self): s = 'divisor={divisor}, mode={mode}' if self.mode == 'pad': s += ', padding_mode={padding_mode}' return s.format(**self.__dict__) @staticmethod def _resize(input, size): # _resize方法使用双线性插值调整图像大小 return F.interpolate(input, size, mode='bilinear', align_corners=False) def _align(self, input): # _align方法将图像尺寸调整为指定除数的倍数,可以通过填充或缩放实现 H, W = input.size()[-2:] H_ = int(ceil(H / self.divisor) * self.divisor) W_ = int(ceil(W / self.divisor) * self.divisor) pad_H, pad_W = H_-H, W_-W if pad_H == pad_W == 0: self._tmp_shape = None return input self._tmp_shape = input.size() if self.mode == 'pad': return F.pad(input, (0, pad_W, 0, pad_H), mode=self.padding_mode) elif self.mode == 'resize': return self._resize(input, size=(H_, W_)) def _resume(self, input, shape=None): # 将图像恢复到调整前的原始形状 if shape is not None: self._tmp_shape = shape if self._tmp_shape is None: return input if self.mode == 'pad': output = input[..., :self._tmp_shape[-2], :self._tmp_shape[-1]] elif self.mode == 'resize': output = self._resize(input, size=self._tmp_shape[-2:]) return output def align(self, input): """align""" if input.dim() == 4: return self._align(input) elif input.dim() == 5: return split_k(self._align(cat_k(input)), input.size(0)) def resume(self, input, shape=None): """resume""" if input.dim() == 4: return self._resume(input, shape) elif input.dim() == 5: return split_k(self._resume(cat_k(input), shape), input.size(0)) def forward(self, func, *args, **kwargs): pass