vit-image-captioning / vit_encoder.py
mostafahagali's picture
Upload 9 files
601cad6 verified
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
class ViTEncoder(nn.Module):
def __init__(
self,
d_model=512,
freeze_backbone=True,
pretrained=True
):
super().__init__()
if pretrained:
weights = ViT_B_16_Weights.IMAGENET1K_V1
else:
weights = None
self.vit = vit_b_16(weights=weights)
# Remove classifier
self.vit.heads = nn.Identity()
self.hidden_dim = self.vit.hidden_dim
# Projection
self.proj = nn.Linear(self.hidden_dim, d_model)
# Freeze
if freeze_backbone:
for p in self.vit.parameters():
p.requires_grad = False
def unfreeze_backbone(self, unfreeze=True):
for p in self.vit.parameters():
p.requires_grad = unfreeze
print(f"ViT Backbone {'Unfrozen' if unfreeze else 'Frozen'}")
def forward(self, images):
"""
images: (B, 3, 224, 224)
return: (B, 196, d_model)
"""
# 1. Patch Embedding
x = self.vit.conv_proj(images)
# (B, hidden, 14, 14)
x = x.flatten(2).transpose(1, 2)
# (B, 196, hidden)
# 2. Add Positional Embedding (Slicing to skip CLS token pos at index 0)
# We use the parameter DIRECTLY from the model so gradients flow correctly
# and device placement is handled automatically.
# self.vit.encoder.pos_embedding is (1, 197, 768)
x = x + self.vit.encoder.pos_embedding[:, 1:]
# 3. Transformer Layers
# We must not ignore the transformer layers!
# Otherwise we are just using a simple Conv+Linear projection.
x = self.vit.encoder.dropout(x)
x = self.vit.encoder.layers(x)
x = self.vit.encoder.ln(x)
# 4. Project
x = self.proj(x)
# (B, 196, d_model)
return x