|
|
import os |
|
|
import math |
|
|
from decimal import Decimal |
|
|
|
|
|
import utility |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
import torch.nn.utils as utils |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" |
|
|
|
|
|
class Trainer(): |
|
|
def __init__(self, args, loader, my_model, my_loss, ckp): |
|
|
self.args = args |
|
|
self.scale = args.scale |
|
|
self.lr = args.lr |
|
|
|
|
|
self.ckp = ckp |
|
|
self.loader_train = loader.loader_train |
|
|
self.loader_test = loader.loader_test |
|
|
self.model = my_model |
|
|
self.loss = my_loss |
|
|
self.optimizer = utility.make_optimizer(args, self.model) |
|
|
|
|
|
if self.args.load != '': |
|
|
self.optimizer.load(ckp.dir, epoch=len(ckp.log)) |
|
|
|
|
|
self.error_last = 1e8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self): |
|
|
self.loss.step() |
|
|
epoch = self.optimizer.get_last_epoch() + 1 |
|
|
|
|
|
|
|
|
lr = self.lr * (0.5 ** (epoch // 100)) |
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] = lr |
|
|
|
|
|
|
|
|
self.ckp.write_log( |
|
|
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) |
|
|
) |
|
|
self.loss.start_log() |
|
|
self.model.train() |
|
|
|
|
|
timer_data, timer_model = utility.timer(), utility.timer() |
|
|
|
|
|
self.loader_train.dataset.set_scale(0) |
|
|
for batch, (lr, edge, hr, _,) in enumerate(self.loader_train): |
|
|
lr, edge, hr = self.prepare(lr, edge, hr) |
|
|
timer_data.hold() |
|
|
timer_model.tic() |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
sr = self.model(torch.cat((lr, edge), dim=1), 0) |
|
|
|
|
|
loss = self.loss(sr, hr) |
|
|
|
|
|
loss.backward() |
|
|
if self.args.gclip > 0: |
|
|
utils.clip_grad_value_( |
|
|
self.model.parameters(), |
|
|
self.args.gclip |
|
|
) |
|
|
self.optimizer.step() |
|
|
|
|
|
timer_model.hold() |
|
|
|
|
|
if (batch + 1) % self.args.print_every == 0: |
|
|
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( |
|
|
(batch + 1) * self.args.batch_size, |
|
|
len(self.loader_train.dataset), |
|
|
self.loss.display_loss(batch), |
|
|
timer_model.release(), |
|
|
timer_data.release())) |
|
|
|
|
|
timer_data.tic() |
|
|
|
|
|
self.loss.end_log(len(self.loader_train)) |
|
|
self.error_last = self.loss.log[-1, -1] |
|
|
self.optimizer.schedule() |
|
|
|
|
|
def test(self): |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
epoch = self.optimizer.get_last_epoch() |
|
|
self.ckp.write_log('\nEvaluation:') |
|
|
self.ckp.add_log( |
|
|
torch.zeros(1, len(self.loader_test), len(self.scale)) |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
timer_test = utility.timer() |
|
|
if self.args.save_results: self.ckp.begin_background() |
|
|
for idx_data, d in enumerate(self.loader_test): |
|
|
for idx_scale, scale in enumerate(self.scale): |
|
|
d.dataset.set_scale(idx_scale) |
|
|
tqdm_test = tqdm(d, ncols=80) |
|
|
for _, (lr, edge, hr, filename) in enumerate(tqdm_test): |
|
|
lr, edge, hr = self.prepare(lr, edge, hr) |
|
|
sr = self.model(torch.cat((lr, edge), dim=1), idx_scale) |
|
|
sr = utility.quantize(sr, self.args.rgb_range) |
|
|
|
|
|
save_list = [sr] |
|
|
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( |
|
|
sr, hr, scale, self.args.rgb_range, dataset=d |
|
|
) |
|
|
|
|
|
if self.args.save_gt: |
|
|
save_list.extend([hr]) |
|
|
|
|
|
if self.args.save_results: |
|
|
self.ckp.save_results(d, filename[0], save_list, scale) |
|
|
|
|
|
self.ckp.log[-1, idx_data, idx_scale] /= len(d) |
|
|
best = self.ckp.log.max(0) |
|
|
self.ckp.write_log( |
|
|
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( |
|
|
d.dataset.name, |
|
|
scale, |
|
|
self.ckp.log[-1, idx_data, idx_scale], |
|
|
best[0][idx_data, idx_scale], |
|
|
best[1][idx_data, idx_scale] + 1 |
|
|
) |
|
|
) |
|
|
|
|
|
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) |
|
|
self.ckp.write_log('Saving...') |
|
|
|
|
|
if self.args.save_results: |
|
|
self.ckp.end_background() |
|
|
|
|
|
if not self.args.test_only: |
|
|
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) |
|
|
|
|
|
self.ckp.write_log( |
|
|
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True |
|
|
) |
|
|
|
|
|
torch.set_grad_enabled(True) |
|
|
|
|
|
def prepare(self, *args): |
|
|
device = torch.device('cpu' if self.args.cpu else 'cuda') |
|
|
def _prepare(tensor): |
|
|
if self.args.precision == 'half': tensor = tensor.half() |
|
|
return tensor.to(device) |
|
|
|
|
|
return [_prepare(a) for a in args] |
|
|
|
|
|
def terminate(self): |
|
|
if self.args.test_only: |
|
|
self.test() |
|
|
return True |
|
|
else: |
|
|
epoch = self.optimizer.get_last_epoch() + 1 |
|
|
return epoch >= self.args.epochs |