hugaagg's picture
Upload folder using huggingface_hub
2ecc7ab verified
import argparse
import subprocess
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
import os
import torch.nn as nn
from utils.dataset_utils import DenoiseTestDataset, DerainDehazeDataset
from utils.val_utils import AverageMeter, compute_psnr_ssim
from utils.image_io import save_image_tensor
from net.model import PromptIR
import lightning.pytorch as pl
import torch.nn.functional as F
class PromptIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = PromptIR(decoder=True)
self.loss_fn = nn.L1Loss()
def forward(self,x):
return self.net(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
([clean_name, de_id], degrad_patch, clean_patch) = batch
restored = self.net(degrad_patch)
loss = self.loss_fn(restored,clean_patch)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def lr_scheduler_step(self,scheduler,metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=2e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15,max_epochs=150)
return [optimizer],[scheduler]
def test_Denoise(net, dataset, sigma=15):
output_path = testopt.output_path + 'denoise/' + str(sigma) + '/'
subprocess.check_output(['mkdir', '-p', output_path])
dataset.set_sigma(sigma)
testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
psnr = AverageMeter()
ssim = AverageMeter()
with torch.no_grad():
for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader):
degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
restored = net(degrad_patch)
temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch)
psnr.update(temp_psnr, N)
ssim.update(temp_ssim, N)
save_image_tensor(restored, output_path + clean_name[0] + '.png')
print("Denoise sigma=%d: psnr: %.2f, ssim: %.4f" % (sigma, psnr.avg, ssim.avg))
def test_Derain_Dehaze(net, dataset, task="derain"):
output_path = testopt.output_path + task + '/'
subprocess.check_output(['mkdir', '-p', output_path])
dataset.set_dataset(task)
testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
psnr = AverageMeter()
ssim = AverageMeter()
with torch.no_grad():
for ([degraded_name], degrad_patch, clean_patch) in tqdm(testloader):
degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
restored = net(degrad_patch)
temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch)
psnr.update(temp_psnr, N)
ssim.update(temp_ssim, N)
save_image_tensor(restored, output_path + degraded_name[0] + '.png')
print("PSNR: %.2f, SSIM: %.4f" % (psnr.avg, ssim.avg))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input Parameters
parser.add_argument('--cuda', type=int, default=0)
parser.add_argument('--mode', type=int, default=0,
help='0 for denoise, 1 for derain, 2 for dehaze, 3 for all-in-one')
parser.add_argument('--denoise_path', type=str, default="test/denoise/", help='save path of test noisy images')
parser.add_argument('--derain_path', type=str, default="test/derain/", help='save path of test raining images')
parser.add_argument('--dehaze_path', type=str, default="test/dehaze/", help='save path of test hazy images')
parser.add_argument('--output_path', type=str, default="output/", help='output save path')
parser.add_argument('--ckpt_name', type=str, default="model.ckpt", help='checkpoint save path')
testopt = parser.parse_args()
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(testopt.cuda)
ckpt_path = "ckpt/" + testopt.ckpt_name
denoise_splits = ["bsd68/"]
derain_splits = ["Rain100L/"]
denoise_tests = []
derain_tests = []
base_path = testopt.denoise_path
for i in denoise_splits:
testopt.denoise_path = os.path.join(base_path,i)
denoise_testset = DenoiseTestDataset(testopt)
denoise_tests.append(denoise_testset)
print("CKPT name : {}".format(ckpt_path))
net = PromptIRModel().load_from_checkpoint(ckpt_path).cuda()
net.eval()
if testopt.mode == 0:
for testset,name in zip(denoise_tests,denoise_splits) :
print('Start {} testing Sigma=15...'.format(name))
test_Denoise(net, testset, sigma=15)
print('Start {} testing Sigma=25...'.format(name))
test_Denoise(net, testset, sigma=25)
print('Start {} testing Sigma=50...'.format(name))
test_Denoise(net, testset, sigma=50)
elif testopt.mode == 1:
print('Start testing rain streak removal...')
derain_base_path = testopt.derain_path
for name in derain_splits:
print('Start testing {} rain streak removal...'.format(name))
testopt.derain_path = os.path.join(derain_base_path,name)
derain_set = DerainDehazeDataset(opt,addnoise=False,sigma=15)
test_Derain_Dehaze(net, derain_set, task="derain")
elif testopt.mode == 2:
print('Start testing SOTS...')
derain_base_path = testopt.derain_path
name = derain_splits[0]
testopt.derain_path = os.path.join(derain_base_path,name)
derain_set = DerainDehazeDataset(testopt,addnoise=False,sigma=15)
test_Derain_Dehaze(net, derain_set, task="SOTS_outdoor")
elif testopt.mode == 3:
for testset,name in zip(denoise_tests,denoise_splits) :
print('Start {} testing Sigma=15...'.format(name))
test_Denoise(net, testset, sigma=15)
print('Start {} testing Sigma=25...'.format(name))
test_Denoise(net, testset, sigma=25)
print('Start {} testing Sigma=50...'.format(name))
test_Denoise(net, testset, sigma=50)
derain_base_path = testopt.derain_path
print(derain_splits)
for name in derain_splits:
print('Start testing {} rain streak removal...'.format(name))
testopt.derain_path = os.path.join(derain_base_path,name)
derain_set = DerainDehazeDataset(testopt,addnoise=False,sigma=15)
test_Derain_Dehaze(net, derain_set, task="derain")
print('Start testing SOTS...')
test_Derain_Dehaze(net, derain_set, task="dehaze")