File size: 1,272 Bytes
4c62147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn.functional as F
import torchtask
from module import harmonizer as _harmonizer
def add_parser_arguments(parser):
torchtask.model_template.add_parser_arguments(parser)
def harmonizer():
return Harmonizer
class Harmonizer(torchtask.model_template.TaskModel):
def __init__(self, args):
super(Harmonizer, self).__init__(args)
self.model = _harmonizer.Harmonizer()
self.param_groups = [
{'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr},
{'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr},
{'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr},
]
def forward(self, inp):
resulter, debugger = {}, {}
x, mask = inp
pred = self.model(x, mask)
resulter['outputs'] = pred
return resulter, debugger
def restore(self, x, mask, arguments):
with torch.no_grad():
return self.model.restore_image(x, mask, arguments)
def adjust(self, x, mask, arguments):
with torch.no_grad():
return self.model.adjust_image(x, mask, arguments)
|