File size: 5,150 Bytes
dd33601 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import os
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
def DCT_mat(size):
m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)]
return m
def generate_filter(start, end, size):
return [[0. if i + j > end or i + j < start else 1. for j in range(size)] for i in range(size)]
def norm_sigma(x):
return 2. * torch.sigmoid(x) - 1.
class Filter(nn.Module):
def __init__(self, size, band_start, band_end, use_learnable=False, norm=False):
super(Filter, self).__init__()
self.use_learnable = use_learnable
self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False)
if self.use_learnable:
self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True)
self.learnable.data.normal_(0., 0.1)
self.norm = norm
if norm:
self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False)
def forward(self, x):
if self.use_learnable:
filt = self.base + norm_sigma(self.learnable)
else:
filt = self.base
if self.norm:
y = x * filt / self.ft_num
else:
y = x * filt
return y
class DCT_base_Rec_Module(nn.Module):
"""_summary_
Args:
x: [C, H, W] -> [C*level, output, output]
"""
def __init__(self, window_size=32, stride=16, output=256, grade_N=6, level_fliter=[0]):
super().__init__()
assert output % window_size == 0
assert len(level_fliter) > 0
self.window_size = window_size
self.grade_N = grade_N
self.level_N = len(level_fliter)
self.N = (output // window_size) * (output // window_size)
self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False)
self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False)
self.unfold = nn.Unfold(
kernel_size=(window_size, window_size), stride=stride
)
self.fold0 = nn.Fold(
output_size=(window_size, window_size),
kernel_size=(window_size, window_size),
stride=window_size
)
lm, mh = 2.82, 2
level_f = [
Filter(window_size, 0, window_size * 2)
]
self.level_filters = nn.ModuleList([level_f[i] for i in level_fliter])
self.grade_filters = nn.ModuleList([Filter(window_size, window_size * 2. / grade_N * i, window_size * 2. / grade_N * (i+1), norm=True) for i in range(grade_N)])
def forward(self, x):
N = self.N
grade_N = self.grade_N
level_N = self.level_N
window_size = self.window_size
C, W, H = x.shape
x_unfold = self.unfold(x.unsqueeze(0)).squeeze(0)
_, L = x_unfold.shape
x_unfold = x_unfold.transpose(0, 1).reshape(L, C, window_size, window_size)
x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T
y_list = []
for i in range(self.level_N):
x_pass = self.level_filters[i](x_dct)
y = self._DCT_patch_T @ x_pass @ self._DCT_patch
y_list.append(y)
level_x_unfold = torch.cat(y_list, dim=1)
grade = torch.zeros(L).to(x.device)
w, k = 1, 2
for _ in range(grade_N):
_x = torch.abs(x_dct)
_x = torch.log(_x + 1)
_x = self.grade_filters[_](_x)
_x = torch.sum(_x, dim=[1,2,3])
grade += w * _x
w *= k
_, idx = torch.sort(grade)
max_idx = torch.flip(idx, dims=[0])[:N]
maxmax_idx = max_idx[0]
if len(max_idx) == 1:
maxmax_idx1 = max_idx[0]
else:
maxmax_idx1 = max_idx[1]
min_idx = idx[:N]
minmin_idx = idx[0]
if len(min_idx) == 1:
minmin_idx1 = idx[0]
else:
minmin_idx1 = idx[1]
x_minmin = torch.index_select(level_x_unfold, 0, minmin_idx)
x_maxmax = torch.index_select(level_x_unfold, 0, maxmax_idx)
x_minmin1 = torch.index_select(level_x_unfold, 0, minmin_idx1)
x_maxmax1 = torch.index_select(level_x_unfold, 0, maxmax_idx1)
x_minmin = x_minmin.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
x_maxmax = x_maxmax.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
x_minmin1 = x_minmin1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
x_maxmax1 = x_maxmax1.reshape(1, level_N*C*window_size* window_size).transpose(0, 1)
x_minmin = self.fold0(x_minmin)
x_maxmax = self.fold0(x_maxmax)
x_minmin1 = self.fold0(x_minmin1)
x_maxmax1 = self.fold0(x_maxmax1)
return x_minmin, x_maxmax, x_minmin1, x_maxmax1
|