| | """Model class template |
| | |
| | This module provides a template for users to implement custom models. |
| | You can specify '--model template' to use this model. |
| | The class name should be consistent with both the filename and its model option. |
| | The filename should be <model>_dataset.py |
| | The class name should be <Model>Dataset.py |
| | It implements a simple image-to-image translation baseline based on regression loss. |
| | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: |
| | min_<netG> ||netG(data_A) - data_B||_1 |
| | You need to implement the following functions: |
| | <modify_commandline_options>: Add model-specific options and rewrite default values for existing options. |
| | <__init__>: Initialize this model class. |
| | <set_input>: Unpack input data and perform data pre-processing. |
| | <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>. |
| | <optimize_parameters>: Update network weights; it will be called in every training iteration. |
| | """ |
| | import numpy as np |
| | import torch |
| | from .base_model import BaseModel |
| | from . import networks |
| |
|
| |
|
| | class TemplateModel(BaseModel): |
| | @staticmethod |
| | def modify_commandline_options(parser, is_train=True): |
| | """Add new model-specific options and rewrite default values for existing options. |
| | |
| | Parameters: |
| | parser -- the option parser |
| | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. |
| | |
| | Returns: |
| | the modified parser. |
| | """ |
| | parser.set_defaults(dataset_mode='aligned') |
| | if is_train: |
| | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') |
| |
|
| | return parser |
| |
|
| | def __init__(self, opt): |
| | """Initialize this model class. |
| | |
| | Parameters: |
| | opt -- training/test options |
| | |
| | A few things can be done here. |
| | - (required) call the initialization function of BaseModel |
| | - define loss function, visualization images, model names, and optimizers |
| | """ |
| | BaseModel.__init__(self, opt) |
| | |
| | self.loss_names = ['loss_G'] |
| | |
| | self.visual_names = ['data_A', 'data_B', 'output'] |
| | |
| | |
| | self.model_names = ['G'] |
| | |
| | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) |
| | if self.isTrain: |
| | |
| | |
| | self.criterionLoss = torch.nn.L1Loss() |
| | |
| | |
| | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) |
| | self.optimizers = [self.optimizer] |
| |
|
| | |
| |
|
| | def set_input(self, input): |
| | """Unpack input data from the dataloader and perform necessary pre-processing steps. |
| | |
| | Parameters: |
| | input: a dictionary that contains the data itself and its metadata information. |
| | """ |
| | AtoB = self.opt.direction == 'AtoB' |
| | self.data_A = input['A' if AtoB else 'B'].to(self.device) |
| | self.data_B = input['B' if AtoB else 'A'].to(self.device) |
| | self.image_paths = input['A_paths' if AtoB else 'B_paths'] |
| |
|
| | def forward(self): |
| | """Run forward pass. This will be called by both functions <optimize_parameters> and <test>.""" |
| | self.output = self.netG(self.data_A) |
| |
|
| | def backward(self): |
| | """Calculate losses, gradients, and update network weights; called in every training iteration""" |
| | |
| | |
| | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression |
| | self.loss_G.backward() |
| |
|
| | def optimize_parameters(self): |
| | """Update network weights; it will be called in every training iteration.""" |
| | self.forward() |
| | self.optimizer.zero_grad() |
| | self.backward() |
| | self.optimizer.step() |
| |
|