File size: 4,081 Bytes
9ec3d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import math
import torch
import torch.nn as nn

@torch.no_grad()
def sinusoidal_2d_pe(H: int, W: int, D: int, device=None) -> torch.Tensor:

    assert D % 4 == 0, f"D % 4 == 0 must; D={D}"
    device = device or torch.device('cpu')

    y = torch.arange(H, device=device, dtype=torch.float32)
    x = torch.arange(W, device=device, dtype=torch.float32)
    yy, xx = torch.meshgrid(y, x, indexing='ij')  # [H,W]

    d = D // 4
    k = torch.arange(d, device=device, dtype=torch.float32)
    omega = torch.exp(-math.log(10000.0) * k / d)  # [d]

    # Broadcast: [H,W,1]*[d] -> [H,W,d]
    y_sin = torch.sin(yy[..., None] * omega)  # [H,W,d]
    y_cos = torch.cos(yy[..., None] * omega)  # [H,W,d]
    x_sin = torch.sin(xx[..., None] * omega)  # [H,W,d]
    x_cos = torch.cos(xx[..., None] * omega)  # [H,W,d]

    pe = torch.cat([y_sin, y_cos, x_sin, x_cos], dim=-1)  # [H,W,D]
    pe = pe.view(1, H*W, D).contiguous()
    return pe

class EncoderAttnBlock(nn.Module):
    def __init__(self, 
                 dim, 
                 num_heads:int=8,
                 dropout:float=0.2,
                 mlp_ratio:float=4.0):
        
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):  # x: [B,S,D]
        h = self.norm1(x) 
        x = x + self.attn.forward(h,h,h, need_weights=False)[0] # [B,S,D]
        h = self.norm2(x)
        x = x + self.mlp(h)
        return x


class ViTEncoder(torch.nn.Module):
    def __init__(self,
                 dim:int,
                 in_shape:list[int],
                 num_blocks=2, 
                 num_heads=8, 
                 dropout=0.1,
                 device=torch.device('cpu')):
        
        super().__init__()
        self.H = in_shape[2]
        self.W = in_shape[3]
        self.blocks = torch.nn.ModuleList([
            EncoderAttnBlock(dim, num_heads=num_heads, mlp_ratio=4.0, dropout=dropout)
            for _ in range(num_blocks)
        ]).to(device=device)
        self.norm = torch.nn.LayerNorm(dim).to(device=device)
        self.proj = nn.Conv2d(in_shape[1], dim, 1).to(device=device)
        self.ln = nn.LayerNorm(dim).to(device=device)

    def forward(self, feats:torch.Tensor):      #feats: [B,C,H,W]
        feats = self.proj.forward(feats)        # [B, dim, H, W]
        feats = feats.flatten(2)                # [B, dim, S]
        feats = feats.transpose(1, 2)           # [B, S, dim]
        vis_tokens = self.ln.forward(feats)     # [B, S, dim]

        B,S,D = vis_tokens.shape
        pe = sinusoidal_2d_pe(H=self.H, 
                              W=self.W, 
                              D=D, 
                              device=vis_tokens.device)
        x = vis_tokens + pe # Positional Encoding
        
        # Attention Blocks
        for block in self.blocks:
            x = block(x) 

        return self.norm(x)         # [B,S,D]

# ------ OLD VERSION ------
class CNNEncoder(nn.Module):

    def __init__(self,
                 dim:int,
                 in_shape:list[int],
                 device=torch.device('cpu')
                 ):
        
        super().__init__()

        self.conv = nn.Conv2d(in_shape[1], dim, 1).to(device=device)
        self.ln = nn.LayerNorm(dim).to(device=device)
        self.visual_patch = in_shape[-1] * in_shape[-2]

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
                                            # feats: [B, C, H, W]
        feats = self.conv.forward(feats)    # [B, dim, H, W]
        feats = feats.flatten(2)            # [B, dim, visual_patch]
        feats = feats.transpose(1, 2)       # [B, visual_patch, dim]
        feats = self.ln.forward(feats)      # [B, visual_patch, dim]
        return feats