| import os |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.nn.functional as F |
|
|
|
|
| def DCT_mat(size): |
| m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)] |
| return m |
|
|
| def generate_filter(start, end, size): |
| return [[0. if i + j > end or i + j < start else 1. for j in range(size)] for i in range(size)] |
|
|
| def norm_sigma(x): |
| return 2. * torch.sigmoid(x) - 1. |
|
|
| class Filter(nn.Module): |
| def __init__(self, size, band_start, band_end, use_learnable=False, norm=False): |
| super(Filter, self).__init__() |
| self.use_learnable = use_learnable |
|
|
| self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False) |
| if self.use_learnable: |
| self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True) |
| self.learnable.data.normal_(0., 0.1) |
| self.norm = norm |
| if norm: |
| self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False) |
|
|
|
|
| def forward(self, x): |
| if self.use_learnable: |
| filt = self.base + norm_sigma(self.learnable) |
| else: |
| filt = self.base |
|
|
| if self.norm: |
| y = x * filt / self.ft_num |
| else: |
| y = x * filt |
| return y |
|
|
| class DCT_base_Rec_Module(nn.Module): |
| """_summary_ |
| |
| Args: |
| x: [C, H, W] -> [C*level, output, output] |
| """ |
| def __init__(self, window_size=32, stride=16, output=256, grade_N=6, level_fliter=[0]): |
| super().__init__() |
| |
| assert output % window_size == 0 |
| assert len(level_fliter) > 0 |
| |
| self.window_size = window_size |
| self.grade_N = grade_N |
| self.level_N = len(level_fliter) |
| self.N = (output // window_size) * (output // window_size) |
| |
| self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False) |
| self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False) |
| |
| self.unfold = nn.Unfold( |
| kernel_size=(window_size, window_size), stride=stride |
| ) |
| self.fold0 = nn.Fold( |
| output_size=(window_size, window_size), |
| kernel_size=(window_size, window_size), |
| stride=window_size |
| ) |
| |
| lm, mh = 2.82, 2 |
| level_f = [ |
| Filter(window_size, 0, window_size * 2) |
| ] |
| |
| self.level_filters = nn.ModuleList([level_f[i] for i in level_fliter]) |
| self.grade_filters = nn.ModuleList([Filter(window_size, window_size * 2. / grade_N * i, window_size * 2. / grade_N * (i+1), norm=True) for i in range(grade_N)]) |
| |
| |
| def forward(self, x): |
| |
| N = self.N |
| grade_N = self.grade_N |
| level_N = self.level_N |
| window_size = self.window_size |
| C, W, H = x.shape |
| x_unfold = self.unfold(x.unsqueeze(0)).squeeze(0) |
| |
|
|
| _, L = x_unfold.shape |
| x_unfold = x_unfold.transpose(0, 1).reshape(L, C, window_size, window_size) |
| x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T |
| |
| y_list = [] |
| for i in range(self.level_N): |
| x_pass = self.level_filters[i](x_dct) |
| y = self._DCT_patch_T @ x_pass @ self._DCT_patch |
| y_list.append(y) |
| level_x_unfold = torch.cat(y_list, dim=1) |
| |
| grade = torch.zeros(L).to(x.device) |
| w, k = 1, 2 |
| for _ in range(grade_N): |
| _x = torch.abs(x_dct) |
| _x = torch.log(_x + 1) |
| _x = self.grade_filters[_](_x) |
| _x = torch.sum(_x, dim=[1,2,3]) |
| grade += w * _x |
| w *= k |
| |
| _, idx = torch.sort(grade) |
| max_idx = torch.flip(idx, dims=[0])[:N] |
| maxmax_idx = max_idx[0] |
| if len(max_idx) == 1: |
| maxmax_idx1 = max_idx[0] |
| else: |
| maxmax_idx1 = max_idx[1] |
|
|
| min_idx = idx[:N] |
| minmin_idx = idx[0] |
| if len(min_idx) == 1: |
| minmin_idx1 = idx[0] |
| else: |
| minmin_idx1 = idx[1] |
|
|
| x_minmin = torch.index_select(level_x_unfold, 0, minmin_idx) |
| x_maxmax = torch.index_select(level_x_unfold, 0, maxmax_idx) |
| x_minmin1 = torch.index_select(level_x_unfold, 0, minmin_idx1) |
| x_maxmax1 = torch.index_select(level_x_unfold, 0, maxmax_idx1) |
|
|
| x_minmin = x_minmin.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) |
| x_maxmax = x_maxmax.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) |
| x_minmin1 = x_minmin1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) |
| x_maxmax1 = x_maxmax1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) |
|
|
| x_minmin = self.fold0(x_minmin) |
| x_maxmax = self.fold0(x_maxmax) |
| x_minmin1 = self.fold0(x_minmin1) |
| x_maxmax1 = self.fold0(x_maxmax1) |
|
|
| |
| return x_minmin, x_maxmax, x_minmin1, x_maxmax1 |
|
|
|
|
| |
|
|
|
|
|
|