File size: 1,464 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch.nn as nn
from torchvision import models


class EncoderResnet18(nn.Module):
    def __init__(self, num_classes=50, embed_size=512):
        super().__init__()
        model = models.resnet18(
            weights=models.ResNet18_Weights.DEFAULT
        )

        modules = list(model.children())[:-1]
        self.backbone = nn.Sequential(*modules)
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.classifier = nn.Linear(
            model.fc.in_features,
            num_classes
        )

        cap_modules = list(model.children())[:-2]
        self.cap_backbone = nn.Sequential(*cap_modules)
        for param in self.cap_backbone.parameters():
            param.requires_grad = False
        
        self.projector = nn.Linear(
            model.fc.in_features,
            embed_size
        )

    def forward(
        self,
        images,
        return_features=False
    ):

        features = self.backbone(images)
        features = features.view(
            features.size(0),
            -1
        )
        logits = self.classifier(features)


        cap_features = self.cap_backbone(images)
        cap_features = cap_features.flatten(2)
        cap_features = cap_features.permute(0, 2, 1)
        cap_features = self.projector(cap_features)

        # classification
        if not return_features:
            return logits

        # captioning
        return cap_features