| import torch
|
| import torch.nn as nn
|
| from torchvision import models
|
|
|
| class CNNEncoder(nn.Module):
|
| def __init__(self, encoded_dim=2048, d_model=512, freeze_backbone=True):
|
| super().__init__()
|
|
|
| resnet = models.resnet50(pretrained=True)
|
| self.backbone = nn.Sequential(*list(resnet.children())[:-2])
|
|
|
| self.proj = nn.Linear(encoded_dim, d_model)
|
|
|
| self.freeze_backbone = freeze_backbone
|
| if freeze_backbone:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = False
|
|
|
| def unfreeze_backbone(self, unfreeze=True):
|
| self.freeze_backbone = not unfreeze
|
| for p in self.backbone.parameters():
|
| p.requires_grad = unfreeze
|
| print(f"ResNet Backbone {'Unfrozen' if unfreeze else 'Frozen'}")
|
|
|
|
|
| def forward(self, images):
|
| """
|
| images: (B, 3, 224, 224)
|
| return: (B, 49, d_model)
|
| """
|
|
|
| features = self.backbone(images)
|
|
|
|
|
| B, C, H, W = features.shape
|
| features = features.view(B, C, H * W).permute(0, 2, 1)
|
|
|
|
|
| features = self.proj(features)
|
|
|
|
|
| return features |