File size: 9,813 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    return p

import torchvision.models as models

class NormType(Enum):
    Batch = 1
    BatchZero = 2
    Weight = 3
    Spectral = 4
    Instance = 5
    InstanceZero = 6
    Pixel = 7

SplitFuncOrIdxList = Optional[Union[Callable, List[int]]]

def to_device(m, device):
    return m.to(device)

def apply_init(m, init_func):
    if hasattr(m, 'apply'):
        m.apply(lambda x: init_func(x.weight) if hasattr(x, 'weight') and hasattr(x.weight, 'data') else None)

ImageDataBunch = DataBunch


def camel2snake(name):
    import re
    _camel_re1 = re.compile('(.)([A-Z][a-z]+)')
    _camel_re2 = re.compile('([a-z0-9])([A-Z])')
    return _camel_re2.sub(r'\1_\2', _camel_re1.sub(r'\1_\2', name)).lower()

class SequentialEx(nn.Module):
    "Like `nn.Sequential`, but with ModuleList semantics, and can handle `MergeLayer`."
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        res = x
        for l in self.layers:
            res.orig = x
            nres = l(res)
            # We have to remove res.orig to avoid hanging references and therefore memory leaks
            res.orig = None
            res = nres
        return res

    def __getitem__(self, i): return self.layers[i]
    def append(self, l): return self.layers.append(l)
    def extend(self, l): return self.layers.extend(l)
    def insert(self, i, l): return self.layers.insert(i, l)

class MergeLayer(nn.Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, dense:bool=False):
        super().__init__()
        self.dense = dense

    def forward(self, x):
        return torch.cat([x, x.orig], dim=1) if self.dense else (x+x.orig)

class SigmoidRange(nn.Module):
    "Sigmoid module with range `(low, x_max)`"
    def __init__(self, low, high):
        super().__init__()
        self.low, self.high = low, high

    def forward(self, x):
        return torch.sigmoid(x) * (self.high - self.low) + self.low

def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
               norm_type:Optional[str]=None,  use_activ:bool=True, leaky:float=None,
               transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):
    "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
    if padding is None: padding = (ks-1)//2 if not transpose else 0
    bn = norm_type in ('Batch', 'BatchZero')
    if bias is None: bias = not bn
    conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
    conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding)
    
    if init: init(conv.weight)
    
    if norm_type == 'Weight': conv = weight_norm(conv)
    elif norm_type == 'Spectral': conv = spectral_norm(conv)
    
    layers = [conv]
    if use_activ: layers.append(relu(leaky=leaky))
    if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
    if self_attention: layers.append(SelfAttention(nf))
    return nn.Sequential(*layers)

def relu(leaky:float=None):
    return nn.LeakyReLU(ifnone(leaky,0.1), inplace=True) if leaky is not None else nn.ReLU(inplace=True)

def batchnorm_2d(nf:int, norm_type:str='Batch'):
    return nn.BatchNorm2d(nf)

class SelfAttention(nn.Module):
    "Self attention layer for nd."
    def __init__(self, n_channels:int):
        super().__init__()
        self.query = conv1d(n_channels, n_channels//8)
        self.key   = conv1d(n_channels, n_channels//8)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(tensor([0.]))

    def forward(self, x):
        #Notation from https://arxiv.org/abs/1805.08318
        size = x.size()
        x = x.view(*size[:2],-1)
        f,g,h = self.query(x),self.key(x),self.value(x)
        beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size)

def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
    "Create and initialize a `nn.Conv1d` layer with spectral normalization."
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias: conv.bias.data.zero_()
    return spectral_norm(conv)

def res_block(nf, dense:bool=False, norm_type:str='Batch', bottle:bool=False, **kwargs):
    "Resnet block of `nf` features."
    norm2 = norm_type
    if not dense and (norm_type=='Batch'): norm2 = 'BatchZero'
    nf_inner = nf//2 if bottle else nf
    return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **kwargs),
                        conv_layer(nf_inner, nf, norm_type=norm2, **kwargs),
                        MergeLayer(dense))

# --- Hooks & Model Sizes ---

class Hook():
    "Create a hook on `m` with `hook_func`."
    def __init__(self, m:nn.Module, hook_func:Callable, is_forward:bool=True, detach:bool=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:nn.Module, input:Tensor, output:Tensor):
        if self.detach:
            input = (o.detach() for o in input) if isinstance(input, tuple) else input.detach()
            output = (o.detach() for o in output) if isinstance(output, tuple) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        if not self.removed:
            self.hook.remove()
            self.removed = True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    def __init__(self, ms:List[nn.Module], hook_func:Callable, is_forward:bool=True, detach:bool=True):
        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

    def __getitem__(self,i:int)->Hook: return self.hooks[i]
    def __len__(self)->int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]

    def remove(self):
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if isinstance(o,list) else list(o)

def hook_outputs(modules:List[nn.Module], detach:bool=True, grad:bool=False)->Hooks:
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)

def dummy_eval(m:nn.Module, size:Tuple=(64,64)):
    "Evaluate `m` on a dummy input of a certain `size`"
    ch_in = in_channels(m)
    x = torch.randn(1,ch_in,*size)
    if next(m.parameters()).is_cuda: x = x.cuda()
    return m.eval()(x)

def model_sizes(m:nn.Module, size:Tuple=(64,64)):
    "Pass a dummy input through the model `m` to get the various sizes of activations."
    with hook_outputs(m) as hooks:
        dummy_eval(m, size)
        return [o.stored.shape for o in hooks]

def in_channels(m:nn.Module) -> int:
    "Return the shape of the first weight layer in `m`."
    for l in m.modules():
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            return l.in_channels
    raise Exception('No weight layer')

# --- Model Creation ---

def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
    "Cut off the head of a typically pretrained `arch`."
    model = arch(pretrained=pretrained)
    # Most torchvision models have 'fc' or 'classifier' as head
    # ResNet specific cut
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    
    return nn.Sequential(*list(model.children())[:cut])

def has_pool_type(m):
    if isinstance(m, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d, nn.AvgPool2d, nn.MaxPool2d)): return True
    return False

def cnn_config(arch):
    "Get the metadata for `arch`."
    # Simplified config for ResNets
    return {'split': lambda m: (m[0][6], m[1])} # Split at layer 6 for ResNet

# --- Learner Shim ---

class Learner:
    "Minimal Learner shim for inference."
    def __init__(self, data, model, path=None):
        self.data = data
        self.model = model
        self.path = path
        
        # Use deoldify.device to get the correct device (CUDA/XPU/CPU)
        from deoldify import device as device_settings
        self.device = device_settings.get_torch_device()
        
        self.model.to(self.device)

    def load(self, name):
        # Load state dict
        if self.path:
            path = self.path / 'models' / f'{name}.pth'
        else:
            path = f'models/{name}.pth'
        
        # Handle map_location
        state = torch.load(path, map_location=self.device)
        if 'model' in state: state = state['model']
        self.model.load_state_dict(state, strict=True)
    
    def split(self, split_on):
        pass # Not needed for inference
    
    def freeze(self):
        for p in self.model.parameters(): p.requires_grad = False

class DataBunch:
    "Minimal DataBunch shim."
    def __init__(self, c=3, device=None):
        self.c = c
        if device:
            self.device = device
        else:
            from deoldify import device as device_settings
            self.device = device_settings.get_torch_device()

def get_dummy_databunch():
    return DataBunch()

def tensor(x, *args, **kwargs):
    return torch.tensor(x, *args, **kwargs)