import torch import torch.nn as nn from torchvision.models import vit_b_16, ViT_B_16_Weights class ViTEncoder(nn.Module): def __init__( self, d_model=512, freeze_backbone=True, pretrained=True ): super().__init__() if pretrained: weights = ViT_B_16_Weights.IMAGENET1K_V1 else: weights = None self.vit = vit_b_16(weights=weights) # Remove classifier self.vit.heads = nn.Identity() self.hidden_dim = self.vit.hidden_dim # Projection self.proj = nn.Linear(self.hidden_dim, d_model) # Freeze if freeze_backbone: for p in self.vit.parameters(): p.requires_grad = False def unfreeze_backbone(self, unfreeze=True): for p in self.vit.parameters(): p.requires_grad = unfreeze print(f"ViT Backbone {'Unfrozen' if unfreeze else 'Frozen'}") def forward(self, images): """ images: (B, 3, 224, 224) return: (B, 196, d_model) """ # 1. Patch Embedding x = self.vit.conv_proj(images) # (B, hidden, 14, 14) x = x.flatten(2).transpose(1, 2) # (B, 196, hidden) # 2. Add Positional Embedding (Slicing to skip CLS token pos at index 0) # We use the parameter DIRECTLY from the model so gradients flow correctly # and device placement is handled automatically. # self.vit.encoder.pos_embedding is (1, 197, 768) x = x + self.vit.encoder.pos_embedding[:, 1:] # 3. Transformer Layers # We must not ignore the transformer layers! # Otherwise we are just using a simple Conv+Linear projection. x = self.vit.encoder.dropout(x) x = self.vit.encoder.layers(x) x = self.vit.encoder.ln(x) # 4. Project x = self.proj(x) # (B, 196, d_model) return x