import torch import torch.nn as nn from torchvision import models class CNNEncoder(nn.Module): def __init__(self, encoded_dim=2048, d_model=512, freeze_backbone=True): super().__init__() resnet = models.resnet50(pretrained=True) self.backbone = nn.Sequential(*list(resnet.children())[:-2]) self.proj = nn.Linear(encoded_dim, d_model) self.freeze_backbone = freeze_backbone if freeze_backbone: for p in self.backbone.parameters(): p.requires_grad = False def unfreeze_backbone(self, unfreeze=True): self.freeze_backbone = not unfreeze for p in self.backbone.parameters(): p.requires_grad = unfreeze print(f"ResNet Backbone {'Unfrozen' if unfreeze else 'Frozen'}") def forward(self, images): """ images: (B, 3, 224, 224) return: (B, 49, d_model) """ features = self.backbone(images) # (B, 2048, 7, 7) B, C, H, W = features.shape features = features.view(B, C, H * W).permute(0, 2, 1) # (B, 49, 2048) features = self.proj(features) # (B, 49, d_model) return features