|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
|
|
|
__all__ = ['DCT2D'] |
|
|
|
|
|
|
|
|
|
|
|
mean=[[932.42657,-0.00260,0.33415,-0.02840,0.00003,-0.02792,-0.00183,0.00006,0.00032,0.03402,-0.00571,0.00020,0.00006,-0.00038,-0.00558,-0.00116,-0.00000,-0.00047,-0.00008,-0.00030,0.00942,0.00161,-0.00009,-0.00006,-0.00014,-0.00035,0.00001,-0.00220,0.00033,-0.00002,-0.00003,-0.00020,0.00007,-0.00000,0.00005,0.00293,-0.00004,0.00006,0.00019,0.00004,0.00006,-0.00015,-0.00002,0.00007,0.00010,-0.00004,0.00008,0.00000,0.00008,-0.00001,0.00015,0.00002,0.00007,0.00003,0.00004,-0.00001,0.00004,-0.00000,0.00002,-0.00000,-0.00008,-0.00000,-0.00003,0.00003], |
|
|
[962.34735,-0.00428,0.09835,0.00152,-0.00009,0.00312,-0.00141,-0.00001,-0.00013,0.01050,0.00065,0.00006,-0.00000,0.00003,0.00264,0.00000,0.00001,0.00007,-0.00006,0.00003,0.00341,0.00163,0.00004,0.00003,-0.00001,0.00008,-0.00000,0.00090,0.00018,-0.00006,-0.00001,0.00007,-0.00003,-0.00001,0.00006,0.00084,-0.00000,-0.00001,0.00000,0.00004,-0.00001,-0.00002,0.00000,0.00001,0.00002,0.00001,0.00004,0.00011,0.00000,-0.00003,0.00011,-0.00002,0.00001,0.00001,0.00001,0.00001,-0.00007,-0.00003,0.00001,0.00000,0.00001,0.00002,0.00001,0.00000], |
|
|
[1053.16101,-0.00213,-0.09207,0.00186,0.00013,0.00034,-0.00119,0.00002,0.00011,-0.00984,0.00046,-0.00007,-0.00001,-0.00005,0.00180,0.00042,0.00002,-0.00010,0.00004,0.00003,-0.00301,0.00125,-0.00002,-0.00003,-0.00001,-0.00001,-0.00001,0.00056,0.00021,0.00001,-0.00001,0.00002,-0.00001,-0.00001,0.00005,-0.00070,-0.00002,-0.00002,0.00005,-0.00004,-0.00000,0.00002,-0.00002,0.00001,0.00000,-0.00003,0.00004,0.00007,0.00001,0.00000,0.00013,-0.00000,0.00000,0.00002,-0.00000,-0.00001,-0.00004,-0.00003,0.00000,0.00001,-0.00001,0.00001,-0.00000,0.00000]] |
|
|
|
|
|
var=[[270372.37500,6287.10645,5974.94043,1653.10889,1463.91748,1832.58997,755.92468,692.41528,648.57184,641.46881,285.79288,301.62100,380.43405,349.84027,374.15891,190.30960,190.76746,221.64578,200.82646,145.87979,126.92046,62.14622,67.75562,102.42001,129.74922,130.04631,103.12189,97.76417,53.17402,54.81048,73.48712,81.04342,69.35100,49.06024,33.96053,37.03279,20.48858,24.94830,33.90822,44.54912,47.56363,40.03160,30.43313,22.63899,26.53739,26.57114,21.84404,17.41557,15.18253,10.69678,11.24111,12.97229,15.08971,15.31646,8.90409,7.44213,6.66096,6.97719,4.17834,3.83882,4.51073,2.36646,2.41363,1.48266], |
|
|
[18839.21094,321.70932,300.15259,77.47830,76.02293,89.04748,33.99642,34.74807,32.12333,28.19588,12.04675,14.26871,18.45779,16.59588,15.67892,7.37718,8.56312,10.28946,9.41013,6.69090,5.16453,2.55186,3.03073,4.66765,5.85418,5.74644,4.33702,3.66948,1.95107,2.26034,3.06380,3.50705,3.06359,2.19284,1.54454,1.57860,0.97078,1.13941,1.48653,1.89996,1.95544,1.64950,1.24754,0.93677,1.09267,1.09516,0.94163,0.78966,0.72489,0.50841,0.50909,0.55664,0.63111,0.64125,0.38847,0.33378,0.30918,0.33463,0.20875,0.19298,0.21903,0.13380,0.13444,0.09554], |
|
|
[17127.39844,292.81421,271.45209,66.64056,63.60253,76.35437,28.06587,27.84831,25.96656,23.60370,9.99173,11.34992,14.46955,12.92553,12.69353,5.91537,6.60187,7.90891,7.32825,5.32785,4.29660,2.13459,2.44135,3.66021,4.50335,4.38959,3.34888,2.97181,1.60633,1.77010,2.35118,2.69018,2.38189,1.74596,1.26014,1.31684,0.79327,0.92046,1.17670,1.47609,1.50914,1.28725,0.99898,0.74832,0.85736,0.85800,0.74663,0.63508,0.58748,0.41098,0.41121,0.44663,0.50277,0.51519,0.31729,0.27336,0.25399,0.27241,0.17353,0.16255,0.18440,0.11602,0.11511,0.08450]] |
|
|
|
|
|
|
|
|
def _zigzag_permutation(rows: int, cols: int) -> List[int]: |
|
|
idx_matrix = np.arange(0, rows * cols, 1).reshape(rows, cols).tolist() |
|
|
dia = [[] for _ in range(rows + cols - 1)] |
|
|
zigzag = [] |
|
|
for i in range(rows): |
|
|
for j in range(cols): |
|
|
s = i + j |
|
|
if s % 2 == 0: |
|
|
dia[s].insert(0, idx_matrix[i][j]) |
|
|
else: |
|
|
dia[s].append(idx_matrix[i][j]) |
|
|
for d in dia: |
|
|
zigzag.extend(d) |
|
|
return zigzag |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dct_kernel_type_2( |
|
|
kernel_size: int, |
|
|
orthonormal: bool, |
|
|
device=None, |
|
|
dtype=None, |
|
|
) -> torch.Tensor: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
x = torch.eye(kernel_size, **factory_kwargs) |
|
|
v = x.clone().contiguous().view(-1, kernel_size) |
|
|
v = torch.cat([v, v.flip([1])], dim=-1) |
|
|
v = torch.fft.fft(v, dim=-1)[:, :kernel_size] |
|
|
try: |
|
|
k = torch.tensor(-1j, **factory_kwargs) * torch.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] |
|
|
except: |
|
|
k = torch.tensor(-1j, **factory_kwargs) * math.pi * torch.arange(kernel_size, **factory_kwargs)[None, :] |
|
|
k = torch.exp(k / (kernel_size * 2)) |
|
|
v = v * k |
|
|
v = v.real |
|
|
if orthonormal: |
|
|
v[:, 0] = v[:, 0] * torch.sqrt(torch.tensor(1 / (kernel_size * 4), **factory_kwargs)) |
|
|
v[:, 1:] = v[:, 1:] * torch.sqrt(torch.tensor(1 / (kernel_size * 2), **factory_kwargs)) |
|
|
v = v.contiguous().view(*x.shape) |
|
|
return v |
|
|
|
|
|
|
|
|
def _dct_kernel_type_3( |
|
|
kernel_size: int, |
|
|
orthonormal: bool, |
|
|
device=None, |
|
|
dtype=None, |
|
|
) -> torch.Tensor: |
|
|
return torch.linalg.inv(_dct_kernel_type_2(kernel_size, orthonormal, device, dtype)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DCT1D(nn.Module): |
|
|
|
|
|
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
super(_DCT1D, self).__init__() |
|
|
kernel = {'2': _dct_kernel_type_2, '3': _dct_kernel_type_3} |
|
|
self.weights = nn.Parameter(kernel[f'{kernel_type}'](kernel_size, orthonormal, **factory_kwargs).T, False) |
|
|
self.register_parameter('bias', None) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return nn.functional.linear(x, self.weights, self.bias) |
|
|
|
|
|
|
|
|
class _DCT2D(nn.Module): |
|
|
|
|
|
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
super(_DCT2D, self).__init__() |
|
|
self.transform = _DCT1D(kernel_size, kernel_type, orthonormal, **factory_kwargs) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return self.transform(self.transform(x).transpose(-1, -2)).transpose(-1, -2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Learnable_DCT2D(nn.Module): |
|
|
r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. |
|
|
|
|
|
Args: |
|
|
kernel_size (int): Size of the coefficient kernel |
|
|
kernel_type (int): Type of the DCT (see Notes). Default: 2 |
|
|
orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
super(Learnable_DCT2D, self).__init__() |
|
|
self.k = kernel_size |
|
|
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) |
|
|
self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) |
|
|
self.permutation = _zigzag_permutation(kernel_size, kernel_size) |
|
|
self.Y_Conv = nn.Conv2d(kernel_size**2, 24, kernel_size=1, padding=0) |
|
|
self.Cb_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0) |
|
|
self.Cr_Conv = nn.Conv2d(kernel_size**2, 4, kernel_size=1, padding=0) |
|
|
self.mean = torch.tensor(mean, requires_grad=False) |
|
|
self.var = torch.tensor(var, requires_grad=False) |
|
|
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) |
|
|
self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) |
|
|
def denormalize(self, x): |
|
|
x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 |
|
|
return x |
|
|
|
|
|
def rgb2ycbcr(self, x): |
|
|
y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) |
|
|
cb = 0.564 * (x[:,:,:,2] - y) + 128 |
|
|
cr = 0.713 * (x[:,:,:,0] - y) + 128 |
|
|
x = torch.stack([y, cb, cr],dim=-1) |
|
|
return x |
|
|
|
|
|
def frequncy_normalize(self, x): |
|
|
x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8)) |
|
|
x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8)) |
|
|
x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8)) |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, c, h, w = x.shape |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = self.denormalize(x) |
|
|
x = self.rgb2ycbcr(x) |
|
|
x = x.permute(0, 3, 1, 2) |
|
|
x = self.unfold(x).transpose(-1, -2) |
|
|
x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) |
|
|
x = self.transform(x) |
|
|
x = x.reshape(-1, c, self.k * self.k) |
|
|
x = x[:, :, self.permutation] |
|
|
x = self.frequncy_normalize(x) |
|
|
x = x.reshape(b, h // self.k, w // self.k, c, -1) |
|
|
x = x.permute(0, 3, 4, 1, 2).contiguous() |
|
|
x_Y = self.Y_Conv(x[:, 0, ]) |
|
|
x_Cb = self.Cb_Conv(x[:, 1, ]) |
|
|
x_Cr = self.Cr_Conv(x[:, 2, ]) |
|
|
x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) |
|
|
return x |
|
|
|
|
|
class Static_DCT2D(nn.Module): |
|
|
r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. |
|
|
|
|
|
Args: |
|
|
kernel_size (int): Size of the coefficient kernel |
|
|
kernel_type (int): Type of the DCT (see Notes). Default: 2 |
|
|
orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
super(Static_DCT2D, self).__init__() |
|
|
self.k = kernel_size |
|
|
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) |
|
|
self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) |
|
|
self.permutation = _zigzag_permutation(kernel_size, kernel_size) |
|
|
self.mean = torch.tensor(mean, requires_grad=False) |
|
|
self.var = torch.tensor(var, requires_grad=False) |
|
|
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], requires_grad=False) |
|
|
self.imagenet_var = torch.tensor([0.229, 0.224, 0.225], requires_grad=False) |
|
|
|
|
|
def denormalize(self, x): |
|
|
x = x.multiply(self.imagenet_var.to(x.device)).add_(self.imagenet_mean.to(x.device)) * 255 |
|
|
return x |
|
|
|
|
|
def rgb2ycbcr(self, x): |
|
|
y = (x[:,:,:,0] *0.299) + (x[:,:,:,1]* 0.587) + (x[:,:,:,2] * 0.114) |
|
|
cb = 0.564 * (x[:,:,:,2] - y) + 128 |
|
|
cr = 0.713 * (x[:,:,:,0] - y) + 128 |
|
|
x = torch.stack([y, cb, cr],dim=-1) |
|
|
return x |
|
|
|
|
|
def frequncy_normalize(self, x): |
|
|
x[:, 0, ].sub_(self.mean.to(x.device)[0]).div_((self.var.to(x.device)[0]**0.5+1e-8)) |
|
|
x[:, 1, ].sub_(self.mean.to(x.device)[1]).div_((self.var.to(x.device)[1]**0.5+1e-8)) |
|
|
x[:, 2, ].sub_(self.mean.to(x.device)[2]).div_((self.var.to(x.device)[2]**0.5+1e-8)) |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, c, h, w = x.shape |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = self.denormalize(x) |
|
|
x = self.rgb2ycbcr(x) |
|
|
x = x.permute(0, 3, 1, 2) |
|
|
x = self.unfold(x).transpose(-1, -2) |
|
|
x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) |
|
|
x = self.transform(x) |
|
|
x = x.reshape(-1, c, self.k * self.k) |
|
|
x = x[:, :, self.permutation] |
|
|
x = self.frequncy_normalize(x) |
|
|
x = x.reshape(b, h // self.k, w // self.k, c, -1) |
|
|
x = x.permute(0, 3, 4, 1, 2).contiguous() |
|
|
x_Y = self.Y_Conv(x[:, 0, ]) |
|
|
x_Cb = self.Cb_Conv(x[:, 1, ]) |
|
|
x_Cr = self.Cr_Conv(x[:, 2, ]) |
|
|
x = torch.cat([x_Y, x_Cb, x_Cr], axis=1) |
|
|
return x |
|
|
|
|
|
class DCT2D(nn.Module): |
|
|
r"""Computes the two-dimensional block-wise discrete cosine transform of :attr:`input`. |
|
|
|
|
|
Args: |
|
|
kernel_size (int): Size of the coefficient kernel |
|
|
kernel_type (int): Type of the DCT (see Notes). Default: 2 |
|
|
orthonormal (bool): A boolean makes the corresponding matrix of coefficients orthonormal. Default: True |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, kernel_size: int, kernel_type: int = 2, orthonormal: bool = True, |
|
|
device=None, dtype=None) -> None: |
|
|
factory_kwargs = dict(device=device, dtype=dtype) |
|
|
super(DCT2D, self).__init__() |
|
|
self.k = kernel_size |
|
|
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size)) |
|
|
self.transform = _DCT2D(kernel_size, kernel_type, orthonormal, **factory_kwargs) |
|
|
self.permutation = _zigzag_permutation(kernel_size, kernel_size) |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, c, h, w = x.shape |
|
|
x = self.unfold(x).transpose(-1, -2) |
|
|
x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k) |
|
|
x = self.transform(x) |
|
|
x = x.reshape(-1, c, self.k * self.k) |
|
|
x = x[:, :, self.permutation] |
|
|
x = x.reshape(b*(h // self.k)*(w // self.k), c, -1) |
|
|
|
|
|
|
|
|
mean_list = torch.zeros([3,64]) |
|
|
var_list = torch.zeros([3, 64]) |
|
|
mean_list[0] = torch.mean(x[:, 0, ],axis=0) |
|
|
mean_list[1] = torch.mean(x[:, 1, ], axis=0) |
|
|
mean_list[2] = torch.mean(x[:, 2, ], axis=0) |
|
|
var_list[0] = torch.var(x[:, 0, ],axis=0) |
|
|
var_list[1] = torch.var(x[:, 1, ], axis=0) |
|
|
var_list[2] = torch.var(x[:, 2, ], axis=0) |
|
|
return mean_list, var_list |