File size: 2,705 Bytes
322161a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from functools import reduce
from operator import add
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet

class Backbone(nn.Module):

    def __init__(self, typestr):
        super(Backbone, self).__init__()

        self.backbone = typestr

        # feature extractor initialization
        if typestr == 'resnet50':
            self.feature_extractor = resnet.resnet50(weights=resnet.ResNet50_Weights.DEFAULT)
            self.feat_channels = [256, 512, 1024, 2048]
            self.nlayers = [3, 4, 6, 3]
            self.feat_ids = list(range(0, 17))
        else:
            raise Exception('Unavailable backbone: %s' % typestr)
        self.feature_extractor.eval()

        # define model
        self.lids = reduce(add, [[i + 1] * x for i, x in enumerate(self.nlayers)])
        self.stack_ids = torch.tensor(self.lids).bincount()[-4:].cumsum(dim=0)

        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def extract_feats(self, img):
        r""" Extract input image features """
        feats = []
        bottleneck_ids = reduce(add, list(map(lambda x: list(range(x)), self.nlayers)))
        # Layer 0
        feat = self.feature_extractor.conv1.forward(img)
        feat = self.feature_extractor.bn1.forward(feat)
        feat = self.feature_extractor.relu.forward(feat)
        feat = self.feature_extractor.maxpool.forward(feat)

        # Layer 1-4
        for hid, (bid, lid) in enumerate(zip(bottleneck_ids, self.lids)):
            res = feat
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv1.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn1.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv2.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn2.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].conv3.forward(feat)
            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].bn3.forward(feat)

            if bid == 0:
                res = self.feature_extractor.__getattr__('layer%d' % lid)[bid].downsample.forward(res)

            feat += res

            if hid + 1 in self.feat_ids:
                feats.append(feat.clone())

            feat = self.feature_extractor.__getattr__('layer%d' % lid)[bid].relu.forward(feat)

        return feats