| import os |
| import sys |
| from basicsr.models.losses import SWTLoss, SWTLossRGB |
|
|
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| |
| from config import Config |
| opt = Config('training.yml') |
| os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in opt.GPU]) |
|
|
|
|
| |
| if "RANK" not in os.environ and "LOCAL_RANK" not in os.environ and len(opt.GPU) > 1: |
| import subprocess |
| env = os.environ.copy() |
| cmd = [sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", str(len(opt.GPU)), sys.argv[0]] + sys.argv[1:] |
| sys.exit(subprocess.run(cmd, env=env).returncode) |
|
|
| import torch |
| torch.backends.cudnn.benchmark = True |
| import utils as utils |
| from models.encoder2 import Convres |
| from restormer import ChannelShuffleWithGBPDeep |
| from torchvision.transforms import transforms |
| from PIL import Image |
| from skimage.metrics import peak_signal_noise_ratio |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| import random |
| import time |
| import numpy as np |
| from model.common import VGGLoss |
| from data_RGB import get_training_data, get_validation_data |
| try: |
| from warmup_scheduler import GradualWarmupScheduler |
| except ImportError: |
| GradualWarmupScheduler = None |
| from tqdm import tqdm |
|
|
| |
| use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ |
| if use_ddp: |
| torch.distributed.init_process_group(backend="nccl") |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| torch.cuda.set_device(local_rank) |
| world_size = torch.distributed.get_world_size() |
| rank = torch.distributed.get_rank() |
| else: |
| local_rank = 0 |
| world_size = 1 |
| rank = 0 |
|
|
| |
| img_path = './dataset/test/input' |
| targeet_path = './dataset/test/target' |
| img_list = sorted(os.listdir(img_path)) |
| num_img = len(img_list) |
| gpus = ','.join([str(i) for i in opt.GPU]) |
| |
| random.seed(1234) |
| np.random.seed(1234) |
| torch.manual_seed(1234) |
| torch.cuda.manual_seed_all(1234) |
| contrast_loss = torch.nn.CrossEntropyLoss().cuda() |
| |
| start_epoch = 1 |
| mode = opt.MODEL.MODE |
| session = opt.MODEL.SESSION |
|
|
| result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) |
| model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) |
|
|
| utils.mkdir(result_dir) |
| utils.mkdir(model_dir) |
|
|
| train_dir = opt.TRAINING.TRAIN_DIR |
| val_dir = opt.TRAINING.VAL_DIR |
|
|
| loss_vgg = VGGLoss().cuda() |
| swt_loss = SWTLoss( |
| loss_weight_ll=0.1, |
| loss_weight_lh=0.01, |
| loss_weight_hl=0.01, |
| loss_weight_hh=0.05, |
| wavelet='sym19', |
| mode='periodic' |
| ).cuda() |
|
|
| |
| model_G1 = ChannelShuffleWithGBPDeep() |
| if use_ddp: |
| model_G1 = model_G1.cuda(local_rank) |
| model_G1 = nn.parallel.DistributedDataParallel( |
| model_G1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True |
| ) |
| if rank == 0: |
| print("\n==> DDP: world_size={}, 每卡 batch={}(总 batch={})\n".format( |
| world_size, opt.OPTIM.BATCH_SIZE, opt.OPTIM.BATCH_SIZE * world_size)) |
| else: |
| model_G1 = model_G1.cuda() |
| device_ids = list(range(torch.cuda.device_count())) |
| if len(device_ids) > 1: |
| model_G1 = nn.DataParallel(model_G1, device_ids=device_ids) |
| print("\n" + "!" * 60) |
| print(" 当前是 DataParallel,{} 卡时通常会比 2 卡更慢。".format(len(device_ids))) |
| print(" 要让 4 卡比 2 卡快,请用 DDP 启动: bash run_ddp.sh") |
| print(" 或: CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 main.py") |
| print("!" * 60 + "\n") |
|
|
| new_lr = opt.OPTIM.LR_INITIAL |
| optimizer_G1 = optim.Adam(model_G1.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8) |
| |
| scheduler_G1 = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer_G1, T_max=opt.OPTIM.NUM_EPOCHS, eta_min=opt.OPTIM.LR_MIN |
| ) |
| |
| if opt.TRAINING.RESUME: |
| path_chk_rest = '/media/home/songmeixi_insta360.com/Low_light_rainy_new/checkpoint_new/Deraining/models/MPRNet/model_200.pth' |
| utils.load_checkpointG1(model_G1, path_chk_rest, strict=False) |
| |
| start_epoch = 1 |
| print("start_epoch=", start_epoch) |
| |
| |
| |
| |
| |
| if rank == 0: |
| print("==> Resuming, current lr: {:.2e}".format(scheduler_G1.get_last_lr()[0])) |
|
|
| |
| |
| |
| ide_loss = torch.nn.L1Loss().cuda() |
| |
| |
| train_dataset = get_training_data(train_dir, {'patch_size': opt.TRAINING.TRAIN_PS}) |
| if use_ddp: |
| |
| n_workers = max(8, 16 // world_size) |
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) |
| train_loader = DataLoader( |
| train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=False, sampler=train_sampler, |
| num_workers=n_workers, drop_last=False, pin_memory=True, |
| persistent_workers=True if n_workers > 0 else False, prefetch_factor=8 if n_workers > 0 else None) |
| else: |
| train_loader = DataLoader( |
| dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, |
| drop_last=False, pin_memory=True) |
|
|
| val_dataset = get_validation_data(val_dir, {'patch_size': opt.TRAINING.VAL_PS}) |
| val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) |
|
|
| if rank == 0: |
| print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.OPTIM.NUM_EPOCHS + 1)) |
| print('===> Loading datasets') |
|
|
| best_psnr = 0 |
| best_epoch = 0 |
| transform = transforms.ToTensor() |
| |
| for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): |
| if use_ddp: |
| train_sampler.set_epoch(epoch) |
| epoch_start_time = time.time() |
| epoch_loss = 0 |
| model_G1.train() |
|
|
| for i, data in enumerate(tqdm(train_loader, disable=(rank != 0)), 0): |
| optimizer_G1.zero_grad(set_to_none=True) |
| target = data[0].cuda(local_rank, non_blocking=True) |
| input_ = data[1].cuda(local_rank, non_blocking=True) |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| output = model_G1(input_) |
| |
| vgg = loss_vgg(output, target) |
| ide = ide_loss(output, target) |
| swt = swt_loss(output, target) |
|
|
| |
| loss = ide + 0.05 * vgg + 0.15 * swt |
|
|
| loss.backward() |
| optimizer_G1.step() |
| epoch_loss += loss.item() |
|
|
| scheduler_G1.step() |
| if rank == 0: |
| cur_lr = scheduler_G1.get_last_lr()[0] |
| print("------------------------------------------------------------------") |
| print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLR: {:.2e}".format( |
| epoch, time.time() - epoch_start_time, epoch_loss, cur_lr)) |
| print("------------------------------------------------------------------") |
| if epoch % 1 == 0 and rank == 0: |
| state_to_save = model_G1.module.state_dict() if use_ddp else model_G1.state_dict() |
| torch.save({'epoch': epoch, |
| 'state_dict_G1': state_to_save, |
| 'optimizer_G1': optimizer_G1.state_dict(), |
| 'scheduler_G1': scheduler_G1.state_dict(), |
| }, os.path.join(model_dir, 'model_{}.pth'.format(epoch))) |
| print("laileao") |
| model_G1.eval() |
| |
| transform = transforms.ToTensor() |
| PSNR = 0 |
| |
| for img in img_list: |
| image = Image.open(img_path + '/' + img).convert('RGB') |
| target = Image.open(targeet_path + '/' + img).convert('RGB') |
| image = transform(image).unsqueeze(0).cuda(local_rank) |
| target = transform(target).unsqueeze(0).cuda(local_rank) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with torch.set_grad_enabled(False): |
| pre = model_G1(image) |
| |
| p_numpy = pre.squeeze(0).cpu().detach().numpy() |
| label_numpy = target.squeeze(0).cpu().detach().numpy() |
| psnr = peak_signal_noise_ratio(label_numpy, p_numpy, data_range=1) |
| PSNR += psnr |
| PSNR = PSNR / num_img |
| print("PSNR =", PSNR) |
| if use_ddp: |
| torch.distributed.barrier() |