File size: 2,022 Bytes
601cad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights


class ViTEncoder(nn.Module):
    def __init__(

        self,

        d_model=512,

        freeze_backbone=True,

        pretrained=True

    ):
        super().__init__()

        if pretrained:
            weights = ViT_B_16_Weights.IMAGENET1K_V1
        else:
            weights = None

        self.vit = vit_b_16(weights=weights)

        # Remove classifier
        self.vit.heads = nn.Identity()

        self.hidden_dim = self.vit.hidden_dim

        # Projection
        self.proj = nn.Linear(self.hidden_dim, d_model)


        # Freeze
        if freeze_backbone:
            for p in self.vit.parameters():
                p.requires_grad = False


    def unfreeze_backbone(self, unfreeze=True):
        for p in self.vit.parameters():
            p.requires_grad = unfreeze

        print(f"ViT Backbone {'Unfrozen' if unfreeze else 'Frozen'}")


    def forward(self, images):
        """

        images: (B, 3, 224, 224)



        return: (B, 196, d_model)

        """

        # 1. Patch Embedding
        x = self.vit.conv_proj(images)
        # (B, hidden, 14, 14)

        x = x.flatten(2).transpose(1, 2)
        # (B, 196, hidden)


        # 2. Add Positional Embedding (Slicing to skip CLS token pos at index 0)
        # We use the parameter DIRECTLY from the model so gradients flow correctly
        # and device placement is handled automatically.
        # self.vit.encoder.pos_embedding is (1, 197, 768)
        x = x + self.vit.encoder.pos_embedding[:, 1:]

        # 3. Transformer Layers
        # We must not ignore the transformer layers! 
        # Otherwise we are just using a simple Conv+Linear projection.
        x = self.vit.encoder.dropout(x)
        x = self.vit.encoder.layers(x)
        x = self.vit.encoder.ln(x)

        # 4. Project
        x = self.proj(x)
        # (B, 196, d_model)

        return x