File size: 1,272 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F

import torchtask

from module import harmonizer as _harmonizer


def add_parser_arguments(parser):
    torchtask.model_template.add_parser_arguments(parser)


def harmonizer():
    return Harmonizer


class Harmonizer(torchtask.model_template.TaskModel):
    def __init__(self, args):
        super(Harmonizer, self).__init__(args)

        self.model = _harmonizer.Harmonizer()
        self.param_groups = [
            {'params': filter(lambda p:p.requires_grad, self.model.backbone.parameters()), 'lr': self.args.lr},
            {'params': filter(lambda p:p.requires_grad, self.model.regressor.parameters()), 'lr': self.args.lr},
            {'params': filter(lambda p:p.requires_grad, self.model.performer.parameters()), 'lr': self.args.lr},
        ]
    
    def forward(self, inp):
        resulter, debugger = {}, {}
        x, mask = inp
        pred = self.model(x, mask)
        resulter['outputs'] = pred
        return resulter, debugger

    def restore(self, x, mask, arguments):
        with torch.no_grad():
            return self.model.restore_image(x, mask, arguments)

    def adjust(self, x, mask, arguments):
        with torch.no_grad():
            return self.model.adjust_image(x, mask, arguments)