|
|
import cv2
|
|
|
import math
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from os import path as osp
|
|
|
from PIL import Image, ImageDraw
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
from basicsr.data.transforms import mod_crop
|
|
|
from basicsr.utils import img2tensor, scandir
|
|
|
|
|
|
|
|
|
def read_img_seq(path, require_mod_crop=False, scale=1):
|
|
|
"""Read a sequence of images from a given folder path.
|
|
|
|
|
|
Args:
|
|
|
path (list[str] | str): List of image paths or image folder path.
|
|
|
require_mod_crop (bool): Require mod crop for each image.
|
|
|
Default: False.
|
|
|
scale (int): Scale factor for mod_crop. Default: 1.
|
|
|
|
|
|
Returns:
|
|
|
Tensor: size (t, c, h, w), RGB, [0, 1].
|
|
|
"""
|
|
|
if isinstance(path, list):
|
|
|
img_paths = path
|
|
|
else:
|
|
|
img_paths = sorted(list(scandir(path, full_path=True)))
|
|
|
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
|
|
if require_mod_crop:
|
|
|
imgs = [mod_crop(img, scale) for img in imgs]
|
|
|
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
|
|
imgs = torch.stack(imgs, dim=0)
|
|
|
return imgs
|
|
|
|
|
|
|
|
|
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
|
|
"""Generate an index list for reading `num_frames` frames from a sequence
|
|
|
of images.
|
|
|
|
|
|
Args:
|
|
|
crt_idx (int): Current center index.
|
|
|
max_frame_num (int): Max number of the sequence of images (from 1).
|
|
|
num_frames (int): Reading num_frames frames.
|
|
|
padding (str): Padding mode, one of
|
|
|
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
|
|
Examples: current_idx = 0, num_frames = 5
|
|
|
The generated frame indices under different padding mode:
|
|
|
replicate: [0, 0, 0, 1, 2]
|
|
|
reflection: [2, 1, 0, 1, 2]
|
|
|
reflection_circle: [4, 3, 0, 1, 2]
|
|
|
circle: [3, 4, 0, 1, 2]
|
|
|
|
|
|
Returns:
|
|
|
list[int]: A list of indices.
|
|
|
"""
|
|
|
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
|
|
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
|
|
|
|
|
max_frame_num = max_frame_num - 1
|
|
|
num_pad = num_frames // 2
|
|
|
|
|
|
indices = []
|
|
|
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
|
|
if i < 0:
|
|
|
if padding == 'replicate':
|
|
|
pad_idx = 0
|
|
|
elif padding == 'reflection':
|
|
|
pad_idx = -i
|
|
|
elif padding == 'reflection_circle':
|
|
|
pad_idx = crt_idx + num_pad - i
|
|
|
else:
|
|
|
pad_idx = num_frames + i
|
|
|
elif i > max_frame_num:
|
|
|
if padding == 'replicate':
|
|
|
pad_idx = max_frame_num
|
|
|
elif padding == 'reflection':
|
|
|
pad_idx = max_frame_num * 2 - i
|
|
|
elif padding == 'reflection_circle':
|
|
|
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
|
|
else:
|
|
|
pad_idx = i - num_frames
|
|
|
else:
|
|
|
pad_idx = i
|
|
|
indices.append(pad_idx)
|
|
|
return indices
|
|
|
|
|
|
|
|
|
def paired_paths_from_lmdb(folders, keys):
|
|
|
"""Generate paired paths from lmdb files.
|
|
|
|
|
|
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
|
|
|
|
|
lq.lmdb
|
|
|
βββ data.mdb
|
|
|
βββ lock.mdb
|
|
|
βββ meta_info.txt
|
|
|
|
|
|
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
|
|
https://lmdb.readthedocs.io/en/release/ for more details.
|
|
|
|
|
|
The meta_info.txt is a specified txt file to record the meta information
|
|
|
of our datasets. It will be automatically created when preparing
|
|
|
datasets by our provided dataset tools.
|
|
|
Each line in the txt file records
|
|
|
1)image name (with extension),
|
|
|
2)image shape,
|
|
|
3)compression level, separated by a white space.
|
|
|
Example: `baboon.png (120,125,3) 1`
|
|
|
|
|
|
We use the image name without extension as the lmdb key.
|
|
|
Note that we use the same key for the corresponding lq and gt images.
|
|
|
|
|
|
Args:
|
|
|
folders (list[str]): A list of folder path. The order of list should
|
|
|
be [input_folder, gt_folder].
|
|
|
keys (list[str]): A list of keys identifying folders. The order should
|
|
|
be in consistent with folders, e.g., ['lq', 'gt'].
|
|
|
Note that this key is different from lmdb keys.
|
|
|
|
|
|
Returns:
|
|
|
list[str]: Returned path list.
|
|
|
"""
|
|
|
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
|
|
f'But got {len(folders)}')
|
|
|
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
|
|
input_folder, gt_folder = folders
|
|
|
input_key, gt_key = keys
|
|
|
|
|
|
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
|
|
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
|
|
f'formats. But received {input_key}: {input_folder}; '
|
|
|
f'{gt_key}: {gt_folder}')
|
|
|
|
|
|
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
|
|
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
|
|
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
|
|
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
|
|
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
|
|
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
|
|
else:
|
|
|
paths = []
|
|
|
for lmdb_key in sorted(input_lmdb_keys):
|
|
|
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
|
|
"""Generate paired paths from an meta information file.
|
|
|
|
|
|
Each line in the meta information file contains the image names and
|
|
|
image shape (usually for gt), separated by a white space.
|
|
|
|
|
|
Example of an meta information file:
|
|
|
```
|
|
|
0001_s001.png (480,480,3)
|
|
|
0001_s002.png (480,480,3)
|
|
|
```
|
|
|
|
|
|
Args:
|
|
|
folders (list[str]): A list of folder path. The order of list should
|
|
|
be [input_folder, gt_folder].
|
|
|
keys (list[str]): A list of keys identifying folders. The order should
|
|
|
be in consistent with folders, e.g., ['lq', 'gt'].
|
|
|
meta_info_file (str): Path to the meta information file.
|
|
|
filename_tmpl (str): Template for each filename. Note that the
|
|
|
template excludes the file extension. Usually the filename_tmpl is
|
|
|
for files in the input folder.
|
|
|
|
|
|
Returns:
|
|
|
list[str]: Returned path list.
|
|
|
"""
|
|
|
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
|
|
f'But got {len(folders)}')
|
|
|
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
|
|
input_folder, gt_folder = folders
|
|
|
input_key, gt_key = keys
|
|
|
|
|
|
with open(meta_info_file, 'r') as fin:
|
|
|
gt_names = [line.split(' ')[0] for line in fin]
|
|
|
|
|
|
paths = []
|
|
|
for gt_name in gt_names:
|
|
|
basename, ext = osp.splitext(osp.basename(gt_name))
|
|
|
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
|
|
input_path = osp.join(input_folder, input_name)
|
|
|
gt_path = osp.join(gt_folder, gt_name)
|
|
|
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
|
|
"""Generate paired paths from folders.
|
|
|
|
|
|
Args:
|
|
|
folders (list[str]): A list of folder path. The order of list should
|
|
|
be [input_folder, gt_folder].
|
|
|
keys (list[str]): A list of keys identifying folders. The order should
|
|
|
be in consistent with folders, e.g., ['lq', 'gt'].
|
|
|
filename_tmpl (str): Template for each filename. Note that the
|
|
|
template excludes the file extension. Usually the filename_tmpl is
|
|
|
for files in the input folder.
|
|
|
|
|
|
Returns:
|
|
|
list[str]: Returned path list.
|
|
|
"""
|
|
|
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
|
|
f'But got {len(folders)}')
|
|
|
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
|
|
input_folder, gt_folder = folders
|
|
|
input_key, gt_key = keys
|
|
|
|
|
|
input_paths = list(scandir(input_folder))
|
|
|
gt_paths = list(scandir(gt_folder))
|
|
|
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
|
|
f'{len(input_paths)}, {len(gt_paths)}.')
|
|
|
paths = []
|
|
|
for gt_path in gt_paths:
|
|
|
basename, ext = osp.splitext(osp.basename(gt_path))
|
|
|
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
|
|
input_path = osp.join(input_folder, input_name)
|
|
|
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
|
|
|
gt_path = osp.join(gt_folder, gt_path)
|
|
|
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def paths_from_folder(folder):
|
|
|
"""Generate paths from folder.
|
|
|
|
|
|
Args:
|
|
|
folder (str): Folder path.
|
|
|
|
|
|
Returns:
|
|
|
list[str]: Returned path list.
|
|
|
"""
|
|
|
|
|
|
paths = list(scandir(folder))
|
|
|
paths = [osp.join(folder, path) for path in paths]
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def paths_from_lmdb(folder):
|
|
|
"""Generate paths from lmdb.
|
|
|
|
|
|
Args:
|
|
|
folder (str): Folder path.
|
|
|
|
|
|
Returns:
|
|
|
list[str]: Returned path list.
|
|
|
"""
|
|
|
if not folder.endswith('.lmdb'):
|
|
|
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
|
|
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
|
|
paths = [line.split('.')[0] for line in fin]
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
|
|
"""Generate Gaussian kernel used in `duf_downsample`.
|
|
|
|
|
|
Args:
|
|
|
kernel_size (int): Kernel size. Default: 13.
|
|
|
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
|
|
|
|
|
Returns:
|
|
|
np.array: The Gaussian kernel.
|
|
|
"""
|
|
|
from scipy.ndimage import filters as filters
|
|
|
kernel = np.zeros((kernel_size, kernel_size))
|
|
|
|
|
|
kernel[kernel_size // 2, kernel_size // 2] = 1
|
|
|
|
|
|
return filters.gaussian_filter(kernel, sigma)
|
|
|
|
|
|
|
|
|
def duf_downsample(x, kernel_size=13, scale=4):
|
|
|
"""Downsamping with Gaussian kernel used in the DUF official code.
|
|
|
|
|
|
Args:
|
|
|
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
|
|
kernel_size (int): Kernel size. Default: 13.
|
|
|
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
|
|
Default: 4.
|
|
|
|
|
|
Returns:
|
|
|
Tensor: DUF downsampled frames.
|
|
|
"""
|
|
|
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
|
|
|
|
|
squeeze_flag = False
|
|
|
if x.ndim == 4:
|
|
|
squeeze_flag = True
|
|
|
x = x.unsqueeze(0)
|
|
|
b, t, c, h, w = x.size()
|
|
|
x = x.view(-1, 1, h, w)
|
|
|
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
|
|
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
|
|
|
|
|
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
|
|
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
|
|
x = F.conv2d(x, gaussian_filter, stride=scale)
|
|
|
x = x[:, :, 2:-2, 2:-2]
|
|
|
x = x.view(b, t, c, x.size(2), x.size(3))
|
|
|
if squeeze_flag:
|
|
|
x = x.squeeze(0)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def brush_stroke_mask(img, color=(255,255,255)):
|
|
|
min_num_vertex = 8
|
|
|
max_num_vertex = 28
|
|
|
mean_angle = 2*math.pi / 5
|
|
|
angle_range = 2*math.pi / 12
|
|
|
|
|
|
min_width = 30
|
|
|
max_width = 70
|
|
|
|
|
|
|
|
|
|
|
|
def generate_mask(H, W, img=None):
|
|
|
average_radius = math.sqrt(H*H+W*W) / 8
|
|
|
mask = Image.new('RGB', (W, H), 0)
|
|
|
if img is not None: mask = img
|
|
|
|
|
|
for _ in range(np.random.randint(1, 4)):
|
|
|
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
|
|
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
|
|
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
|
|
angles = []
|
|
|
vertex = []
|
|
|
for i in range(num_vertex):
|
|
|
if i % 2 == 0:
|
|
|
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
|
|
else:
|
|
|
angles.append(np.random.uniform(angle_min, angle_max))
|
|
|
|
|
|
h, w = mask.size
|
|
|
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
|
|
for i in range(num_vertex):
|
|
|
r = np.clip(
|
|
|
np.random.normal(loc=average_radius, scale=average_radius//2),
|
|
|
0, 2*average_radius)
|
|
|
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
|
|
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
|
|
vertex.append((int(new_x), int(new_y)))
|
|
|
|
|
|
draw = ImageDraw.Draw(mask)
|
|
|
width = int(np.random.uniform(min_width, max_width))
|
|
|
draw.line(vertex, fill=color, width=width)
|
|
|
for v in vertex:
|
|
|
draw.ellipse((v[0] - width//2,
|
|
|
v[1] - width//2,
|
|
|
v[0] + width//2,
|
|
|
v[1] + width//2),
|
|
|
fill=color)
|
|
|
|
|
|
return mask
|
|
|
|
|
|
width, height = img.size
|
|
|
mask = generate_mask(height, width, img)
|
|
|
return mask
|
|
|
|
|
|
|
|
|
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
|
|
|
"""Generate a random free form mask with configuration.
|
|
|
Args:
|
|
|
config: Config should have configuration including IMG_SHAPES,
|
|
|
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
|
|
|
Returns:
|
|
|
tuple: (top, left, height, width)
|
|
|
Link:
|
|
|
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
|
|
|
"""
|
|
|
height = shape[0]
|
|
|
width = shape[1]
|
|
|
mask = np.zeros((height, width), np.float32)
|
|
|
times = np.random.randint(times-5, times)
|
|
|
for i in range(times):
|
|
|
start_x = np.random.randint(width)
|
|
|
start_y = np.random.randint(height)
|
|
|
for j in range(1 + np.random.randint(5)):
|
|
|
angle = 0.01 + np.random.randint(max_angle)
|
|
|
if i % 2 == 0:
|
|
|
angle = 2 * 3.1415926 - angle
|
|
|
length = 10 + np.random.randint(max_len-20, max_len)
|
|
|
brush_w = 5 + np.random.randint(max_width-30, max_width)
|
|
|
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
|
|
|
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
|
|
|
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
|
|
|
start_x, start_y = end_x, end_y
|
|
|
return mask.astype(np.float32) |