Spaces:
Build error
Build error
| # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license | |
| from __future__ import absolute_import, division | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| __all__ = ["HACNN"] | |
| class ConvBlock(nn.Module): | |
| """Basic convolutional block. | |
| convolution + batch normalization + relu. | |
| Args: | |
| in_c (int): number of input channels. | |
| out_c (int): number of output channels. | |
| k (int or tuple): kernel size. | |
| s (int or tuple): stride. | |
| p (int or tuple): padding. | |
| """ | |
| def __init__(self, in_c, out_c, k, s=1, p=0): | |
| super(ConvBlock, self).__init__() | |
| self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) | |
| self.bn = nn.BatchNorm2d(out_c) | |
| def forward(self, x): | |
| return F.relu(self.bn(self.conv(x))) | |
| class InceptionA(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(InceptionA, self).__init__() | |
| mid_channels = out_channels // 4 | |
| self.stream1 = nn.Sequential( | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ConvBlock(mid_channels, mid_channels, 3, p=1), | |
| ) | |
| self.stream2 = nn.Sequential( | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ConvBlock(mid_channels, mid_channels, 3, p=1), | |
| ) | |
| self.stream3 = nn.Sequential( | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ConvBlock(mid_channels, mid_channels, 3, p=1), | |
| ) | |
| self.stream4 = nn.Sequential( | |
| nn.AvgPool2d(3, stride=1, padding=1), | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ) | |
| def forward(self, x): | |
| s1 = self.stream1(x) | |
| s2 = self.stream2(x) | |
| s3 = self.stream3(x) | |
| s4 = self.stream4(x) | |
| y = torch.cat([s1, s2, s3, s4], dim=1) | |
| return y | |
| class InceptionB(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(InceptionB, self).__init__() | |
| mid_channels = out_channels // 4 | |
| self.stream1 = nn.Sequential( | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), | |
| ) | |
| self.stream2 = nn.Sequential( | |
| ConvBlock(in_channels, mid_channels, 1), | |
| ConvBlock(mid_channels, mid_channels, 3, p=1), | |
| ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), | |
| ) | |
| self.stream3 = nn.Sequential( | |
| nn.MaxPool2d(3, stride=2, padding=1), | |
| ConvBlock(in_channels, mid_channels * 2, 1), | |
| ) | |
| def forward(self, x): | |
| s1 = self.stream1(x) | |
| s2 = self.stream2(x) | |
| s3 = self.stream3(x) | |
| y = torch.cat([s1, s2, s3], dim=1) | |
| return y | |
| class SpatialAttn(nn.Module): | |
| """Spatial Attention (Sec. 3.1.I.1)""" | |
| def __init__(self): | |
| super(SpatialAttn, self).__init__() | |
| self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) | |
| self.conv2 = ConvBlock(1, 1, 1) | |
| def forward(self, x): | |
| # global cross-channel averaging | |
| x = x.mean(1, keepdim=True) | |
| # 3-by-3 conv | |
| x = self.conv1(x) | |
| # bilinear resizing | |
| x = F.upsample( | |
| x, (x.size(2) * 2, x.size(3) * 2), mode="bilinear", align_corners=True | |
| ) | |
| # scaling conv | |
| x = self.conv2(x) | |
| return x | |
| class ChannelAttn(nn.Module): | |
| """Channel Attention (Sec. 3.1.I.2)""" | |
| def __init__(self, in_channels, reduction_rate=16): | |
| super(ChannelAttn, self).__init__() | |
| assert in_channels % reduction_rate == 0 | |
| self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1) | |
| self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1) | |
| def forward(self, x): | |
| # squeeze operation (global average pooling) | |
| x = F.avg_pool2d(x, x.size()[2:]) | |
| # excitation operation (2 conv layers) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class SoftAttn(nn.Module): | |
| """Soft Attention (Sec. 3.1.I) | |
| Aim: Spatial Attention + Channel Attention | |
| Output: attention maps with shape identical to input. | |
| """ | |
| def __init__(self, in_channels): | |
| super(SoftAttn, self).__init__() | |
| self.spatial_attn = SpatialAttn() | |
| self.channel_attn = ChannelAttn(in_channels) | |
| self.conv = ConvBlock(in_channels, in_channels, 1) | |
| def forward(self, x): | |
| y_spatial = self.spatial_attn(x) | |
| y_channel = self.channel_attn(x) | |
| y = y_spatial * y_channel | |
| y = torch.sigmoid(self.conv(y)) | |
| return y | |
| class HardAttn(nn.Module): | |
| """Hard Attention (Sec. 3.1.II)""" | |
| def __init__(self, in_channels): | |
| super(HardAttn, self).__init__() | |
| self.fc = nn.Linear(in_channels, 4 * 2) | |
| self.init_params() | |
| def init_params(self): | |
| self.fc.weight.data.zero_() | |
| self.fc.bias.data.copy_( | |
| torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float) | |
| ) | |
| def forward(self, x): | |
| # squeeze operation (global average pooling) | |
| x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) | |
| # predict transformation parameters | |
| theta = torch.tanh(self.fc(x)) | |
| theta = theta.view(-1, 4, 2) | |
| return theta | |
| class HarmAttn(nn.Module): | |
| """Harmonious Attention (Sec. 3.1)""" | |
| def __init__(self, in_channels): | |
| super(HarmAttn, self).__init__() | |
| self.soft_attn = SoftAttn(in_channels) | |
| self.hard_attn = HardAttn(in_channels) | |
| def forward(self, x): | |
| y_soft_attn = self.soft_attn(x) | |
| theta = self.hard_attn(x) | |
| return y_soft_attn, theta | |
| class HACNN(nn.Module): | |
| """Harmonious Attention Convolutional Neural Network. | |
| Reference: | |
| Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. | |
| Public keys: | |
| - ``hacnn``: HACNN. | |
| """ | |
| # Args: | |
| # num_classes (int): number of classes to predict | |
| # nchannels (list): number of channels AFTER concatenation | |
| # feat_dim (int): feature dimension for a single stream | |
| # learn_region (bool): whether to learn region features (i.e. local branch) | |
| def __init__( | |
| self, | |
| num_classes, | |
| loss="softmax", | |
| nchannels=[128, 256, 384], | |
| feat_dim=512, | |
| learn_region=True, | |
| use_gpu=True, | |
| **kwargs | |
| ): | |
| super(HACNN, self).__init__() | |
| self.loss = loss | |
| self.learn_region = learn_region | |
| self.use_gpu = use_gpu | |
| self.conv = ConvBlock(3, 32, 3, s=2, p=1) | |
| # Construct Inception + HarmAttn blocks | |
| # ============== Block 1 ============== | |
| self.inception1 = nn.Sequential( | |
| InceptionA(32, nchannels[0]), | |
| InceptionB(nchannels[0], nchannels[0]), | |
| ) | |
| self.ha1 = HarmAttn(nchannels[0]) | |
| # ============== Block 2 ============== | |
| self.inception2 = nn.Sequential( | |
| InceptionA(nchannels[0], nchannels[1]), | |
| InceptionB(nchannels[1], nchannels[1]), | |
| ) | |
| self.ha2 = HarmAttn(nchannels[1]) | |
| # ============== Block 3 ============== | |
| self.inception3 = nn.Sequential( | |
| InceptionA(nchannels[1], nchannels[2]), | |
| InceptionB(nchannels[2], nchannels[2]), | |
| ) | |
| self.ha3 = HarmAttn(nchannels[2]) | |
| self.fc_global = nn.Sequential( | |
| nn.Linear(nchannels[2], feat_dim), | |
| nn.BatchNorm1d(feat_dim), | |
| nn.ReLU(), | |
| ) | |
| self.classifier_global = nn.Linear(feat_dim, num_classes) | |
| if self.learn_region: | |
| self.init_scale_factors() | |
| self.local_conv1 = InceptionB(32, nchannels[0]) | |
| self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) | |
| self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) | |
| self.fc_local = nn.Sequential( | |
| nn.Linear(nchannels[2] * 4, feat_dim), | |
| nn.BatchNorm1d(feat_dim), | |
| nn.ReLU(), | |
| ) | |
| self.classifier_local = nn.Linear(feat_dim, num_classes) | |
| self.feat_dim = feat_dim * 2 | |
| else: | |
| self.feat_dim = feat_dim | |
| def init_scale_factors(self): | |
| # initialize scale factors (s_w, s_h) for four regions | |
| self.scale_factors = [] | |
| self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) | |
| self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) | |
| self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) | |
| self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) | |
| def stn(self, x, theta): | |
| """Performs spatial transform | |
| x: (batch, channel, height, width) | |
| theta: (batch, 2, 3) | |
| """ | |
| grid = F.affine_grid(theta, x.size()) | |
| x = F.grid_sample(x, grid) | |
| return x | |
| def transform_theta(self, theta_i, region_idx): | |
| """Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)""" | |
| scale_factors = self.scale_factors[region_idx] | |
| theta = torch.zeros(theta_i.size(0), 2, 3) | |
| theta[:, :, :2] = scale_factors | |
| theta[:, :, -1] = theta_i | |
| if self.use_gpu: | |
| theta = theta.cuda() | |
| return theta | |
| def forward(self, x): | |
| assert ( | |
| x.size(2) == 160 and x.size(3) == 64 | |
| ), "Input size does not match, expected (160, 64) but got ({}, {})".format( | |
| x.size(2), x.size(3) | |
| ) | |
| x = self.conv(x) | |
| # ============== Block 1 ============== | |
| # global branch | |
| x1 = self.inception1(x) | |
| x1_attn, x1_theta = self.ha1(x1) | |
| x1_out = x1 * x1_attn | |
| # local branch | |
| if self.learn_region: | |
| x1_local_list = [] | |
| for region_idx in range(4): | |
| x1_theta_i = x1_theta[:, region_idx, :] | |
| x1_theta_i = self.transform_theta(x1_theta_i, region_idx) | |
| x1_trans_i = self.stn(x, x1_theta_i) | |
| x1_trans_i = F.upsample( | |
| x1_trans_i, (24, 28), mode="bilinear", align_corners=True | |
| ) | |
| x1_local_i = self.local_conv1(x1_trans_i) | |
| x1_local_list.append(x1_local_i) | |
| # ============== Block 2 ============== | |
| # Block 2 | |
| # global branch | |
| x2 = self.inception2(x1_out) | |
| x2_attn, x2_theta = self.ha2(x2) | |
| x2_out = x2 * x2_attn | |
| # local branch | |
| if self.learn_region: | |
| x2_local_list = [] | |
| for region_idx in range(4): | |
| x2_theta_i = x2_theta[:, region_idx, :] | |
| x2_theta_i = self.transform_theta(x2_theta_i, region_idx) | |
| x2_trans_i = self.stn(x1_out, x2_theta_i) | |
| x2_trans_i = F.upsample( | |
| x2_trans_i, (12, 14), mode="bilinear", align_corners=True | |
| ) | |
| x2_local_i = x2_trans_i + x1_local_list[region_idx] | |
| x2_local_i = self.local_conv2(x2_local_i) | |
| x2_local_list.append(x2_local_i) | |
| # ============== Block 3 ============== | |
| # Block 3 | |
| # global branch | |
| x3 = self.inception3(x2_out) | |
| x3_attn, x3_theta = self.ha3(x3) | |
| x3_out = x3 * x3_attn | |
| # local branch | |
| if self.learn_region: | |
| x3_local_list = [] | |
| for region_idx in range(4): | |
| x3_theta_i = x3_theta[:, region_idx, :] | |
| x3_theta_i = self.transform_theta(x3_theta_i, region_idx) | |
| x3_trans_i = self.stn(x2_out, x3_theta_i) | |
| x3_trans_i = F.upsample( | |
| x3_trans_i, (6, 7), mode="bilinear", align_corners=True | |
| ) | |
| x3_local_i = x3_trans_i + x2_local_list[region_idx] | |
| x3_local_i = self.local_conv3(x3_local_i) | |
| x3_local_list.append(x3_local_i) | |
| # ============== Feature generation ============== | |
| # global branch | |
| x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view( | |
| x3_out.size(0), x3_out.size(1) | |
| ) | |
| x_global = self.fc_global(x_global) | |
| # local branch | |
| if self.learn_region: | |
| x_local_list = [] | |
| for region_idx in range(4): | |
| x_local_i = x3_local_list[region_idx] | |
| x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view( | |
| x_local_i.size(0), -1 | |
| ) | |
| x_local_list.append(x_local_i) | |
| x_local = torch.cat(x_local_list, 1) | |
| x_local = self.fc_local(x_local) | |
| if not self.training: | |
| # l2 normalization before concatenation | |
| if self.learn_region: | |
| x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) | |
| x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) | |
| return torch.cat([x_global, x_local], 1) | |
| else: | |
| return x_global | |
| prelogits_global = self.classifier_global(x_global) | |
| if self.learn_region: | |
| prelogits_local = self.classifier_local(x_local) | |
| if self.loss == "softmax": | |
| if self.learn_region: | |
| return (prelogits_global, prelogits_local) | |
| else: | |
| return prelogits_global | |
| elif self.loss == "triplet": | |
| if self.learn_region: | |
| return (prelogits_global, prelogits_local), (x_global, x_local) | |
| else: | |
| return prelogits_global, x_global | |
| else: | |
| raise KeyError("Unsupported loss: {}".format(self.loss)) | |