""" Implementation of YOLOv3 architecture """ import torch import torch.nn as nn class CNNBlock(nn.Module): """Convolutional block with BatchNorm and LeakyReLU""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bn_act=True): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not bn_act) self.bn = nn.BatchNorm2d(out_channels) self.leaky = nn.LeakyReLU(0.1) self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: return self.leaky(self.bn(self.conv(x))) else: return self.conv(x) class ResidualBlock(nn.Module): """Residual block with skip connection""" def __init__(self, channels, num_repeats=1): super().__init__() self.layers = nn.ModuleList() for _ in range(num_repeats): self.layers.append( nn.Sequential( CNNBlock(channels, channels // 2, kernel_size=1), CNNBlock(channels // 2, channels, kernel_size=3, padding=1), ) ) def forward(self, x): for layer in self.layers: x = x + layer(x) return x class ScalePrediction(nn.Module): """Scale prediction block for YOLO output""" def __init__(self, in_channels, num_classes): super().__init__() self.pred = nn.Sequential( CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False), ) self.num_classes = num_classes def forward(self, x): return ( self.pred(x) .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) .permute(0, 1, 3, 4, 2) ) class YOLOv3(nn.Module): """YOLOv3 architecture with Darknet-53 backbone""" def __init__(self, in_channels=3, num_classes=20): super().__init__() self.num_classes = num_classes # Darknet-53 Backbone self.conv1 = CNNBlock(in_channels, 32, kernel_size=3, stride=1, padding=1) self.conv2 = CNNBlock(32, 64, kernel_size=3, stride=2, padding=1) self.residual1 = ResidualBlock(64, num_repeats=1) self.conv3 = CNNBlock(64, 128, kernel_size=3, stride=2, padding=1) self.residual2 = ResidualBlock(128, num_repeats=2) self.conv4 = CNNBlock(128, 256, kernel_size=3, stride=2, padding=1) self.residual3 = ResidualBlock(256, num_repeats=8) self.conv5 = CNNBlock(256, 512, kernel_size=3, stride=2, padding=1) self.residual4 = ResidualBlock(512, num_repeats=8) self.conv6 = CNNBlock(512, 1024, kernel_size=3, stride=2, padding=1) self.residual5 = ResidualBlock(1024, num_repeats=4) # First scale prediction (13x13 - large objects) self.conv7 = CNNBlock(1024, 512, kernel_size=1, stride=1, padding=0) self.conv8 = CNNBlock(512, 1024, kernel_size=3, stride=1, padding=1) self.residual6 = ResidualBlock(1024, num_repeats=1) self.conv9 = CNNBlock(1024, 512, kernel_size=1, stride=1, padding=0) self.scale_pred1 = ScalePrediction(512, num_classes=num_classes) # Second scale (26x26 - medium objects) self.conv10 = CNNBlock(512, 256, kernel_size=1, stride=1, padding=0) self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest') self.conv11 = CNNBlock(768, 256, kernel_size=1, stride=1, padding=0) self.conv12 = CNNBlock(256, 512, kernel_size=3, stride=1, padding=1) self.residual7 = ResidualBlock(512, num_repeats=1) self.conv13 = CNNBlock(512, 256, kernel_size=1, stride=1, padding=0) self.scale_pred2 = ScalePrediction(256, num_classes=num_classes) # Third scale (52x52 - small objects) self.conv14 = CNNBlock(256, 128, kernel_size=1, stride=1, padding=0) self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest') self.conv15 = CNNBlock(384, 128, kernel_size=1, stride=1, padding=0) self.conv16 = CNNBlock(128, 256, kernel_size=3, stride=1, padding=1) self.residual8 = ResidualBlock(256, num_repeats=1) self.conv17 = CNNBlock(256, 128, kernel_size=1, stride=1, padding=0) self.scale_pred3 = ScalePrediction(128, num_classes=num_classes) def forward(self, x): # Darknet-53 feature extraction x = self.conv1(x) x = self.conv2(x) x = self.residual1(x) x = self.conv3(x) x = self.residual2(x) x = self.conv4(x) route1 = self.residual3(x) x = self.conv5(route1) route2 = self.residual4(x) x = self.conv6(route2) x = self.residual5(x) # First scale (13x13) x = self.conv7(x) x = self.conv8(x) x = self.residual6(x) x = self.conv9(x) out1 = self.scale_pred1(x) # Second scale (26x26) x = self.conv10(x) x = self.upsample1(x) x = torch.cat([x, route2], dim=1) x = self.conv11(x) x = self.conv12(x) x = self.residual7(x) x = self.conv13(x) out2 = self.scale_pred2(x) # Third scale (52x52) x = self.conv14(x) x = self.upsample2(x) x = torch.cat([x, route1], dim=1) x = self.conv15(x) x = self.conv16(x) x = self.residual8(x) x = self.conv17(x) out3 = self.scale_pred3(x) return [out1, out2, out3]