Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import torch
import torch.nn.functional as F
import torchtask
from module import harmonizer as _harmonizer
def add_parser_arguments(parser):
torchtask.model_template.add_parser_arguments(parser)
def harmonizer():
return Harmonizer
class Harmonizer(torchtask.model_template.TaskModel):
def __init__(self, args):
super(Harmonizer, self).__init__(args)
self.model = _harmonizer.Harmonizer()
self.param_groups = [
{'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr},
{'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr},
{'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr},
]
def forward(self, inp):
resulter, debugger = {}, {}
x, mask = inp
pred = self.model(x, mask)
resulter['outputs'] = pred
return resulter, debugger
def restore(self, x, mask, arguments):
with torch.no_grad():
return self.model.restore_image(x, mask, arguments)
def adjust(self, x, mask, arguments):
with torch.no_grad():
return self.model.adjust_image(x, mask, arguments)