vit-image-captioning / cnn_encoder.py
mostafahagali's picture
Upload 9 files
601cad6 verified
import torch
import torch.nn as nn
from torchvision import models
class CNNEncoder(nn.Module):
def __init__(self, encoded_dim=2048, d_model=512, freeze_backbone=True):
super().__init__()
resnet = models.resnet50(pretrained=True)
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.proj = nn.Linear(encoded_dim, d_model)
self.freeze_backbone = freeze_backbone
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
def unfreeze_backbone(self, unfreeze=True):
self.freeze_backbone = not unfreeze
for p in self.backbone.parameters():
p.requires_grad = unfreeze
print(f"ResNet Backbone {'Unfrozen' if unfreeze else 'Frozen'}")
def forward(self, images):
"""
images: (B, 3, 224, 224)
return: (B, 49, d_model)
"""
features = self.backbone(images)
# (B, 2048, 7, 7)
B, C, H, W = features.shape
features = features.view(B, C, H * W).permute(0, 2, 1)
# (B, 49, 2048)
features = self.proj(features)
# (B, 49, d_model)
return features