llir / win_util.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
import torch
import torch.nn.functional as F
import torch.nn as nn
def check_image_size(x, padder_size, mode='reflect'):
_, _, h, w = x.size()
if isinstance(padder_size, int):
padder_size_h = padder_size
padder_size_w = padder_size
else:
padder_size_h, padder_size_w = padder_size
mod_pad_h = (padder_size_h - h % padder_size_h) % padder_size_h
mod_pad_w = (padder_size_w - w % padder_size_w) % padder_size_w
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), mode=mode)
return x
def window_partitions(x, window_size):
"""
Args:
x: (B, C, H, W)
window_size (int): window size
Returns:
windows: (num_windows*B, C, window_size, window_size)
"""
if isinstance(window_size, int):
window_size = [window_size, window_size]
B, C, H, W = x.shape
x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
return windows
def window_reverses(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, C, window_size, window_size)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, C, H, W)
"""
# B = int(windows.shape[0] / (H * W / window_size / window_size))
# print('B: ', B)
# print(H // window_size)
# print(W // window_size)
if isinstance(window_size, int):
window_size = [window_size, window_size]
C = windows.shape[1]
# print('C: ', C)
x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
return x
def window_partitionx(x, window_size):
_, _, H, W = x.shape
h, w = window_size * (H // window_size), window_size * (W // window_size)
x_main = window_partitions(x[:, :, :h, :w], window_size)
b_main = x_main.shape[0]
if h == H and w == W:
return x_main, [b_main]
if h != H and w != W:
x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
b_r = x_r.shape[0] + b_main
x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
b_d = x_d.shape[0] + b_r
x_dd = x[:, :, -window_size:, -window_size:]
b_dd = x_dd.shape[0] + b_d
# batch_list = [b_main, b_r, b_d, b_dd]
return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd]
if h == H and w != W:
x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
b_r = x_r.shape[0] + b_main
return torch.cat([x_main, x_r], dim=0), [b_main, b_r]
if h != H and w == W:
x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
b_d = x_d.shape[0] + b_main
return torch.cat([x_main, x_d], dim=0), [b_main, b_d]
def window_reversex(windows, window_size, H, W, batch_list):
h, w = window_size * (H // window_size), window_size * (W // window_size)
# print(windows[:batch_list[0], ...].shape)
x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w)
B, C, _, _ = x_main.shape
# print('windows: ', windows.shape)
# print('batch_list: ', batch_list)
if torch.is_complex(windows):
res = torch.complex(torch.zeros([B, C, H, W]), torch.zeros([B, C, H, W]))
res = res.to(windows.device)
else:
res = torch.zeros([B, C, H, W], device=windows.device)
res[:, :, :h, :w] = x_main
if h == H and w == W:
return res
if h != H and w != W and len(batch_list) == 4:
x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size)
res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:]
x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
res[:, :, :h, w:] = x_r[:, :, :, w - W:]
x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w)
res[:, :, h:, :w] = x_d[:, :, h - H:, :]
return res
if w != W and len(batch_list) == 2:
x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
res[:, :, :h, w:] = x_r[:, :, :, w - W:]
if h != H and len(batch_list) == 2:
x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w)
res[:, :, h:, :w] = x_d[:, :, h - H:, :]
return res
def window_partitionxy(x, window_size, start=[0, 0]):
s_h, s_w = start
assert 0 <= s_h < window_size and 0 <= s_w < window_size
_, _, H, W = x.shape
h, w = window_size * (H // window_size), window_size * (W // window_size)
x_main, b_main = window_partitionx(x[:, :, s_h:, s_w:], window_size)
# print(x_main.shape, b_main, x[:, :, s_h:, s_w:].shape)
if s_h == 0 and s_w == 0:
return x_main, b_main
if s_h != 0 and s_w != 0:
x_l = window_partitions(x[:, :, -h:, :window_size], window_size)
b_l = x_l.shape[0] + b_main[-1]
b_main.append(b_l)
x_u = window_partitions(x[:, :, :window_size, -w:], window_size)
b_u = x_u.shape[0] + b_l
b_main.append(b_u)
x_uu = x[:, :, :window_size, :window_size]
b_uu = x_uu.shape[0] + b_u
b_main.append(b_uu)
# batch_list = [b_main, b_r, b_d, b_dd]
return torch.cat([x_main, x_l, x_u, x_uu], dim=0), b_main
def window_reversexy(windows, window_size, H, W, batch_list, start=[0, 0]):
s_h, s_w = start
assert 0 <= s_h < window_size and 0 <= s_w < window_size
if s_h == 0 and s_w == 0:
x_main = window_reversex(windows, window_size, H, W, batch_list)
return x_main
else:
h, w = window_size * (H // window_size), window_size * (W // window_size)
# print(windows[:batch_list[-4], ...].shape, batch_list[:-3], H-s_h, W-s_w)
x_main = window_reversex(windows[:batch_list[-4], ...], window_size, H-s_h, W-s_w, batch_list[:-3])
B, C, _, _ = x_main.shape
res = torch.zeros([B, C, H, W], device=windows.device)
x_uu = window_reverses(windows[batch_list[-2]:, ...], window_size, window_size, window_size)
res[:, :, :window_size, :window_size] = x_uu[:, :, :, :]
x_l = window_reverses(windows[batch_list[-4]:batch_list[-3], ...], window_size, h, window_size)
res[:, :, -h:, :window_size] = x_l
x_u = window_reverses(windows[batch_list[-3]:batch_list[-2], ...], window_size, window_size, w)
res[:, :, :window_size, -w:] = x_u[:, :, :, :]
res[:, :, s_h:, s_w:] = x_main
return res
class WindowPartition(nn.Module):
def __init__(self, window_size=8, shift_size=0):
super().__init__()
self.window_size = window_size
self.shift_size = shift_size
def forward(self, x):
H, W = x.shape[-2:]
if self.window_size is not None and (H > self.window_size and W > self.window_size):
if not self.shift_size:
x, batch_list = window_partitionx(x, self.window_size)
return x, batch_list
else:
x, batch_list = window_partitionxy(x, self.window_size, [self.shift_size, self.shift_size])
return x, batch_list
else:
return x, []
class WindowReverse(nn.Module):
def __init__(self, window_size=8, shift_size=0):
super().__init__()
self.window_size = window_size
self.shift_size = shift_size
def forward(self, x, H, W, batch_list):
# print(x.shape, batch_list)
if len(batch_list) > 0 and self.window_size is not None and (H > self.window_size and W > self.window_size):
if not self.shift_size:
x = window_reversex(x, self.window_size, H, W, batch_list)
else:
x = window_reversexy(x, self.window_size, H, W, batch_list, [self.shift_size, self.shift_size])
return x