File size: 3,823 Bytes
11aa70b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE/)
Copyright (c) 2024 D-FINE Authors. All Rights Reserved.
"""

import copy
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from .utils import get_activation

from ..core import register
from .hybrid_encoder import ConvNormLayer_fuse
from .hybrid_encoder import RepNCSPELAN4

__all__ = ['LiteEncoder']


# Copy from https://github.com/meituan/YOLOv6/blob/main/yolov6/layers/common.py#L695
class GAP_Fusion(nn.Module):
    '''BiFusion Block in PAN'''
    def __init__(self, in_channels, out_channels, act=None):
        super().__init__()
        self.cv = ConvNormLayer_fuse(out_channels, out_channels, 1, 1, act=act)

    def forward(self, x):
        # global average pooling
        gap = F.adaptive_avg_pool2d(x, 1)
        x = x + gap
        return self.cv(x)
        
# Two-scale encoder
@register()
class LiteEncoder(nn.Module):
    __share__ = ['eval_spatial_size', ]

    def __init__(self,
                 in_channels=[512],
                 feat_strides=[16],
                 hidden_dim=256,
                 expansion=1.0,
                 depth_mult=1.0,
                 act='silu',
                 eval_spatial_size=None,
                 csp_type='csp2',
                 ):
        super().__init__()
        self.in_channels = in_channels
        self.feat_strides = feat_strides
        self.hidden_dim = hidden_dim
        self.eval_spatial_size = eval_spatial_size
        self.out_channels = [hidden_dim for _ in range(len(in_channels))]
        self.out_strides = feat_strides
        
        # channel projection: unify the channel dimension of the input features
        self.input_proj = nn.ModuleList()
        for in_channel in in_channels:
            proj = nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)),
                    ('norm', nn.BatchNorm2d(hidden_dim))
                ]))

            self.input_proj.append(proj)

        # get the small-scale feature
        down_sample = nn.Sequential(   # avg pooling
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(hidden_dim, hidden_dim, 1, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            get_activation(act)
        )
        self.down_sample1 = copy.deepcopy(down_sample)
        self.down_sample2 = copy.deepcopy(down_sample)

        # Bi-Fusion
        self.bi_fusion = GAP_Fusion(hidden_dim, hidden_dim, act=act)

        # fuse block
        c1, c2, c3, c4, num_blocks = hidden_dim, hidden_dim, hidden_dim*2, round(expansion * hidden_dim // 2), round(3 * depth_mult)
        fuse_block = RepNCSPELAN4(c1=c1, c2=c2, c3=c3, c4=c4, n=num_blocks, act=act, csp_type=csp_type)
        self.fpn_block = copy.deepcopy(fuse_block)  
        self.pan_block = copy.deepcopy(fuse_block)

    def forward(self, feats):
        assert len(feats) == len(self.in_channels)
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        proj_feats.append(self.down_sample1(proj_feats[-1]))   # get the small-scale feature

        # fuse the global feature and the small-scale feature
        proj_feats[-1] = self.bi_fusion(proj_feats[-1])

        outs = []
        # fpn
        fuse_feat = proj_feats[0] + F.interpolate(proj_feats[1], scale_factor=2., mode='nearest')
        outs.append(self.fpn_block(fuse_feat))

        fuse_feat = proj_feats[1] + self.down_sample2(outs[-1])
        outs.append(self.pan_block(fuse_feat))

        return outs