File size: 1,617 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from torchvision import models


class EncoderViTB16(nn.Module):
    def __init__(self, num_classes=50, embed_size=512):
        super().__init__()

        model = models.vit_b_16(
            weights=models.ViT_B_16_Weights.DEFAULT
        )

        self.backbone = model

        for param in self.backbone.parameters():
            param.requires_grad = False

        in_features = model.heads.head.in_features

        self.backbone.heads = nn.Identity()

        self.classifier = nn.Linear(
            in_features,
            num_classes
        )

        self.projector = nn.Linear(
            in_features,
            embed_size
        )

    def forward(
        self,
        images,
        return_features=False
    ):

        features = self.backbone(images)

        if isinstance(features, tuple):
            features = features[0]

        features = features.view(
            features.size(0),
            -1
        )

        logits = self.classifier(features)

        # 특성 추출
        cap_features = self.backbone._process_input(images) # B, 196, 768
        cap_features = cap_features + self.backbone.encoder.pos_embedding[:, 1:, :] # 위치 임베딩
        for layer in self.backbone.encoder.layers: # B, 196, 768
            cap_features = layer(cap_features)
        cap_features = self.backbone.encoder.ln(cap_features) # LayerNorm
        cap_features = self.projector(cap_features) # B, 196, d_model

        # classification
        if not return_features:
            return logits

        # captioning
        return cap_features