import math import torch import torch.nn as nn @torch.no_grad() def sinusoidal_2d_pe(H: int, W: int, D: int, device=None) -> torch.Tensor: assert D % 4 == 0, f"D % 4 == 0 must; D={D}" device = device or torch.device('cpu') y = torch.arange(H, device=device, dtype=torch.float32) x = torch.arange(W, device=device, dtype=torch.float32) yy, xx = torch.meshgrid(y, x, indexing='ij') # [H,W] d = D // 4 k = torch.arange(d, device=device, dtype=torch.float32) omega = torch.exp(-math.log(10000.0) * k / d) # [d] # Broadcast: [H,W,1]*[d] -> [H,W,d] y_sin = torch.sin(yy[..., None] * omega) # [H,W,d] y_cos = torch.cos(yy[..., None] * omega) # [H,W,d] x_sin = torch.sin(xx[..., None] * omega) # [H,W,d] x_cos = torch.cos(xx[..., None] * omega) # [H,W,d] pe = torch.cat([y_sin, y_cos, x_sin, x_cos], dim=-1) # [H,W,D] pe = pe.view(1, H*W, D).contiguous() return pe class EncoderAttnBlock(nn.Module): def __init__(self, dim, num_heads:int=8, dropout:float=0.2, mlp_ratio:float=4.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=dropout) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, int(dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(dropout) ) def forward(self, x): # x: [B,S,D] h = self.norm1(x) x = x + self.attn.forward(h,h,h, need_weights=False)[0] # [B,S,D] h = self.norm2(x) x = x + self.mlp(h) return x class ViTEncoder(torch.nn.Module): def __init__(self, dim:int, in_shape:list[int], num_blocks=2, num_heads=8, dropout=0.1, device=torch.device('cpu')): super().__init__() self.H = in_shape[2] self.W = in_shape[3] self.blocks = torch.nn.ModuleList([ EncoderAttnBlock(dim, num_heads=num_heads, mlp_ratio=4.0, dropout=dropout) for _ in range(num_blocks) ]).to(device=device) self.norm = torch.nn.LayerNorm(dim).to(device=device) self.proj = nn.Conv2d(in_shape[1], dim, 1).to(device=device) self.ln = nn.LayerNorm(dim).to(device=device) def forward(self, feats:torch.Tensor): #feats: [B,C,H,W] feats = self.proj.forward(feats) # [B, dim, H, W] feats = feats.flatten(2) # [B, dim, S] feats = feats.transpose(1, 2) # [B, S, dim] vis_tokens = self.ln.forward(feats) # [B, S, dim] B,S,D = vis_tokens.shape pe = sinusoidal_2d_pe(H=self.H, W=self.W, D=D, device=vis_tokens.device) x = vis_tokens + pe # Positional Encoding # Attention Blocks for block in self.blocks: x = block(x) return self.norm(x) # [B,S,D] # ------ OLD VERSION ------ class CNNEncoder(nn.Module): def __init__(self, dim:int, in_shape:list[int], device=torch.device('cpu') ): super().__init__() self.conv = nn.Conv2d(in_shape[1], dim, 1).to(device=device) self.ln = nn.LayerNorm(dim).to(device=device) self.visual_patch = in_shape[-1] * in_shape[-2] def forward(self, feats: torch.Tensor) -> torch.Tensor: # feats: [B, C, H, W] feats = self.conv.forward(feats) # [B, dim, H, W] feats = feats.flatten(2) # [B, dim, visual_patch] feats = feats.transpose(1, 2) # [B, visual_patch, dim] feats = self.ln.forward(feats) # [B, visual_patch, dim] return feats