|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
import torch.optim
|
|
|
import os
|
|
|
import sys
|
|
|
import argparse
|
|
|
import time
|
|
|
import dataloader
|
|
|
import model
|
|
|
import Myloss
|
|
|
import numpy as np
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
def weights_init(m):
|
|
|
classname = m.__class__.__name__
|
|
|
if classname.find('Conv') != -1:
|
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
|
elif classname.find('BatchNorm') != -1:
|
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
|
m.bias.data.fill_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(config):
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES']='0'
|
|
|
|
|
|
DCE_net = model.enhance_net_nopool().cuda()
|
|
|
|
|
|
DCE_net.apply(weights_init)
|
|
|
if config.load_pretrain == True:
|
|
|
DCE_net.load_state_dict(torch.load(config.pretrain_dir))
|
|
|
train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
|
|
|
|
|
|
|
|
|
|
|
|
L_color = Myloss.L_color()
|
|
|
L_spa = Myloss.L_spa()
|
|
|
|
|
|
L_exp = Myloss.L_exp(16,0.6)
|
|
|
L_TV = Myloss.L_TV()
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
|
|
|
|
|
DCE_net.train()
|
|
|
|
|
|
for epoch in range(config.num_epochs):
|
|
|
for iteration, img_lowlight in enumerate(train_loader):
|
|
|
|
|
|
img_lowlight = img_lowlight.cuda()
|
|
|
|
|
|
enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight)
|
|
|
|
|
|
Loss_TV = 200*L_TV(A)
|
|
|
|
|
|
loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))
|
|
|
|
|
|
loss_col = 5*torch.mean(L_color(enhanced_image))
|
|
|
|
|
|
loss_exp = 10*torch.mean(L_exp(enhanced_image))
|
|
|
|
|
|
|
|
|
|
|
|
loss = Loss_TV + loss_spa + loss_col + loss_exp
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
|
|
|
optimizer.step()
|
|
|
|
|
|
if ((iteration+1) % config.display_iter) == 0:
|
|
|
print("Loss at iteration", iteration+1, ":", loss.item())
|
|
|
if ((iteration+1) % config.snapshot_iter) == 0:
|
|
|
|
|
|
torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
|
|
|
parser.add_argument('--lr', type=float, default=0.0001)
|
|
|
parser.add_argument('--weight_decay', type=float, default=0.0001)
|
|
|
parser.add_argument('--grad_clip_norm', type=float, default=0.1)
|
|
|
parser.add_argument('--num_epochs', type=int, default=200)
|
|
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
|
|
parser.add_argument('--val_batch_size', type=int, default=4)
|
|
|
parser.add_argument('--num_workers', type=int, default=4)
|
|
|
parser.add_argument('--display_iter', type=int, default=10)
|
|
|
parser.add_argument('--snapshot_iter', type=int, default=10)
|
|
|
parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
|
|
|
parser.add_argument('--load_pretrain', type=bool, default= False)
|
|
|
parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")
|
|
|
|
|
|
config = parser.parse_args()
|
|
|
|
|
|
if not os.path.exists(config.snapshots_folder):
|
|
|
os.mkdir(config.snapshots_folder)
|
|
|
|
|
|
|
|
|
train(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|