| import torch | |
| import torch.nn as nn | |
| class ConvRep5(nn.Module): | |
| def __init__(self, in_channels, out_channels, rep_scale=4): | |
| super(ConvRep5, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2) | |
| self.conv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1) | |
| self.conv1_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 1), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv2 = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1) | |
| self.conv2_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)) | |
| self.conv_crossh_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)) | |
| self.conv_crossv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_out = nn.Conv2d(out_channels * rep_scale * 10, out_channels, 1) | |
| def forward(self, inp): | |
| x = torch.cat( | |
| [self.conv(inp), | |
| self.conv1(inp), | |
| self.conv2(inp), | |
| self.conv_crossh(inp), | |
| self.conv_crossv(inp), | |
| self.conv_bn(inp), | |
| self.conv1_bn(inp), | |
| self.conv2_bn(inp), | |
| self.conv_crossh_bn(inp), | |
| self.conv_crossv_bn(inp)], | |
| 1 | |
| ) | |
| out = self.conv_out(x) | |
| return out | |
| def slim(self): | |
| conv_weight = self.conv.weight | |
| conv_bias = self.conv.bias | |
| conv1_weight = self.conv1.weight | |
| conv1_bias = self.conv1.bias | |
| conv1_weight = nn.functional.pad(conv1_weight, (2, 2, 2, 2)) | |
| conv2_weight = self.conv2.weight | |
| conv2_weight = nn.functional.pad(conv2_weight, (1, 1, 1, 1)) | |
| conv2_bias = self.conv2.bias | |
| conv_crossv_weight = self.conv_crossv.weight | |
| conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (1, 1, 2, 2)) | |
| conv_crossv_bias = self.conv_crossv.bias | |
| conv_crossh_weight = self.conv_crossh.weight | |
| conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (2, 2, 1, 1)) | |
| conv_crossh_bias = self.conv_crossh.bias | |
| conv1_bn_weight = self.conv1_bn[0].weight | |
| conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (2, 2, 2, 2)) | |
| conv2_bn_weight = self.conv2_bn[0].weight | |
| conv2_bn_weight = nn.functional.pad(conv2_bn_weight, (1, 1, 1, 1)) | |
| conv_crossv_bn_weight = self.conv_crossv_bn[0].weight | |
| conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (1, 1, 2, 2)) | |
| conv_crossh_bn_weight = self.conv_crossh_bn[0].weight | |
| conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (2, 2, 1, 1)) | |
| bn = self.conv_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_bias = self.conv_bn[0].bias * k + b | |
| conv_bn_bias = conv_bn_bias * bn.weight + bn.bias | |
| bn = self.conv1_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv1_bn_bias = self.conv1_bn[0].bias * k + b | |
| conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias | |
| bn = self.conv2_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv2_bn_weight = conv2_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv2_bn_weight = conv2_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv2_bn_bias = self.conv2_bn[0].bias * k + b | |
| conv2_bn_bias = conv2_bn_bias * bn.weight + bn.bias | |
| bn = self.conv_crossv_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b | |
| conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias | |
| bn = self.conv_crossh_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b | |
| conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias | |
| weight = torch.cat( | |
| [conv_weight, conv1_weight, conv2_weight, | |
| conv_crossh_weight, conv_crossv_weight, | |
| conv_bn_weight, conv1_bn_weight, conv2_bn_weight, | |
| conv_crossh_bn_weight, conv_crossv_bn_weight], | |
| 0 | |
| ) | |
| weight_compress = self.conv_out.weight.squeeze() | |
| weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) | |
| bias_ = torch.cat( | |
| [conv_bias, conv1_bias, conv2_bias, | |
| conv_crossh_bias, conv_crossv_bias, | |
| conv_bn_bias, conv1_bn_bias, conv2_bn_bias, | |
| conv_crossh_bn_bias, conv_crossv_bn_bias], | |
| 0 | |
| ) | |
| bias = torch.matmul(weight_compress, bias_) | |
| if isinstance(self.conv_out.bias, torch.Tensor): | |
| bias = bias + self.conv_out.bias | |
| return weight, bias | |
| class ConvRep3(nn.Module): | |
| def __init__(self, in_channels, out_channels, rep_scale=4): | |
| super(ConvRep3, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1) | |
| self.conv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1) | |
| self.conv1_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 1), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)) | |
| self.conv_crossh_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)) | |
| self.conv_crossv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_out = nn.Conv2d(out_channels * rep_scale * 8, out_channels, 1) | |
| def forward(self, inp): | |
| x = torch.cat( | |
| [self.conv(inp), | |
| self.conv1(inp), | |
| self.conv_crossh(inp), | |
| self.conv_crossv(inp), | |
| self.conv_bn(inp), | |
| self.conv1_bn(inp), | |
| self.conv_crossh_bn(inp), | |
| self.conv_crossv_bn(inp)], | |
| 1 | |
| ) | |
| out = self.conv_out(x) | |
| return out | |
| def slim(self): | |
| conv_weight = self.conv.weight | |
| conv_bias = self.conv.bias | |
| conv1_weight = self.conv1.weight | |
| conv1_bias = self.conv1.bias | |
| conv1_weight = nn.functional.pad(conv1_weight, (1, 1, 1, 1)) | |
| conv_crossv_weight = self.conv_crossv.weight | |
| conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (0, 0, 1, 1)) | |
| conv_crossv_bias = self.conv_crossv.bias | |
| conv_crossh_weight = self.conv_crossh.weight | |
| conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (1, 1, 0, 0)) | |
| conv_crossh_bias = self.conv_crossh.bias | |
| conv1_bn_weight = self.conv1_bn[0].weight | |
| conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (1, 1, 1, 1)) | |
| conv_crossv_bn_weight = self.conv_crossv_bn[0].weight | |
| conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (0, 0, 1, 1)) | |
| conv_crossh_bn_weight = self.conv_crossh_bn[0].weight | |
| conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (1, 1, 0, 0)) | |
| bn = self.conv_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_bias = self.conv_bn[0].bias * k + b | |
| conv_bn_bias = conv_bn_bias * bn.weight + bn.bias | |
| bn = self.conv1_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv1_bn_bias = self.conv1_bn[0].bias * k + b | |
| conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias | |
| bn = self.conv_crossv_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b | |
| conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias | |
| bn = self.conv_crossh_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b | |
| conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias | |
| weight = torch.cat( | |
| [conv_weight, conv1_weight, | |
| conv_crossh_weight, conv_crossv_weight, | |
| conv_bn_weight, conv1_bn_weight, | |
| conv_crossh_bn_weight, conv_crossv_bn_weight], | |
| 0 | |
| ) | |
| weight_compress = self.conv_out.weight.squeeze() | |
| weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) | |
| bias_ = torch.cat( | |
| [conv_bias, conv1_bias, | |
| conv_crossh_bias, conv_crossv_bias, | |
| conv_bn_bias, conv1_bn_bias, | |
| conv_crossh_bn_bias, conv_crossv_bn_bias], | |
| 0 | |
| ) | |
| bias = torch.matmul(weight_compress, bias_) | |
| if isinstance(self.conv_out.bias, torch.Tensor): | |
| bias = bias + self.conv_out.bias | |
| return weight, bias | |
| class ConvRepPoint(nn.Module): | |
| def __init__(self, in_channels, out_channels, rep_scale=4): | |
| super(ConvRepPoint, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 1) | |
| self.conv_bn = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels * rep_scale, 1), | |
| nn.BatchNorm2d(out_channels * rep_scale) | |
| ) | |
| self.conv_out = nn.Conv2d(out_channels * rep_scale * 2, out_channels, 1) | |
| def forward(self, inp): | |
| x = torch.cat([self.conv(inp), self.conv_bn(inp)], 1) | |
| out = self.conv_out(x) | |
| return out | |
| def slim(self): | |
| conv_weight = self.conv.weight | |
| conv_bias = self.conv.bias | |
| bn = self.conv_bn[1] | |
| k = 1 / (bn.running_var + bn.eps) ** .5 | |
| b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 | |
| conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| conv_bn_bias = self.conv_bn[0].bias * k + b | |
| conv_bn_bias = conv_bn_bias * bn.weight + bn.bias | |
| weight = torch.cat([conv_weight, conv_bn_weight], 0) | |
| weight_compress = self.conv_out.weight.squeeze() | |
| weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) | |
| bias = torch.cat([conv_bias, conv_bn_bias], 0) | |
| bias = torch.matmul(weight_compress, bias) | |
| if isinstance(self.conv_out.bias, torch.Tensor): | |
| bias = bias + self.conv_out.bias | |
| return weight, bias | |
| class QuadraticConnectionUnit(nn.Module): | |
| def __init__(self, block1, block2, channels): | |
| super(QuadraticConnectionUnit, self).__init__() | |
| self.block1 = block1 | |
| self.block2 = block2 | |
| self.scale = 0.1 | |
| self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) | |
| def forward(self, x): | |
| return self.scale * self.block1(x) * self.block2(x) + self.bias | |
| class QuadraticConnectionUnitS(nn.Module): | |
| def __init__(self, block1, block2, channels): | |
| super(QuadraticConnectionUnitS, self).__init__() | |
| self.block1 = block1 | |
| self.block2 = block2 | |
| self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) | |
| def forward(self, x): | |
| return self.block1(x) * self.block2(x) + self.bias | |
| class AdditionFusion(nn.Module): | |
| def __init__(self, addend1, addend2, channels): | |
| super(AdditionFusion, self).__init__() | |
| self.addend1 = addend1 | |
| self.addend2 = addend2 | |
| self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) | |
| def forward(self, x): | |
| return self.addend1(x) + self.addend2(x) + self.bias | |
| class AdditionFusionS(nn.Module): | |
| def __init__(self, addend1, addend2, channels): | |
| super(AdditionFusionS, self).__init__() | |
| self.addend1 = addend1 | |
| self.addend2 = addend2 | |
| self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) | |
| def forward(self, x): | |
| return self.addend1(x) + self.addend2(x) + self.bias | |
| class DropBlock(nn.Module): | |
| def __init__(self, block_size, p=0.5): | |
| super(DropBlock, self).__init__() | |
| self.block_size = block_size | |
| self.p = p / block_size / block_size | |
| def forward(self, x): | |
| mask = 1 - (torch.rand_like(x[:, :1]) >= self.p).float() | |
| mask = nn.functional.max_pool2d(mask, self.block_size, 1, self.block_size // 2) | |
| return x * (1 - mask) | |
| class ResBlock(nn.Module): | |
| def __init__(self, num_feat=4, rep_scale=4): | |
| super(ResBlock, self).__init__() | |
| self.conv1 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale) | |
| self.conv2 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv2(self.relu(self.conv1(x))) | |
| return identity + out | |
| class ResBlockS(nn.Module): | |
| def __init__(self, num_feat=4): | |
| super(ResBlockS, self).__init__() | |
| self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |
| self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv2(self.relu(self.conv1(x))) | |
| return identity + out | |