KhadgaA commited on
Commit
dc604d9
·
1 Parent(s): a245d8a

moonknight

Browse files
__pycache__/lle.cpython-39.pyc ADDED
Binary file (2.84 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from lle import SYELLENetS
8
+
9
+ kwargs = {'channels': 12}
10
+ model = SYELLENetS(**kwargs)
11
+ model.load_state_dict(torch.load('./model_best_slim.pkl', map_location='cpu'))
12
+ model.eval()
13
+ def predict(input_img, ver):
14
+ input_img = Image.open(input_img)
15
+
16
+
17
+ # transform = transforms.Compose([transforms.Resize((400,60), Image.BICUBIC)])
18
+
19
+ input_img = np.array(input_img).transpose([2, 0, 1])
20
+ input_img = input_img.astype(np.float32) / 255.0
21
+ input_img = torch.from_numpy(input_img).unsqueeze(0)
22
+ with torch.no_grad():
23
+ outputs = model(input_img)
24
+ out_img = (outputs.clip(0, 1)[0] * 255).permute([1, 2, 0]).cpu().numpy().astype(np.uint8)[..., ::-1]
25
+ return transforms.ToPILImage()(out_img)
26
+
27
+ title="Image to Line Drawings - Complex and Simple Portraits and Landscapes"
28
+
29
+ examples=['./examples/1.png', './examples/22.png', './examples/23.png', './examples/55.png', './examples/79.png'
30
+ ]
31
+
32
+ iface = gr.Interface(predict, inputs=gr.Image(type='filepath'),
33
+ outputs='image',
34
+ title=title,
35
+ examples=examples)
36
+
37
+ iface.launch()
examples/1.png ADDED
examples/111.png ADDED
examples/146.png ADDED
examples/179.png ADDED
examples/22.png ADDED
examples/23.png ADDED
examples/493.png ADDED
examples/547.png ADDED
examples/55.png ADDED
examples/665.png ADDED
examples/669.png ADDED
examples/748.png ADDED
examples/778.png ADDED
examples/780.png ADDED
examples/79.png ADDED
lle.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from utils import (
3
+ ConvRep5,
4
+ ConvRep3,
5
+ ConvRepPoint,
6
+ DropBlock,
7
+ QuadraticConnectionUnit,
8
+ QuadraticConnectionUnitS,
9
+ )
10
+
11
+
12
+ class SYELLENet(nn.Module):
13
+ def __init__(self, channels, rep_scale=4):
14
+ super(SYELLENet, self).__init__()
15
+ self.channels = channels
16
+ self.head = QuadraticConnectionUnit(
17
+ nn.Sequential(
18
+ ConvRep5(3, channels, rep_scale=rep_scale),
19
+ nn.PReLU(channels),
20
+ ConvRep3(channels, channels, rep_scale=rep_scale)
21
+ ),
22
+ ConvRep5(3, channels, rep_scale=rep_scale),
23
+ channels
24
+ )
25
+ self.body = QuadraticConnectionUnit(
26
+ ConvRep3(channels, channels, rep_scale=rep_scale),
27
+ ConvRepPoint(channels, channels, rep_scale=rep_scale),
28
+ 12
29
+ )
30
+ self.att = nn.Sequential(
31
+ nn.AdaptiveAvgPool2d(1),
32
+ ConvRepPoint(channels, channels, rep_scale=rep_scale),
33
+ nn.PReLU(channels),
34
+ ConvRepPoint(channels, channels, rep_scale=rep_scale),
35
+ nn.Sigmoid()
36
+ )
37
+ self.tail = ConvRep3(channels, 3, rep_scale=rep_scale)
38
+
39
+ self.tail_warm = ConvRep3(channels, 3, rep_scale=rep_scale)
40
+ self.drop = DropBlock(3)
41
+
42
+ def forward(self, x):
43
+ x = self.head(x)
44
+ x = self.body(x)
45
+ x = self.att(x) * x
46
+ return self.tail(x)
47
+
48
+ def forward_warm(self, x):
49
+ x = self.drop(x)
50
+ x = self.head(x)
51
+ x = self.body(x)
52
+ return self.tail(x), self.tail_warm(x)
53
+
54
+ def slim(self):
55
+ net_slim = SYELLENetS(self.channels)
56
+ weight_slim = net_slim.state_dict()
57
+ for name, mod in self.named_modules():
58
+ if isinstance(mod, ConvRep3) or isinstance(mod, ConvRep5) or isinstance(mod, ConvRepPoint):
59
+ if '%s.weight' % name in weight_slim:
60
+ w, b = mod.slim()
61
+ weight_slim['%s.weight' % name] = w
62
+ weight_slim['%s.bias' % name] = b
63
+ if 'block2' in name:
64
+ weight_slim['%s.weight' % name] = weight_slim['%s.weight' % name] * 0.1
65
+ weight_slim['%s.bias' % name] = weight_slim['%s.bias' % name] * 0.1
66
+ elif isinstance(mod, QuadraticConnectionUnit):
67
+ weight_slim['%s.bias' % name] = mod.bias
68
+ elif isinstance(mod, nn.PReLU):
69
+ weight_slim['%s.weight' % name] = mod.weight
70
+
71
+ net_slim.load_state_dict(weight_slim)
72
+ return net_slim
73
+
74
+
75
+ class SYELLENetS(nn.Module):
76
+ def __init__(self, channels):
77
+ super(SYELLENetS, self).__init__()
78
+ self.head = QuadraticConnectionUnitS(
79
+ nn.Sequential(
80
+ nn.Conv2d(3, channels, 5, 1, 2),
81
+ nn.PReLU(channels),
82
+ nn.Conv2d(channels, channels, 3, 1, 1)
83
+ ),
84
+ nn.Conv2d(3, channels, 5, 1, 2),
85
+ channels
86
+ )
87
+ self.body = QuadraticConnectionUnitS(
88
+ nn.Conv2d(channels, channels, 3, 1, 1),
89
+ nn.Conv2d(channels, channels, 1, ),
90
+ 12
91
+ )
92
+ self.att = nn.Sequential(
93
+ nn.AdaptiveAvgPool2d(1),
94
+ nn.Conv2d(channels, channels, 1),
95
+ nn.PReLU(channels),
96
+ nn.Conv2d(channels, channels, 1),
97
+ nn.Sigmoid()
98
+ )
99
+ self.tail = nn.Conv2d(channels, 3, 3, 1, 1)
100
+
101
+ def forward(self, x):
102
+ x = self.head(x)
103
+ x = self.body(x)
104
+ x = self.att(x) * x
105
+ return self.tail(x)
model_best_slim.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d1f9547b3982465874c181e3aa4891312825fe58c54aef60b28fe888494328f
3
+ size 27674
model_best_slim_lol.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7eb128342cba6762455e5804317237bce9c44cd0922827db8994ffca2fcb760
3
+ size 24294
utils.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ConvRep5(nn.Module):
6
+ def __init__(self, in_channels, out_channels, rep_scale=4):
7
+ super(ConvRep5, self).__init__()
8
+ self.in_channels = in_channels
9
+ self.out_channels = out_channels
10
+ self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2)
11
+ self.conv_bn = nn.Sequential(
12
+ nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2),
13
+ nn.BatchNorm2d(out_channels * rep_scale)
14
+ )
15
+ self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
16
+ self.conv1_bn = nn.Sequential(
17
+ nn.Conv2d(in_channels, out_channels * rep_scale, 1),
18
+ nn.BatchNorm2d(out_channels * rep_scale)
19
+ )
20
+ self.conv2 = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
21
+ self.conv2_bn = nn.Sequential(
22
+ nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1),
23
+ nn.BatchNorm2d(out_channels * rep_scale)
24
+ )
25
+ self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
26
+ self.conv_crossh_bn = nn.Sequential(
27
+ nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)),
28
+ nn.BatchNorm2d(out_channels * rep_scale)
29
+ )
30
+ self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
31
+ self.conv_crossv_bn = nn.Sequential(
32
+ nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)),
33
+ nn.BatchNorm2d(out_channels * rep_scale)
34
+ )
35
+ self.conv_out = nn.Conv2d(out_channels * rep_scale * 10, out_channels, 1)
36
+
37
+ def forward(self, inp):
38
+ x = torch.cat(
39
+ [self.conv(inp),
40
+ self.conv1(inp),
41
+ self.conv2(inp),
42
+ self.conv_crossh(inp),
43
+ self.conv_crossv(inp),
44
+ self.conv_bn(inp),
45
+ self.conv1_bn(inp),
46
+ self.conv2_bn(inp),
47
+ self.conv_crossh_bn(inp),
48
+ self.conv_crossv_bn(inp)],
49
+ 1
50
+ )
51
+
52
+ out = self.conv_out(x)
53
+
54
+ return out
55
+
56
+ def slim(self):
57
+ conv_weight = self.conv.weight
58
+ conv_bias = self.conv.bias
59
+ conv1_weight = self.conv1.weight
60
+ conv1_bias = self.conv1.bias
61
+ conv1_weight = nn.functional.pad(conv1_weight, (2, 2, 2, 2))
62
+ conv2_weight = self.conv2.weight
63
+ conv2_weight = nn.functional.pad(conv2_weight, (1, 1, 1, 1))
64
+ conv2_bias = self.conv2.bias
65
+ conv_crossv_weight = self.conv_crossv.weight
66
+ conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (1, 1, 2, 2))
67
+ conv_crossv_bias = self.conv_crossv.bias
68
+ conv_crossh_weight = self.conv_crossh.weight
69
+ conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (2, 2, 1, 1))
70
+ conv_crossh_bias = self.conv_crossh.bias
71
+ conv1_bn_weight = self.conv1_bn[0].weight
72
+ conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (2, 2, 2, 2))
73
+ conv2_bn_weight = self.conv2_bn[0].weight
74
+ conv2_bn_weight = nn.functional.pad(conv2_bn_weight, (1, 1, 1, 1))
75
+ conv_crossv_bn_weight = self.conv_crossv_bn[0].weight
76
+ conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (1, 1, 2, 2))
77
+ conv_crossh_bn_weight = self.conv_crossh_bn[0].weight
78
+ conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (2, 2, 1, 1))
79
+ bn = self.conv_bn[1]
80
+ k = 1 / (bn.running_var + bn.eps) ** .5
81
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
82
+ conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
83
+ conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
84
+ conv_bn_bias = self.conv_bn[0].bias * k + b
85
+ conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
86
+ bn = self.conv1_bn[1]
87
+ k = 1 / (bn.running_var + bn.eps) ** .5
88
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
89
+ conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
90
+ conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
91
+ conv1_bn_bias = self.conv1_bn[0].bias * k + b
92
+ conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
93
+ bn = self.conv2_bn[1]
94
+ k = 1 / (bn.running_var + bn.eps) ** .5
95
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
96
+ conv2_bn_weight = conv2_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
97
+ conv2_bn_weight = conv2_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
98
+ conv2_bn_bias = self.conv2_bn[0].bias * k + b
99
+ conv2_bn_bias = conv2_bn_bias * bn.weight + bn.bias
100
+ bn = self.conv_crossv_bn[1]
101
+ k = 1 / (bn.running_var + bn.eps) ** .5
102
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
103
+ conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
104
+ conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
105
+ conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b
106
+ conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
107
+ bn = self.conv_crossh_bn[1]
108
+ k = 1 / (bn.running_var + bn.eps) ** .5
109
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
110
+ conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
111
+ conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
112
+ conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b
113
+ conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
114
+ weight = torch.cat(
115
+ [conv_weight, conv1_weight, conv2_weight,
116
+ conv_crossh_weight, conv_crossv_weight,
117
+ conv_bn_weight, conv1_bn_weight, conv2_bn_weight,
118
+ conv_crossh_bn_weight, conv_crossv_bn_weight],
119
+ 0
120
+ )
121
+ weight_compress = self.conv_out.weight.squeeze()
122
+ weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
123
+ bias_ = torch.cat(
124
+ [conv_bias, conv1_bias, conv2_bias,
125
+ conv_crossh_bias, conv_crossv_bias,
126
+ conv_bn_bias, conv1_bn_bias, conv2_bn_bias,
127
+ conv_crossh_bn_bias, conv_crossv_bn_bias],
128
+ 0
129
+ )
130
+ bias = torch.matmul(weight_compress, bias_)
131
+ if isinstance(self.conv_out.bias, torch.Tensor):
132
+ bias = bias + self.conv_out.bias
133
+ return weight, bias
134
+
135
+
136
+ class ConvRep3(nn.Module):
137
+ def __init__(self, in_channels, out_channels, rep_scale=4):
138
+ super(ConvRep3, self).__init__()
139
+ self.in_channels = in_channels
140
+ self.out_channels = out_channels
141
+ self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1)
142
+ self.conv_bn = nn.Sequential(
143
+ nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1),
144
+ nn.BatchNorm2d(out_channels * rep_scale)
145
+ )
146
+ self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
147
+ self.conv1_bn = nn.Sequential(
148
+ nn.Conv2d(in_channels, out_channels * rep_scale, 1),
149
+ nn.BatchNorm2d(out_channels * rep_scale)
150
+ )
151
+ self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0))
152
+ self.conv_crossh_bn = nn.Sequential(
153
+ nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)),
154
+ nn.BatchNorm2d(out_channels * rep_scale)
155
+ )
156
+ self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1))
157
+ self.conv_crossv_bn = nn.Sequential(
158
+ nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)),
159
+ nn.BatchNorm2d(out_channels * rep_scale)
160
+ )
161
+ self.conv_out = nn.Conv2d(out_channels * rep_scale * 8, out_channels, 1)
162
+
163
+ def forward(self, inp):
164
+ x = torch.cat(
165
+ [self.conv(inp),
166
+ self.conv1(inp),
167
+ self.conv_crossh(inp),
168
+ self.conv_crossv(inp),
169
+ self.conv_bn(inp),
170
+ self.conv1_bn(inp),
171
+ self.conv_crossh_bn(inp),
172
+ self.conv_crossv_bn(inp)],
173
+ 1
174
+ )
175
+
176
+ out = self.conv_out(x)
177
+
178
+ return out
179
+
180
+ def slim(self):
181
+ conv_weight = self.conv.weight
182
+ conv_bias = self.conv.bias
183
+ conv1_weight = self.conv1.weight
184
+ conv1_bias = self.conv1.bias
185
+ conv1_weight = nn.functional.pad(conv1_weight, (1, 1, 1, 1))
186
+ conv_crossv_weight = self.conv_crossv.weight
187
+ conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (0, 0, 1, 1))
188
+ conv_crossv_bias = self.conv_crossv.bias
189
+ conv_crossh_weight = self.conv_crossh.weight
190
+ conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (1, 1, 0, 0))
191
+ conv_crossh_bias = self.conv_crossh.bias
192
+ conv1_bn_weight = self.conv1_bn[0].weight
193
+ conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (1, 1, 1, 1))
194
+ conv_crossv_bn_weight = self.conv_crossv_bn[0].weight
195
+ conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (0, 0, 1, 1))
196
+ conv_crossh_bn_weight = self.conv_crossh_bn[0].weight
197
+ conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (1, 1, 0, 0))
198
+ bn = self.conv_bn[1]
199
+ k = 1 / (bn.running_var + bn.eps) ** .5
200
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
201
+ conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
202
+ conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
203
+ conv_bn_bias = self.conv_bn[0].bias * k + b
204
+ conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
205
+ bn = self.conv1_bn[1]
206
+ k = 1 / (bn.running_var + bn.eps) ** .5
207
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
208
+ conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
209
+ conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
210
+ conv1_bn_bias = self.conv1_bn[0].bias * k + b
211
+ conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias
212
+ bn = self.conv_crossv_bn[1]
213
+ k = 1 / (bn.running_var + bn.eps) ** .5
214
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
215
+ conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
216
+ conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
217
+ conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b
218
+ conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias
219
+ bn = self.conv_crossh_bn[1]
220
+ k = 1 / (bn.running_var + bn.eps) ** .5
221
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
222
+ conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
223
+ conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
224
+ conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b
225
+ conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias
226
+ weight = torch.cat(
227
+ [conv_weight, conv1_weight,
228
+ conv_crossh_weight, conv_crossv_weight,
229
+ conv_bn_weight, conv1_bn_weight,
230
+ conv_crossh_bn_weight, conv_crossv_bn_weight],
231
+ 0
232
+ )
233
+ weight_compress = self.conv_out.weight.squeeze()
234
+ weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
235
+ bias_ = torch.cat(
236
+ [conv_bias, conv1_bias,
237
+ conv_crossh_bias, conv_crossv_bias,
238
+ conv_bn_bias, conv1_bn_bias,
239
+ conv_crossh_bn_bias, conv_crossv_bn_bias],
240
+ 0
241
+ )
242
+ bias = torch.matmul(weight_compress, bias_)
243
+ if isinstance(self.conv_out.bias, torch.Tensor):
244
+ bias = bias + self.conv_out.bias
245
+ return weight, bias
246
+
247
+
248
+ class ConvRepPoint(nn.Module):
249
+ def __init__(self, in_channels, out_channels, rep_scale=4):
250
+ super(ConvRepPoint, self).__init__()
251
+ self.in_channels = in_channels
252
+ self.out_channels = out_channels
253
+ self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 1)
254
+ self.conv_bn = nn.Sequential(
255
+ nn.Conv2d(in_channels, out_channels * rep_scale, 1),
256
+ nn.BatchNorm2d(out_channels * rep_scale)
257
+ )
258
+ self.conv_out = nn.Conv2d(out_channels * rep_scale * 2, out_channels, 1)
259
+
260
+ def forward(self, inp):
261
+ x = torch.cat([self.conv(inp), self.conv_bn(inp)], 1)
262
+ out = self.conv_out(x)
263
+ return out
264
+
265
+ def slim(self):
266
+ conv_weight = self.conv.weight
267
+ conv_bias = self.conv.bias
268
+ bn = self.conv_bn[1]
269
+ k = 1 / (bn.running_var + bn.eps) ** .5
270
+ b = - bn.running_mean / (bn.running_var + bn.eps) ** .5
271
+ conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
272
+ conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
273
+ conv_bn_bias = self.conv_bn[0].bias * k + b
274
+ conv_bn_bias = conv_bn_bias * bn.weight + bn.bias
275
+ weight = torch.cat([conv_weight, conv_bn_weight], 0)
276
+ weight_compress = self.conv_out.weight.squeeze()
277
+ weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1])
278
+ bias = torch.cat([conv_bias, conv_bn_bias], 0)
279
+ bias = torch.matmul(weight_compress, bias)
280
+ if isinstance(self.conv_out.bias, torch.Tensor):
281
+ bias = bias + self.conv_out.bias
282
+ return weight, bias
283
+
284
+
285
+ class QuadraticConnectionUnit(nn.Module):
286
+ def __init__(self, block1, block2, channels):
287
+ super(QuadraticConnectionUnit, self).__init__()
288
+ self.block1 = block1
289
+ self.block2 = block2
290
+ self.scale = 0.1
291
+ self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
292
+
293
+ def forward(self, x):
294
+ return self.scale * self.block1(x) * self.block2(x) + self.bias
295
+
296
+
297
+ class QuadraticConnectionUnitS(nn.Module):
298
+ def __init__(self, block1, block2, channels):
299
+ super(QuadraticConnectionUnitS, self).__init__()
300
+ self.block1 = block1
301
+ self.block2 = block2
302
+ self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
303
+
304
+ def forward(self, x):
305
+ return self.block1(x) * self.block2(x) + self.bias
306
+
307
+
308
+ class AdditionFusion(nn.Module):
309
+ def __init__(self, addend1, addend2, channels):
310
+ super(AdditionFusion, self).__init__()
311
+ self.addend1 = addend1
312
+ self.addend2 = addend2
313
+ self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
314
+
315
+ def forward(self, x):
316
+ return self.addend1(x) + self.addend2(x) + self.bias
317
+
318
+
319
+ class AdditionFusionS(nn.Module):
320
+ def __init__(self, addend1, addend2, channels):
321
+ super(AdditionFusionS, self).__init__()
322
+ self.addend1 = addend1
323
+ self.addend2 = addend2
324
+ self.bias = nn.Parameter(torch.randn((1, channels, 1, 1)))
325
+
326
+ def forward(self, x):
327
+ return self.addend1(x) + self.addend2(x) + self.bias
328
+
329
+
330
+ class DropBlock(nn.Module):
331
+ def __init__(self, block_size, p=0.5):
332
+ super(DropBlock, self).__init__()
333
+ self.block_size = block_size
334
+ self.p = p / block_size / block_size
335
+
336
+ def forward(self, x):
337
+ mask = 1 - (torch.rand_like(x[:, :1]) >= self.p).float()
338
+ mask = nn.functional.max_pool2d(mask, self.block_size, 1, self.block_size // 2)
339
+ return x * (1 - mask)
340
+
341
+
342
+ class ResBlock(nn.Module):
343
+ def __init__(self, num_feat=4, rep_scale=4):
344
+ super(ResBlock, self).__init__()
345
+ self.conv1 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale)
346
+ self.conv2 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale)
347
+ self.relu = nn.ReLU(inplace=True)
348
+
349
+ def forward(self, x):
350
+ identity = x
351
+ out = self.conv2(self.relu(self.conv1(x)))
352
+ return identity + out
353
+
354
+
355
+ class ResBlockS(nn.Module):
356
+ def __init__(self, num_feat=4):
357
+ super(ResBlockS, self).__init__()
358
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
359
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
360
+ self.relu = nn.ReLU(inplace=True)
361
+
362
+ def forward(self, x):
363
+ identity = x
364
+ out = self.conv2(self.relu(self.conv1(x)))
365
+ return identity + out