| import torch
|
| import torch.nn as nn
|
| from torchvision.models import vit_b_16, ViT_B_16_Weights
|
|
|
|
|
| class ViTEncoder(nn.Module):
|
| def __init__(
|
| self,
|
| d_model=512,
|
| freeze_backbone=True,
|
| pretrained=True
|
| ):
|
| super().__init__()
|
|
|
| if pretrained:
|
| weights = ViT_B_16_Weights.IMAGENET1K_V1
|
| else:
|
| weights = None
|
|
|
| self.vit = vit_b_16(weights=weights)
|
|
|
|
|
| self.vit.heads = nn.Identity()
|
|
|
| self.hidden_dim = self.vit.hidden_dim
|
|
|
|
|
| self.proj = nn.Linear(self.hidden_dim, d_model)
|
|
|
|
|
|
|
| if freeze_backbone:
|
| for p in self.vit.parameters():
|
| p.requires_grad = False
|
|
|
|
|
| def unfreeze_backbone(self, unfreeze=True):
|
| for p in self.vit.parameters():
|
| p.requires_grad = unfreeze
|
|
|
| print(f"ViT Backbone {'Unfrozen' if unfreeze else 'Frozen'}")
|
|
|
|
|
| def forward(self, images):
|
| """
|
| images: (B, 3, 224, 224)
|
|
|
| return: (B, 196, d_model)
|
| """
|
|
|
|
|
| x = self.vit.conv_proj(images)
|
|
|
|
|
| x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| x = x + self.vit.encoder.pos_embedding[:, 1:]
|
|
|
|
|
|
|
|
|
| x = self.vit.encoder.dropout(x)
|
| x = self.vit.encoder.layers(x)
|
| x = self.vit.encoder.ln(x)
|
|
|
|
|
| x = self.proj(x)
|
|
|
|
|
| return x
|
|
|