hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
## Restormer: Efficient Transformer for High-Resolution Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
## https://arxiv.org/abs/2111.09881
##--------------------------------------------------------------
##------- Demo file to test Restormer on your own images---------
## Example usage on directory containing several images: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/' --result_dir './demo/restored/'
## Example usage on a image directly: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/'
## Example usage with tile option on a large image: python demo.py --task Single_Image_Defocus_Deblurring --input_dir './demo/degraded/portrait.jpg' --result_dir './demo/restored/' --tile 720 --tile_overlap 32
##--------------------------------------------------------------
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import os
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
from pdb import set_trace as stx
import numpy as np
parser = argparse.ArgumentParser(description='Test Restormer on your own images')
parser.add_argument('--input_dir', default='./demo/degraded/', type=str, help='Directory of input images or path of single image')
parser.add_argument('--result_dir', default='./demo/restored/', type=str, help='Directory for restored results')
parser.add_argument('--task', required=True, type=str, help='Task to run', choices=['Motion_Deblurring',
'Single_Image_Defocus_Deblurring',
'Deraining',
'Real_Denoising',
'Gaussian_Gray_Denoising',
'Gaussian_Color_Denoising'])
parser.add_argument('--tile', type=int, default=None, help='Tile size (e.g 720). None means testing on the original resolution image')
parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
args = parser.parse_args()
def load_img(filepath):
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_gray_img(filepath):
return np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
def save_gray_img(filepath, img):
cv2.imwrite(filepath, img)
def get_weights_and_parameters(task, parameters):
if task == 'Motion_Deblurring':
weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
elif task == 'Single_Image_Defocus_Deblurring':
weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
elif task == 'Deraining':
weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
elif task == 'Real_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
parameters['LayerNorm_type'] = 'BiasFree'
elif task == 'Gaussian_Color_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_color_denoising_blind.pth')
parameters['LayerNorm_type'] = 'BiasFree'
elif task == 'Gaussian_Gray_Denoising':
weights = os.path.join('Denoising', 'pretrained_models', 'gaussian_gray_denoising_blind.pth')
parameters['inp_channels'] = 1
parameters['out_channels'] = 1
parameters['LayerNorm_type'] = 'BiasFree'
return weights, parameters
task = args.task
inp_dir = args.input_dir
out_dir = os.path.join(args.result_dir, task)
os.makedirs(out_dir, exist_ok=True)
extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']
if any([inp_dir.endswith(ext) for ext in extensions]):
files = [inp_dir]
else:
files = []
for ext in extensions:
files.extend(glob(os.path.join(inp_dir, '*.'+ext)))
files = natsorted(files)
if len(files) == 0:
raise Exception(f'No files found at {inp_dir}')
# Get model weights and parameters
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
weights, parameters = get_weights_and_parameters(task, parameters)
load_arch = run_path(os.path.join('basicsr', 'models', 'archs', 'restormer_arch.py'))
model = load_arch['Restormer'](**parameters)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
checkpoint = torch.load(weights)
model.load_state_dict(checkpoint['params'])
model.eval()
img_multiple_of = 8
print(f"\n ==> Running {task} with weights {weights}\n ")
with torch.no_grad():
for file_ in tqdm(files):
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
if task == 'Gaussian_Gray_Denoising':
img = load_gray_img(file_)
else:
img = load_img(file_)
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
# Pad the input if not_multiple_of 8
height,width = input_.shape[2], input_.shape[3]
H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-height if height%img_multiple_of!=0 else 0
padw = W-width if width%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
if args.tile is None:
## Testing on the original resolution image
restored = model(input_)
else:
# test the image tile by tile
b, c, h, w = input_.shape
tile = min(args.tile, h, w)
assert tile % 8 == 0, "tile size should be multiple of 8"
tile_overlap = args.tile_overlap
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E = torch.zeros(b, c, h, w).type_as(input_)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch)
W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask)
restored = E.div_(W)
restored = torch.clamp(restored, 0, 1)
# Unpad the output
restored = restored[:,:,:height,:width]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
f = os.path.splitext(os.path.split(file_)[-1])[0]
# stx()
if task == 'Gaussian_Gray_Denoising':
save_gray_img((os.path.join(out_dir, f+'.png')), restored)
else:
save_img((os.path.join(out_dir, f+'.png')), restored)
print(f"\nRestored images are saved at {out_dir}")