import torch import torch.nn.functional as F import torchtask from module import harmonizer as _harmonizer def add_parser_arguments(parser): torchtask.model_template.add_parser_arguments(parser) def harmonizer(): return Harmonizer class Harmonizer(torchtask.model_template.TaskModel): def __init__(self, args): super(Harmonizer, self).__init__(args) self.model = _harmonizer.Harmonizer() self.param_groups = [ {'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr}, {'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr}, {'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr}, ] def forward(self, inp): resulter, debugger = {}, {} x, mask = inp pred = self.model(x, mask) resulter['outputs'] = pred return resulter, debugger def restore(self, x, mask, arguments): with torch.no_grad(): return self.model.restore_image(x, mask, arguments) def adjust(self, x, mask, arguments): with torch.no_grad(): return self.model.adjust_image(x, mask, arguments)