import os import torch import torch.nn as nn import numpy as np import torch.nn.functional as F def DCT_mat(size): m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)] return m def generate_filter(start, end, size): return [[0. if i + j > end or i + j < start else 1. for j in range(size)] for i in range(size)] def norm_sigma(x): return 2. * torch.sigmoid(x) - 1. class Filter(nn.Module): def __init__(self, size, band_start, band_end, use_learnable=False, norm=False): super(Filter, self).__init__() self.use_learnable = use_learnable self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False) if self.use_learnable: self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True) self.learnable.data.normal_(0., 0.1) self.norm = norm if norm: self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False) def forward(self, x): if self.use_learnable: filt = self.base + norm_sigma(self.learnable) else: filt = self.base if self.norm: y = x * filt / self.ft_num else: y = x * filt return y class DCT_base_Rec_Module(nn.Module): """_summary_ Args: x: [C, H, W] -> [C*level, output, output] """ def __init__(self, window_size=32, stride=16, output=256, grade_N=6, level_fliter=[0]): super().__init__() assert output % window_size == 0 assert len(level_fliter) > 0 self.window_size = window_size self.grade_N = grade_N self.level_N = len(level_fliter) self.N = (output // window_size) * (output // window_size) self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False) self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False) self.unfold = nn.Unfold( kernel_size=(window_size, window_size), stride=stride ) self.fold0 = nn.Fold( output_size=(window_size, window_size), kernel_size=(window_size, window_size), stride=window_size ) lm, mh = 2.82, 2 level_f = [ Filter(window_size, 0, window_size * 2) ] self.level_filters = nn.ModuleList([level_f[i] for i in level_fliter]) self.grade_filters = nn.ModuleList([Filter(window_size, window_size * 2. / grade_N * i, window_size * 2. / grade_N * (i+1), norm=True) for i in range(grade_N)]) def forward(self, x): N = self.N grade_N = self.grade_N level_N = self.level_N window_size = self.window_size C, W, H = x.shape x_unfold = self.unfold(x.unsqueeze(0)).squeeze(0) _, L = x_unfold.shape x_unfold = x_unfold.transpose(0, 1).reshape(L, C, window_size, window_size) x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T y_list = [] for i in range(self.level_N): x_pass = self.level_filters[i](x_dct) y = self._DCT_patch_T @ x_pass @ self._DCT_patch y_list.append(y) level_x_unfold = torch.cat(y_list, dim=1) grade = torch.zeros(L).to(x.device) w, k = 1, 2 for _ in range(grade_N): _x = torch.abs(x_dct) _x = torch.log(_x + 1) _x = self.grade_filters[_](_x) _x = torch.sum(_x, dim=[1,2,3]) grade += w * _x w *= k _, idx = torch.sort(grade) max_idx = torch.flip(idx, dims=[0])[:N] maxmax_idx = max_idx[0] if len(max_idx) == 1: maxmax_idx1 = max_idx[0] else: maxmax_idx1 = max_idx[1] min_idx = idx[:N] minmin_idx = idx[0] if len(min_idx) == 1: minmin_idx1 = idx[0] else: minmin_idx1 = idx[1] x_minmin = torch.index_select(level_x_unfold, 0, minmin_idx) x_maxmax = torch.index_select(level_x_unfold, 0, maxmax_idx) x_minmin1 = torch.index_select(level_x_unfold, 0, minmin_idx1) x_maxmax1 = torch.index_select(level_x_unfold, 0, maxmax_idx1) x_minmin = x_minmin.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) x_maxmax = x_maxmax.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) x_minmin1 = x_minmin1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) x_maxmax1 = x_maxmax1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1) x_minmin = self.fold0(x_minmin) x_maxmax = self.fold0(x_maxmax) x_minmin1 = self.fold0(x_minmin1) x_maxmax1 = self.fold0(x_maxmax1) return x_minmin, x_maxmax, x_minmin1, x_maxmax1