File size: 2,275 Bytes
4c62147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import cv2
import math
from enum import Enum

import torch
from torch import nn
import torch.nn.functional as F

from .filter import Filter, FILTER_MODULES


class CascadeArgumentRegressor(nn.Module):
    def __init__(self, in_channels, base_channels, out_channels, head_num):
        super(CascadeArgumentRegressor, self).__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.out_channels = out_channels
        self.head_num = head_num

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.f = nn.Linear(self.in_channels, 160)
        self.g = nn.Linear(self.in_channels, self.base_channels)

        self.headers = nn.ModuleList()
        for i in range(0, self.head_num):
            self.headers.append(
                nn.ModuleList([
                    nn.Linear(160 + self.base_channels, self.base_channels),
                    nn.Linear(self.base_channels, self.out_channels),
                ])
            )

    def forward(self, x):
        x = self.pool(x)
        n, c, _, _ = x.shape
        x = x.view(n, c)

        f = self.f(x)
        g = self.g(x)

        pred_args = []
        for i in range(0, self.head_num):
            g = self.headers[i][0](torch.cat((f, g), dim=1))
            pred_args.append(self.headers[i][1](g))

        return pred_args


class FilterPerformer(nn.Module):
    def __init__(self, filter_types):
        super(FilterPerformer, self).__init__()

        self.filters = [FILTER_MODULES[filter_type]() for filter_type in filter_types]

    def forward(self):
        pass

    def restore(self, x, mask, arguments):
        assert len(self.filters) == len(arguments)
        
        outputs = []
        _image = x
        for filter, arg in zip(self.filters, arguments):
            _image = filter(_image, arg)
            outputs.append(_image * mask + x * (1 - mask))

        return outputs

    def adjust(self, image, mask, arguments):
        assert len(self.filters) == len(arguments)
        
        outputs = []
        _image = image
        for filter, arg in zip(reversed(self.filters), reversed(arguments)):
            _image = filter(_image, arg)
            outputs.append(_image * mask + image * (1 - mask))

        return outputs