| | |
| | import sys |
| | import os |
| | import torch |
| |
|
| |
|
| | |
| | root_path = os.path.abspath('.') |
| | sys.path.append(root_path) |
| | from architecture.rrdb import RRDBNet |
| | from train_code.train_master import train_master |
| |
|
| |
|
| |
|
| | |
| | scaler = torch.cuda.amp.GradScaler() |
| |
|
| |
|
| | class train_esrnet(train_master): |
| | def __init__(self, options, args) -> None: |
| | super().__init__(options, args, "esrnet") |
| |
|
| |
|
| | def loss_init(self): |
| | |
| | self.pixel_loss_load() |
| | |
| |
|
| | def call_model(self): |
| | |
| | self.generator = RRDBNet(3, 3, scale=self.options['scale'], num_block=self.options['ESR_blocks_num']).cuda() |
| | |
| | self.generator.train() |
| |
|
| | |
| | def run(self): |
| | self.master_run() |
| | |
| |
|
| | |
| | def calculate_loss(self, gen_hr, imgs_hr): |
| |
|
| | |
| | l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) |
| | self.weight_store["pixel_loss"] = l_g_pix |
| | self.generator_loss += l_g_pix |
| |
|
| |
|
| | def tensorboard_report(self, iteration): |
| | |
| | self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) |
| |
|