from typing import List import numpy as np import torch import torch.nn as nn import math __all__ = ['DCT2D'] # Helper Functions mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], [962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], [1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], [18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], [17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] #torch.tensor(var) def _zigzag_permutation(rows: int, cols: int) -> List[int]: idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() dia = [[] for _ in range(rows + cols - 1)] zigzag = [] for i in range(rows): for j in range(cols): s = i + j if s % 2 == 0: dia[s].insert(0, idx_matrix[i][j]) else: dia[s].append(idx_matrix[i][j]) for d in dia: zigzag.extend(d) return zigzag # Kernels def _dct_kernel_type_2( kernel_size: int, orthonormal: bool, device=None, dtype=None, ) -> torch.Tensor: factory_kwargs = dict(device=device, dtype=dtype) x = torch.eye(kernel_size, **factory_kwargs) v = x.clone().contiguous().view(-1, kernel_size) v = torch.cat([v, v.flip([1])], dim=-1) v = torch.fft.fft(v, dim=-1)[:, :kernel_size] try: k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] except: k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] k = torch.exp(k / (kernel_size * 2)) v = v * k v = v.real if orthonormal: v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) v = v.contiguous().view(*x.shape) return v def _dct_kernel_type_3( kernel_size: int, orthonormal: bool, device=None, dtype=None, ) -> torch.Tensor: return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) # Modules class _DCT1D(nn.Module): def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: factory_kwargs = dict(device=device, dtype=dtype) super(_DCT1D, self).__init__() kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) self.register_parameter('bias', None) def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.linear(x, self.weights, self.bias) class _DCT2D(nn.Module): def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: factory_kwargs = dict(device=device, dtype=dtype) super(_DCT2D, self).__init__() self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., H, W] @ DCT_Kernel.T -> [..., W, H] @ DCT_Kernel.T -> [..., H, W] return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) # Discrete Cosine Transforms (DCT) class Learnable_DCT2D(nn.Module): r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. Args: kernel_size (int): Size of the coefficient kernel kernel_type (int): Type of the DCT (see Notes). Default: 2 orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True """ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: factory_kwargs = dict(device=device, dtype=dtype) super(Learnable_DCT2D, self).__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) self.Y_Conv = nn.Conv2d(kernel_size**2, 24, kernel_size=1, padding=0) self.Cb_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0) self.Cr_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0) self.mean = torch.tensor(mean, requires_grad=False) self.var = torch.tensor(var, requires_grad=False) self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) def denormalize(self, x): x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 # denormalize return x def rgb2ycbcr(self, x): y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) #rgb2ycbcr cb = 0.564 * (x[:,:,:,2] - y) + 128 cr = 0.713 * (x[:,:,:,0] - y) + 128 x = torch.stack([y, cb, cr],dim=-1) return x def frequncy_normalize(self, x): x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8)) x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8)) x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8)) return x def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape #b, c, h, w x = x.permute(0, 2, 3, 1)#b, h, w, c x = self.denormalize(x)#b, h, w,c x = self.rgb2ycbcr(x)#b, h, w, c x = x.permute(0, 3, 1, 2)#b, c, h, w x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) x = self.transform(x) x = x.reshape(-1, c, self.k * self.k) x = x[:, :, self.permutation] x = self.frequncy_normalize(x) x = x.reshape(b, h // self.k, w // self.k, c, -1)#? b, block -> b, h, w, c, block x = x.permute(0, 3, 4, 1, 2).contiguous() # b, c, block, h, w x_Y = self.Y_Conv(x[:, 0, ]) x_Cb = self.Cb_Conv(x[:, 1, ]) x_Cr = self.Cr_Conv(x[:, 2, ]) x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) return x class Static_DCT2D(nn.Module): r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. Args: kernel_size (int): Size of the coefficient kernel kernel_type (int): Type of the DCT (see Notes). Default: 2 orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True """ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: factory_kwargs = dict(device=device, dtype=dtype) super(Static_DCT2D, self).__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) self.mean = torch.tensor(mean, requires_grad=False) self.var = torch.tensor(var, requires_grad=False) self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) def denormalize(self, x): x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 # denormalize return x def rgb2ycbcr(self, x): y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) #rgb2ycbcr cb = 0.564 * (x[:,:,:,2] - y) + 128 cr = 0.713 * (x[:,:,:,0] - y) + 128 x = torch.stack([y, cb, cr],dim=-1) return x def frequncy_normalize(self, x): x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8)) x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8)) x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8)) return x def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape #b, c, h, w x = x.permute(0, 2, 3, 1)#b, h, w, c x = self.denormalize(x)#b, h, w,c x = self.rgb2ycbcr(x)#b, h, w, c x = x.permute(0, 3, 1, 2)#b, c, h, w x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) x = self.transform(x) x = x.reshape(-1, c, self.k * self.k) x = x[:, :, self.permutation] x = self.frequncy_normalize(x) x = x.reshape(b, h // self.k, w // self.k, c, -1)#? b, block -> b, h, w, c, block x = x.permute(0, 3, 4, 1, 2).contiguous() # b, c, block, h, w x_Y = self.Y_Conv(x[:, 0, ]) x_Cb = self.Cb_Conv(x[:, 1, ]) x_Cr = self.Cr_Conv(x[:, 2, ]) x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) return x class DCT2D(nn.Module): r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. Args: kernel_size (int): Size of the coefficient kernel kernel_type (int): Type of the DCT (see Notes). Default: 2 orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True """ def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, device=None, dtype=None) -> None: factory_kwargs = dict(device=device, dtype=dtype) super(DCT2D, self).__init__() self.k = kernel_size self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) self.permutation = _zigzag_permutation(kernel_size, kernel_size) def forward(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape x = self.unfold(x).transpose(-1, -2)#b,c, h, w -> b, c*block, blocks x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) x = self.transform(x) x = x.reshape(-1, c, self.k * self.k) x = x[:, :, self.permutation] x = x.reshape(b*(h // self.k)*(w // self.k), c, -1)#? b, block -> b, h, w, c, block #torch.max(x[:,0,],axis=0).values.detach().cpu().numpy() mean_list = torch.zeros([3,64]) var_list = torch.zeros([3, 64]) mean_list[0] = torch.mean(x[:, 0, ],axis=0) mean_list[1] = torch.mean(x[:, 1, ], axis=0) mean_list[2] = torch.mean(x[:, 2, ], axis=0) var_list[0] = torch.var(x[:, 0, ],axis=0) var_list[1] = torch.var(x[:, 1, ], axis=0) var_list[2] = torch.var(x[:, 2, ], axis=0) return mean_list, var_list