Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- nets/CSPdarknet_tiny.py +101 -0
- nets/FAENet.py +248 -0
- nets/__init__.py +1 -0
- nets/yolo.py +241 -0
- nets/yolo_training.py +277 -0
nets/CSPdarknet_tiny.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision
|
| 6 |
+
from nets.FAENet import FAENet
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
class BasicConv(nn.Module):
|
| 11 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
| 12 |
+
super(BasicConv, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size // 2, bias=False)
|
| 15 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 16 |
+
self.activation = nn.LeakyReLU(0.1)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x = self.conv(x)
|
| 20 |
+
x = self.bn(x)
|
| 21 |
+
x = self.activation(x)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
class Resblock_body(nn.Module):
|
| 25 |
+
def __init__(self, in_channels, out_channels):
|
| 26 |
+
super(Resblock_body, self).__init__()
|
| 27 |
+
self.out_channels = out_channels
|
| 28 |
+
|
| 29 |
+
self.conv1 = BasicConv(in_channels, out_channels, 3)
|
| 30 |
+
self.conv2 = BasicConv(out_channels // 2, out_channels // 2, 3)
|
| 31 |
+
self.conv3 = BasicConv(out_channels // 2, out_channels // 2, 3)
|
| 32 |
+
self.conv4 = BasicConv(out_channels, out_channels, 1)
|
| 33 |
+
self.maxpool = nn.MaxPool2d([2, 2], [2, 2])
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
x = self.conv1(x)
|
| 37 |
+
route = x
|
| 38 |
+
c = self.out_channels
|
| 39 |
+
x = torch.split(x, c // 2, dim=1)[1]
|
| 40 |
+
x = self.conv2(x)
|
| 41 |
+
route1 = x
|
| 42 |
+
x = self.conv3(x)
|
| 43 |
+
x = torch.cat([x, route1], dim=1)
|
| 44 |
+
x = self.conv4(x)
|
| 45 |
+
feat = x
|
| 46 |
+
x = torch.cat([route, x], dim=1)
|
| 47 |
+
x = self.maxpool(x)
|
| 48 |
+
return x, feat
|
| 49 |
+
|
| 50 |
+
class CSPDarkNet(nn.Module):
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super(CSPDarkNet, self).__init__()
|
| 53 |
+
|
| 54 |
+
self.faenet = FAENet()
|
| 55 |
+
# 416,416,3 -> 208,208,32 -> 104,104,64
|
| 56 |
+
self.conv1 = BasicConv(3, 32, kernel_size=3, stride=2)
|
| 57 |
+
self.conv2 = BasicConv(32, 64, kernel_size=3, stride=2)
|
| 58 |
+
# 104,104,64 -> 52,52,128
|
| 59 |
+
self.resblock_body1 = Resblock_body(64, 64)
|
| 60 |
+
# 52,52,128 -> 26,26,256
|
| 61 |
+
self.resblock_body2 = Resblock_body(128, 128)
|
| 62 |
+
# 26,26,256 -> 13,13,512
|
| 63 |
+
self.resblock_body3 = Resblock_body(256, 256)
|
| 64 |
+
# 13,13,512 -> 13,13,512
|
| 65 |
+
self.conv3 = BasicConv(512, 512, kernel_size=3)
|
| 66 |
+
self.num_features = 1
|
| 67 |
+
# 进行权值初始化
|
| 68 |
+
for m in self.modules():
|
| 69 |
+
if isinstance(m, nn.Conv2d):
|
| 70 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 71 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 72 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 73 |
+
m.weight.data.fill_(1)
|
| 74 |
+
m.bias.data.zero_()
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
|
| 78 |
+
x = self.faenet(x)
|
| 79 |
+
# 416,416,3 -> 208,208,32 -> 104,104,64
|
| 80 |
+
x = self.conv1(x)
|
| 81 |
+
x = self.conv2(x)
|
| 82 |
+
# 104,104,64 -> 52,52,128
|
| 83 |
+
x, _ = self.resblock_body1(x)
|
| 84 |
+
# 52,52,128 -> 26,26,256
|
| 85 |
+
x, _ = self.resblock_body2(x)
|
| 86 |
+
# 26,26,256 -> 13,13,512
|
| 87 |
+
# -> feat1 26,26,256
|
| 88 |
+
x, feat1 = self.resblock_body3(x)
|
| 89 |
+
# 13,13,512 -> 13,13,512
|
| 90 |
+
x = self.conv3(x)
|
| 91 |
+
feat2 = x
|
| 92 |
+
|
| 93 |
+
return feat1, feat2
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def darknet_tiny(pretrained, **kwargs):
|
| 97 |
+
model = CSPDarkNet()
|
| 98 |
+
if pretrained:
|
| 99 |
+
model.load_state_dict(torch.load("model_data/CSPdarknet53_tiny_backbone_weights.pth"))
|
| 100 |
+
return model
|
| 101 |
+
|
nets/FAENet.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class eca_block(nn.Module):
|
| 10 |
+
def __init__(self, channel, b=1, gamma=2):
|
| 11 |
+
super(eca_block, self).__init__()
|
| 12 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
| 13 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
| 14 |
+
|
| 15 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 16 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
| 17 |
+
self.sigmoid = nn.Sigmoid()
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
|
| 21 |
+
y = self.avg_pool(x)
|
| 22 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
| 23 |
+
y = self.sigmoid(y)
|
| 24 |
+
return x * y.expand_as(x)
|
| 25 |
+
|
| 26 |
+
class DilatedConvNet(nn.Module):
|
| 27 |
+
def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
|
| 28 |
+
super(DilatedConvNet, self).__init__()
|
| 29 |
+
self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
| 30 |
+
self.relu = nn.ReLU(inplace=False)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
|
| 34 |
+
x = self.dilated_conv(x)
|
| 35 |
+
x = self.relu(x)
|
| 36 |
+
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
class LAM(nn.Module):
|
| 40 |
+
def __init__(self, ch=16):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.eca = eca_block(ch)
|
| 43 |
+
self.conv1 = nn.Conv2d(6, 3, 3, padding=1)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.eca(x)
|
| 47 |
+
x = self.conv1(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
class RFEM(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
ch_blocks=64,
|
| 54 |
+
ch_mask=16,
|
| 55 |
+
):
|
| 56 |
+
super().__init__()
|
| 57 |
+
|
| 58 |
+
self.encoder = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1),
|
| 59 |
+
nn.LeakyReLU(True),
|
| 60 |
+
nn.Conv2d(16, ch_blocks, 3, padding=1),
|
| 61 |
+
nn.LeakyReLU(True))
|
| 62 |
+
|
| 63 |
+
self.dconv1 = DilatedConvNet(ch_blocks,
|
| 64 |
+
ch_blocks // 4,
|
| 65 |
+
kernel_size=3,
|
| 66 |
+
padding=1, dilation=1)
|
| 67 |
+
self.dconv2 = DilatedConvNet(ch_blocks,
|
| 68 |
+
ch_blocks // 4,
|
| 69 |
+
kernel_size=3,
|
| 70 |
+
padding=2, dilation=2)
|
| 71 |
+
self.dconv3 = DilatedConvNet(ch_blocks,
|
| 72 |
+
ch_blocks // 4,
|
| 73 |
+
kernel_size=3,
|
| 74 |
+
padding=3, dilation=3)
|
| 75 |
+
self.dconv4 = nn.Conv2d(ch_blocks,
|
| 76 |
+
ch_blocks // 4,
|
| 77 |
+
kernel_size=7,
|
| 78 |
+
padding=3)
|
| 79 |
+
|
| 80 |
+
self.decoder = nn.Sequential(nn.Conv2d(ch_blocks, 16, 3, padding=1),
|
| 81 |
+
nn.LeakyReLU(True),
|
| 82 |
+
nn.Conv2d(16, 3, 3, padding=1),
|
| 83 |
+
nn.LeakyReLU(True),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
self.lam = LAM(ch_mask)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
x1 = self.encoder(x)
|
| 90 |
+
x1_1 = self.dconv1(x1)
|
| 91 |
+
x1_2 = self.dconv2(x1)
|
| 92 |
+
x1_3 = self.dconv3(x1)
|
| 93 |
+
x1_4 = self.dconv4(x1)
|
| 94 |
+
x1 = torch.cat([x1_1, x1_2, x1_3, x1_4], dim=1)
|
| 95 |
+
x1 = self.decoder(x1)
|
| 96 |
+
out = x + x1
|
| 97 |
+
out = torch.relu(out)
|
| 98 |
+
mask = self.lam(torch.cat([x, out], dim=1))
|
| 99 |
+
return out, mask
|
| 100 |
+
|
| 101 |
+
class ATEM(nn.Module):
|
| 102 |
+
def __init__(self, in_ch=3, inter_ch=32, out_ch=3, kernel_size=3):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.encoder = nn.Sequential(
|
| 105 |
+
nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2),
|
| 106 |
+
nn.LeakyReLU(True),
|
| 107 |
+
)
|
| 108 |
+
self.shift_conv = nn.Sequential(
|
| 109 |
+
nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
|
| 110 |
+
self.scale_conv = nn.Sequential(
|
| 111 |
+
nn.Conv2d(in_ch, inter_ch, kernel_size, padding=kernel_size // 2))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
self.decoder = nn.Sequential(
|
| 115 |
+
nn.Conv2d(inter_ch, out_ch, kernel_size, padding=kernel_size // 2))
|
| 116 |
+
|
| 117 |
+
def forward(self, x, tag):
|
| 118 |
+
x = self.encoder(x)
|
| 119 |
+
scale = self.scale_conv(tag)
|
| 120 |
+
shift = self.shift_conv(tag)
|
| 121 |
+
x = x +(x * scale + shift)
|
| 122 |
+
x = self.decoder(x)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
class Trans_high(nn.Module):
|
| 126 |
+
def __init__(self, in_ch=3, inter_ch=16, out_ch=3, kernel_size=3):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.atem = ATEM(in_ch, inter_ch, out_ch, kernel_size)
|
| 129 |
+
def forward(self, x, tag):
|
| 130 |
+
x = x + self.atem(x, tag)
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Up_tag(nn.Module):
|
| 135 |
+
def __init__(self, kernel_size=1, ch=3):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.up = nn.Sequential(
|
| 138 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
| 139 |
+
nn.Conv2d(ch,
|
| 140 |
+
ch,
|
| 141 |
+
kernel_size,
|
| 142 |
+
stride=1,
|
| 143 |
+
padding=kernel_size // 2,
|
| 144 |
+
bias=False))
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
x = self.up(x)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Lap_Pyramid_Conv(nn.Module):
|
| 152 |
+
def __init__(self, num_high=3, kernel_size=5, channels=3):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.num_high = num_high
|
| 156 |
+
self.kernel = self.gauss_kernel(kernel_size, channels)
|
| 157 |
+
|
| 158 |
+
def gauss_kernel(self, kernel_size, channels):
|
| 159 |
+
kernel = cv2.getGaussianKernel(kernel_size, 0).dot(
|
| 160 |
+
cv2.getGaussianKernel(kernel_size, 0).T)
|
| 161 |
+
kernel = torch.FloatTensor(kernel).unsqueeze(0).repeat(
|
| 162 |
+
channels, 1, 1, 1)
|
| 163 |
+
kernel = torch.nn.Parameter(data=kernel, requires_grad=False)
|
| 164 |
+
return kernel
|
| 165 |
+
|
| 166 |
+
def conv_gauss(self, x, kernel):
|
| 167 |
+
n_channels, _, kw, kh = kernel.shape
|
| 168 |
+
x = torch.nn.functional.pad(x, (kw // 2, kh // 2, kw // 2, kh // 2),
|
| 169 |
+
mode='reflect')
|
| 170 |
+
x = torch.nn.functional.conv2d(x, kernel, groups=n_channels)
|
| 171 |
+
return x
|
| 172 |
+
def downsample(self, x):
|
| 173 |
+
return x[:, :, ::2, ::2]
|
| 174 |
+
def pyramid_down(self, x):
|
| 175 |
+
return self.downsample(self.conv_gauss(x, self.kernel))
|
| 176 |
+
def upsample(self, x):
|
| 177 |
+
up = torch.zeros((x.size(0), x.size(1), x.size(2) * 2, x.size(3) * 2),
|
| 178 |
+
device=x.device)
|
| 179 |
+
up[:, :, ::2, ::2] = x * 4
|
| 180 |
+
|
| 181 |
+
return self.conv_gauss(up, self.kernel)
|
| 182 |
+
|
| 183 |
+
def pyramid_decom(self, img):
|
| 184 |
+
self.kernel = self.kernel.to(img.device)
|
| 185 |
+
current = img
|
| 186 |
+
pyr = []
|
| 187 |
+
for _ in range(self.num_high):
|
| 188 |
+
down = self.pyramid_down(current)
|
| 189 |
+
up = self.upsample(down)
|
| 190 |
+
diff = current - up
|
| 191 |
+
pyr.append(diff)
|
| 192 |
+
current = down
|
| 193 |
+
pyr.append(current)
|
| 194 |
+
return pyr
|
| 195 |
+
|
| 196 |
+
def pyramid_recons(self, pyr):
|
| 197 |
+
image = pyr[0]
|
| 198 |
+
for level in pyr[1:]:
|
| 199 |
+
up = self.upsample(image)
|
| 200 |
+
image = up + level
|
| 201 |
+
return image
|
| 202 |
+
|
| 203 |
+
class FAENet(nn.Module):
|
| 204 |
+
def __init__(self,
|
| 205 |
+
num_high=1,
|
| 206 |
+
ch_blocks=32,
|
| 207 |
+
up_ksize=1,
|
| 208 |
+
high_ch=32,
|
| 209 |
+
high_ksize=3,
|
| 210 |
+
ch_mask=32,
|
| 211 |
+
gauss_kernel=7):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.num_high = num_high
|
| 214 |
+
self.lap_pyramid = Lap_Pyramid_Conv(num_high, gauss_kernel)
|
| 215 |
+
self.rfem = RFEM(ch_blocks, ch_mask)
|
| 216 |
+
|
| 217 |
+
for i in range(0, self.num_high):
|
| 218 |
+
self.__setattr__('up_tag_layer_{}'.format(i),
|
| 219 |
+
Up_tag(up_ksize, ch=3))
|
| 220 |
+
self.__setattr__('trans_high_layer_{}'.format(i),
|
| 221 |
+
Trans_high(3, high_ch, 3, high_ksize))
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
pyrs = self.lap_pyramid.pyramid_decom(img=x)
|
| 225 |
+
|
| 226 |
+
trans_pyrs = []
|
| 227 |
+
trans_pyr, tag = self.rfem(pyrs[-1])
|
| 228 |
+
trans_pyrs.append(trans_pyr)
|
| 229 |
+
|
| 230 |
+
commom_tag = []
|
| 231 |
+
for i in range(self.num_high):
|
| 232 |
+
tag = self.__getattr__('up_tag_layer_{}'.format(i))(tag)
|
| 233 |
+
commom_tag.append(tag)
|
| 234 |
+
|
| 235 |
+
for i in range(self.num_high):
|
| 236 |
+
trans_pyr = self.__getattr__('trans_high_layer_{}'.format(i))(
|
| 237 |
+
pyrs[-2 - i], commom_tag[i])
|
| 238 |
+
trans_pyrs.append(trans_pyr)
|
| 239 |
+
|
| 240 |
+
out = self.lap_pyramid.pyramid_recons(trans_pyrs)
|
| 241 |
+
|
| 242 |
+
return out
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
faenet = FAENet()
|
| 246 |
+
params = faenet.parameters()
|
| 247 |
+
num_params = sum(p.numel() for p in params)
|
| 248 |
+
print("FAENet parameters: {:.2f}K ".format(num_params/ 1024) + "{:.2f} MB".format(num_params/ (1024 * 1024)))
|
nets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#
|
nets/yolo.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from nets.CSPdarknet import darknet
|
| 5 |
+
from nets.CSPdarknet_tiny import darknet_tiny
|
| 6 |
+
from nets.mobilenetv2 import mobilenet_v2
|
| 7 |
+
from nets.shufflenet_v2 import shufflenet_v2
|
| 8 |
+
from nets.ghostnet import ghostnet
|
| 9 |
+
from nets.attention import cbam_block, eca_block, se_block, CA_Block
|
| 10 |
+
attention_block = [se_block, cbam_block, eca_block, CA_Block]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#-------------------------------------------------#
|
| 15 |
+
# 卷积块 -> 卷积 + 标准化 + 激活函数
|
| 16 |
+
# Conv2d + BatchNormalization + LeakyReLU
|
| 17 |
+
#-------------------------------------------------#
|
| 18 |
+
class BasicConv(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
| 20 |
+
super(BasicConv, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
|
| 23 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 24 |
+
self.activation = nn.LeakyReLU(0.1)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self.conv(x)
|
| 28 |
+
x = self.bn(x)
|
| 29 |
+
x = self.activation(x)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
#---------------------------------------------------#
|
| 33 |
+
# 卷积 + 上采样
|
| 34 |
+
#---------------------------------------------------#
|
| 35 |
+
class Upsample(nn.Module):
|
| 36 |
+
def __init__(self, in_channels, out_channels):
|
| 37 |
+
super(Upsample, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.upsample = nn.Sequential(
|
| 40 |
+
BasicConv(in_channels, out_channels, 1),
|
| 41 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x,):
|
| 45 |
+
x = self.upsample(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
#---------------------------------------------------#
|
| 49 |
+
# 最后获得yolov4的输出
|
| 50 |
+
#---------------------------------------------------#
|
| 51 |
+
def yolo_head(filters_list, in_filters):
|
| 52 |
+
m = nn.Sequential(
|
| 53 |
+
BasicConv(in_filters, filters_list[0], 3),
|
| 54 |
+
nn.Conv2d(filters_list[0], filters_list[1], 1),
|
| 55 |
+
)
|
| 56 |
+
return m
|
| 57 |
+
|
| 58 |
+
class ConvBNReLU(nn.Module):
|
| 59 |
+
'''Module for the Conv-BN-ReLU tuple.'''
|
| 60 |
+
|
| 61 |
+
def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation,
|
| 62 |
+
use_relu=True):
|
| 63 |
+
super(ConvBNReLU, self).__init__()
|
| 64 |
+
self.conv = nn.Conv2d(
|
| 65 |
+
c_in, c_out, kernel_size=kernel_size, stride=stride,
|
| 66 |
+
padding=padding, dilation=dilation, bias=False)
|
| 67 |
+
self.bn = nn.BatchNorm2d(c_out)
|
| 68 |
+
if use_relu:
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
else:
|
| 71 |
+
self.relu = None
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = self.conv(x)
|
| 75 |
+
x = self.bn(x)
|
| 76 |
+
if self.relu is not None:
|
| 77 |
+
x = self.relu(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CARAFE(nn.Module):
|
| 82 |
+
def __init__(self, c, c_mid=64, scale=2, k_up=5, k_enc=3):
|
| 83 |
+
""" The unofficial implementation of the CARAFE module.
|
| 84 |
+
|
| 85 |
+
The details are in "https://arxiv.org/abs/1905.02188".
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
c: The channel number of the input and the output.
|
| 89 |
+
c_mid: The channel number after compression.
|
| 90 |
+
scale: The expected upsample scale.
|
| 91 |
+
k_up: The size of the reassembly kernel.
|
| 92 |
+
k_enc: The kernel size of the encoder.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
X: The upsampled feature map.
|
| 96 |
+
"""
|
| 97 |
+
super(CARAFE, self).__init__()
|
| 98 |
+
self.scale = scale
|
| 99 |
+
|
| 100 |
+
self.comp = ConvBNReLU(c, c_mid, kernel_size=1, stride=1,
|
| 101 |
+
padding=0, dilation=1)
|
| 102 |
+
self.enc = ConvBNReLU(c_mid, (scale * k_up) ** 2, kernel_size=k_enc,
|
| 103 |
+
stride=1, padding=k_enc // 2, dilation=1,
|
| 104 |
+
use_relu=False)
|
| 105 |
+
self.pix_shf = nn.PixelShuffle(scale)
|
| 106 |
+
|
| 107 |
+
self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest')
|
| 108 |
+
self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale,
|
| 109 |
+
padding=k_up // 2 * scale)
|
| 110 |
+
|
| 111 |
+
def forward(self, X):
|
| 112 |
+
b, c, h, w = X.size()
|
| 113 |
+
h_, w_ = h * self.scale, w * self.scale
|
| 114 |
+
|
| 115 |
+
W = self.comp(X) # b * m * h * w
|
| 116 |
+
W = self.enc(W) # b * 100 * h * w
|
| 117 |
+
W = self.pix_shf(W) # b * 25 * h_ * w_
|
| 118 |
+
W = F.softmax(W, dim=1) # b * 25 * h_ * w_
|
| 119 |
+
|
| 120 |
+
X = self.upsmp(X) # b * c * h_ * w_
|
| 121 |
+
X = self.unfold(X) # b * 25c * h_ * w_
|
| 122 |
+
X = X.view(b, c, -1, h_, w_) # b * 25 * c * h_ * w_
|
| 123 |
+
|
| 124 |
+
X = torch.einsum('bkhw,bckhw->bchw', [W, X]) # b * c * h_ * w_
|
| 125 |
+
return X
|
| 126 |
+
|
| 127 |
+
#---------------------------------------------------#
|
| 128 |
+
# yolo_body--MSFNet
|
| 129 |
+
#---------------------------------------------------#
|
| 130 |
+
class YoloBody(nn.Module):
|
| 131 |
+
def __init__(self, anchors_mask, num_classes, phi=0, backbone ='', pretrained=False):
|
| 132 |
+
super(YoloBody, self).__init__()
|
| 133 |
+
|
| 134 |
+
self.phi = phi
|
| 135 |
+
|
| 136 |
+
if backbone == 'cspdarknet':
|
| 137 |
+
self.backbone = darknet(pretrained)
|
| 138 |
+
self.conv_for_P5 = BasicConv(512,256,1)
|
| 139 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
| 140 |
+
self.upsample_1 = Upsample(256,128)
|
| 141 |
+
self.conv1 = BasicConv(256,128,1)
|
| 142 |
+
self.upsample_2 = CARAFE(128)
|
| 143 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
|
| 144 |
+
|
| 145 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 146 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
| 147 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
| 148 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
| 149 |
+
|
| 150 |
+
elif backbone == 'tiny':
|
| 151 |
+
self.backbone = darknet_tiny(pretrained)
|
| 152 |
+
self.conv_for_P5 = BasicConv(512,256,1)
|
| 153 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
| 154 |
+
self.upsample_1 = Upsample(256,128)
|
| 155 |
+
self.conv1 = BasicConv(256,128,1)
|
| 156 |
+
self.upsample_2 = CARAFE(128)
|
| 157 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
|
| 158 |
+
|
| 159 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 160 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
| 161 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
| 162 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
| 163 |
+
|
| 164 |
+
elif backbone == 'mobilenetv2':
|
| 165 |
+
self.backbone = mobilenet_v2(pretrained)
|
| 166 |
+
self.conv_for_P5 = BasicConv(320,256,1)
|
| 167 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
| 168 |
+
self.upsample_1 = Upsample(256,128)
|
| 169 |
+
self.conv1 = BasicConv(256,128,1)
|
| 170 |
+
self.upsample_2 = CARAFE(128)
|
| 171 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],224)
|
| 172 |
+
|
| 173 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 174 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
| 175 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
| 176 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
| 177 |
+
|
| 178 |
+
elif backbone == 'shufflenetv2':
|
| 179 |
+
self.backbone = shufflenet_v2()
|
| 180 |
+
self.conv_for_P5 = BasicConv(1024,256,1)
|
| 181 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
| 182 |
+
self.upsample_1 = Upsample(256,128)
|
| 183 |
+
self.conv1 = BasicConv(256,128,1)
|
| 184 |
+
self.upsample_2 = CARAFE(128)
|
| 185 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],592)
|
| 186 |
+
|
| 187 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 188 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
| 189 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
| 190 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
| 191 |
+
|
| 192 |
+
elif backbone == 'ghostnet':
|
| 193 |
+
self.backbone = ghostnet()
|
| 194 |
+
self.conv_for_P5 = BasicConv(160,256,1)
|
| 195 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
| 196 |
+
self.upsample_1 = Upsample(256,128)
|
| 197 |
+
self.conv1 = BasicConv(256,128,1)
|
| 198 |
+
self.upsample_2 = CARAFE(128)
|
| 199 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],240)
|
| 200 |
+
|
| 201 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 202 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
| 203 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
| 204 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
|
| 210 |
+
#---------------------------------------------------#
|
| 211 |
+
# 生成CSPdarknet53_tiny的主干模型
|
| 212 |
+
# feat1的shape为26,26,256
|
| 213 |
+
# feat2的shape为13,13,512
|
| 214 |
+
#---------------------------------------------------#
|
| 215 |
+
feat1, feat2 = self.backbone(x)
|
| 216 |
+
if 1 <= self.phi and self.phi <= 4:
|
| 217 |
+
feat1 = self.feat1_att(feat1)
|
| 218 |
+
feat2 = self.feat2_att(feat2)
|
| 219 |
+
|
| 220 |
+
# 13,13,512 -> 13,13,256
|
| 221 |
+
P5 = self.conv_for_P5(feat2)
|
| 222 |
+
# 13,13,256 -> 13,13,512 -> 13,13,255
|
| 223 |
+
out0 = self.yolo_headP5(P5)
|
| 224 |
+
|
| 225 |
+
P6 = self.conv_for_P5(feat2)
|
| 226 |
+
P6_Upsample = self.upsample_1(P6)
|
| 227 |
+
|
| 228 |
+
# 13,13,256 -> 13,13,128 -> 26,26,128 卷积+轻量级上采样
|
| 229 |
+
P5 = self.conv1(P5)
|
| 230 |
+
P5_Upsample = self.upsample_2(P5)
|
| 231 |
+
|
| 232 |
+
sum = P5_Upsample + P6_Upsample
|
| 233 |
+
# 26,26,256 + 26,26,128 -> 26,26,384
|
| 234 |
+
# if 1 <= self.phi and self.phi <= 4:
|
| 235 |
+
# P5_Upsample = self.upsample_att(P5_Upsample)
|
| 236 |
+
# P4 = torch.cat([P5_Upsample, feat1],axis=1)
|
| 237 |
+
P4 = torch.cat([sum, feat1],axis=1)
|
| 238 |
+
# 26,26,384 -> 26,26,256 -> 26,26,255
|
| 239 |
+
out1 = self.yolo_headP4(P4)
|
| 240 |
+
|
| 241 |
+
return out0, out1
|
nets/yolo_training.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
class YOLOLoss(nn.Module):
|
| 9 |
+
def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
|
| 10 |
+
super(YOLOLoss, self).__init__()
|
| 11 |
+
self.anchors = anchors
|
| 12 |
+
self.num_classes = num_classes
|
| 13 |
+
self.bbox_attrs = 5 + num_classes
|
| 14 |
+
self.input_shape = input_shape
|
| 15 |
+
self.anchors_mask = anchors_mask
|
| 16 |
+
self.label_smoothing = label_smoothing
|
| 17 |
+
self.balance = [0.4, 1.0, 4]
|
| 18 |
+
self.box_ratio = 0.05
|
| 19 |
+
self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
|
| 20 |
+
self.cls_ratio = 1 * (num_classes / 80)
|
| 21 |
+
|
| 22 |
+
self.ignore_threshold = 0.5
|
| 23 |
+
self.cuda = cuda
|
| 24 |
+
|
| 25 |
+
def clip_by_tensor(self, t, t_min, t_max):
|
| 26 |
+
t = t.float()
|
| 27 |
+
result = (t >= t_min).float() * t + (t < t_min).float() * t_min
|
| 28 |
+
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
|
| 29 |
+
return result
|
| 30 |
+
|
| 31 |
+
def MSELoss(self, pred, target):
|
| 32 |
+
return torch.pow(pred - target, 2)
|
| 33 |
+
|
| 34 |
+
def BCELoss(self, pred, target):
|
| 35 |
+
epsilon = 1e-7
|
| 36 |
+
pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
|
| 37 |
+
output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
def box_ciou(self, b1, b2):
|
| 41 |
+
b1_xy = b1[..., :2]
|
| 42 |
+
b1_wh = b1[..., 2:4]
|
| 43 |
+
b1_wh_half = b1_wh/2.
|
| 44 |
+
b1_mins = b1_xy - b1_wh_half
|
| 45 |
+
b1_maxes = b1_xy + b1_wh_half
|
| 46 |
+
b2_xy = b2[..., :2]
|
| 47 |
+
b2_wh = b2[..., 2:4]
|
| 48 |
+
b2_wh_half = b2_wh/2.
|
| 49 |
+
b2_mins = b2_xy - b2_wh_half
|
| 50 |
+
b2_maxes = b2_xy + b2_wh_half
|
| 51 |
+
|
| 52 |
+
intersect_mins = torch.max(b1_mins, b2_mins)
|
| 53 |
+
intersect_maxes = torch.min(b1_maxes, b2_maxes)
|
| 54 |
+
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
|
| 55 |
+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
| 56 |
+
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
|
| 57 |
+
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
|
| 58 |
+
union_area = b1_area + b2_area - intersect_area
|
| 59 |
+
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
|
| 60 |
+
|
| 61 |
+
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
|
| 62 |
+
|
| 63 |
+
enclose_mins = torch.min(b1_mins, b2_mins)
|
| 64 |
+
enclose_maxes = torch.max(b1_maxes, b2_maxes)
|
| 65 |
+
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
|
| 66 |
+
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
|
| 67 |
+
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
|
| 68 |
+
|
| 69 |
+
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
|
| 70 |
+
alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
|
| 71 |
+
ciou = ciou - alpha * v
|
| 72 |
+
return ciou
|
| 73 |
+
|
| 74 |
+
def smooth_labels(self, y_true, label_smoothing, num_classes):
|
| 75 |
+
return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
|
| 76 |
+
|
| 77 |
+
def forward(self, l, input, targets=None):
|
| 78 |
+
bs = input.size(0)
|
| 79 |
+
in_h = input.size(2)
|
| 80 |
+
in_w = input.size(3)
|
| 81 |
+
stride_h = self.input_shape[0] / in_h
|
| 82 |
+
stride_w = self.input_shape[1] / in_w
|
| 83 |
+
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
|
| 84 |
+
|
| 85 |
+
prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
|
| 86 |
+
|
| 87 |
+
x = torch.sigmoid(prediction[..., 0])
|
| 88 |
+
y = torch.sigmoid(prediction[..., 1])
|
| 89 |
+
w = prediction[..., 2]
|
| 90 |
+
h = prediction[..., 3]
|
| 91 |
+
conf = torch.sigmoid(prediction[..., 4])
|
| 92 |
+
pred_cls = torch.sigmoid(prediction[..., 5:])
|
| 93 |
+
|
| 94 |
+
y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
|
| 95 |
+
|
| 96 |
+
noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
|
| 97 |
+
|
| 98 |
+
if self.cuda:
|
| 99 |
+
y_true = y_true.type_as(x)
|
| 100 |
+
noobj_mask = noobj_mask.type_as(x)
|
| 101 |
+
box_loss_scale = box_loss_scale.type_as(x)
|
| 102 |
+
box_loss_scale = 2 - box_loss_scale
|
| 103 |
+
|
| 104 |
+
loss = 0
|
| 105 |
+
obj_mask = y_true[..., 4] == 1
|
| 106 |
+
n = torch.sum(obj_mask)
|
| 107 |
+
if n != 0:
|
| 108 |
+
ciou = self.box_ciou(pred_boxes, y_true[..., :4]).type_as(x)
|
| 109 |
+
loss_loc = torch.mean((1 - ciou)[obj_mask])
|
| 110 |
+
|
| 111 |
+
loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
|
| 112 |
+
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
|
| 113 |
+
|
| 114 |
+
loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])
|
| 115 |
+
loss += loss_conf * self.balance[l] * self.obj_ratio
|
| 116 |
+
return loss
|
| 117 |
+
|
| 118 |
+
def calculate_iou(self, _box_a, _box_b):
|
| 119 |
+
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
|
| 120 |
+
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
|
| 121 |
+
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
|
| 122 |
+
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
|
| 123 |
+
|
| 124 |
+
box_a = torch.zeros_like(_box_a)
|
| 125 |
+
box_b = torch.zeros_like(_box_b)
|
| 126 |
+
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
|
| 127 |
+
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
|
| 128 |
+
|
| 129 |
+
A = box_a.size(0)
|
| 130 |
+
B = box_b.size(0)
|
| 131 |
+
|
| 132 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
| 133 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
| 134 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
| 135 |
+
inter = inter[:, :, 0] * inter[:, :, 1]
|
| 136 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)
|
| 137 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)
|
| 138 |
+
union = area_a + area_b - inter
|
| 139 |
+
return inter / union
|
| 140 |
+
|
| 141 |
+
def get_target(self, l, targets, anchors, in_h, in_w):
|
| 142 |
+
bs = len(targets)
|
| 143 |
+
noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
| 144 |
+
box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
| 145 |
+
y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
|
| 146 |
+
for b in range(bs):
|
| 147 |
+
if len(targets[b])==0:
|
| 148 |
+
continue
|
| 149 |
+
batch_target = torch.zeros_like(targets[b])
|
| 150 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
| 151 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
| 152 |
+
batch_target[:, 4] = targets[b][:, 4]
|
| 153 |
+
batch_target = batch_target.cpu()
|
| 154 |
+
|
| 155 |
+
gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
|
| 156 |
+
anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
|
| 157 |
+
iou = self.calculate_iou(gt_box, anchor_shapes)
|
| 158 |
+
best_ns = torch.argmax(iou, dim=-1)
|
| 159 |
+
sort_ns = torch.argsort(iou, dim=-1, descending=True)
|
| 160 |
+
|
| 161 |
+
def check_in_anchors_mask(index, anchors_mask):
|
| 162 |
+
for sub_anchors_mask in anchors_mask:
|
| 163 |
+
if index in sub_anchors_mask:
|
| 164 |
+
return True
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
for t, best_n in enumerate(best_ns):
|
| 168 |
+
if not check_in_anchors_mask(best_n, self.anchors_mask):
|
| 169 |
+
for index in sort_ns[t]:
|
| 170 |
+
if check_in_anchors_mask(index, self.anchors_mask):
|
| 171 |
+
best_n = index
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
if best_n not in self.anchors_mask[l]:
|
| 175 |
+
continue
|
| 176 |
+
k = self.anchors_mask[l].index(best_n)
|
| 177 |
+
i = torch.floor(batch_target[t, 0]).long()
|
| 178 |
+
j = torch.floor(batch_target[t, 1]).long()
|
| 179 |
+
c = batch_target[t, 4].long()
|
| 180 |
+
|
| 181 |
+
noobj_mask[b, k, j, i] = 0
|
| 182 |
+
y_true[b, k, j, i, 0] = batch_target[t, 0]
|
| 183 |
+
y_true[b, k, j, i, 1] = batch_target[t, 1]
|
| 184 |
+
y_true[b, k, j, i, 2] = batch_target[t, 2]
|
| 185 |
+
y_true[b, k, j, i, 3] = batch_target[t, 3]
|
| 186 |
+
y_true[b, k, j, i, 4] = 1
|
| 187 |
+
y_true[b, k, j, i, c + 5] = 1
|
| 188 |
+
box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
|
| 189 |
+
return y_true, noobj_mask, box_loss_scale
|
| 190 |
+
|
| 191 |
+
def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
|
| 192 |
+
bs = len(targets)
|
| 193 |
+
|
| 194 |
+
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
|
| 195 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
|
| 196 |
+
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
|
| 197 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
|
| 198 |
+
|
| 199 |
+
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
|
| 200 |
+
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
|
| 201 |
+
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
|
| 202 |
+
|
| 203 |
+
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
|
| 204 |
+
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
|
| 205 |
+
pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
|
| 206 |
+
pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
|
| 207 |
+
pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
|
| 208 |
+
pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
|
| 209 |
+
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
|
| 210 |
+
for b in range(bs):
|
| 211 |
+
pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
|
| 212 |
+
if len(targets[b]) > 0:
|
| 213 |
+
batch_target = torch.zeros_like(targets[b])
|
| 214 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
| 215 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
| 216 |
+
batch_target = batch_target[:, :4].type_as(x)
|
| 217 |
+
anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
|
| 218 |
+
anch_ious_max, _ = torch.max(anch_ious, dim = 0)
|
| 219 |
+
anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
|
| 220 |
+
noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
|
| 221 |
+
return noobj_mask, pred_boxes
|
| 222 |
+
|
| 223 |
+
def weights_init(net, init_type='normal', init_gain = 0.02):
|
| 224 |
+
def init_func(m):
|
| 225 |
+
classname = m.__class__.__name__
|
| 226 |
+
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
| 227 |
+
if init_type == 'normal':
|
| 228 |
+
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
| 229 |
+
elif init_type == 'xavier':
|
| 230 |
+
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
| 231 |
+
elif init_type == 'kaiming':
|
| 232 |
+
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 233 |
+
elif init_type == 'orthogonal':
|
| 234 |
+
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
| 235 |
+
else:
|
| 236 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 237 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 238 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 239 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
| 240 |
+
print('initialize network with %s type' % init_type)
|
| 241 |
+
net.apply(init_func)
|
| 242 |
+
|
| 243 |
+
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
| 244 |
+
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
| 245 |
+
if iters <= warmup_total_iters:
|
| 246 |
+
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
| 247 |
+
elif iters >= total_iters - no_aug_iter:
|
| 248 |
+
lr = min_lr
|
| 249 |
+
else:
|
| 250 |
+
lr = min_lr + 0.5 * (lr - min_lr) * (
|
| 251 |
+
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
| 252 |
+
)
|
| 253 |
+
return lr
|
| 254 |
+
|
| 255 |
+
def step_lr(lr, decay_rate, step_size, iters):
|
| 256 |
+
if step_size < 1:
|
| 257 |
+
raise ValueError("step_size must above 1.")
|
| 258 |
+
n = iters // step_size
|
| 259 |
+
out_lr = lr * decay_rate ** n
|
| 260 |
+
return out_lr
|
| 261 |
+
|
| 262 |
+
if lr_decay_type == "cos":
|
| 263 |
+
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
| 264 |
+
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
| 265 |
+
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
| 266 |
+
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
| 267 |
+
else:
|
| 268 |
+
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
| 269 |
+
step_size = total_iters / step_num
|
| 270 |
+
func = partial(step_lr, lr, decay_rate, step_size)
|
| 271 |
+
|
| 272 |
+
return func
|
| 273 |
+
|
| 274 |
+
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
| 275 |
+
lr = lr_scheduler_func(epoch)
|
| 276 |
+
for param_group in optimizer.param_groups:
|
| 277 |
+
param_group['lr'] = lr
|