from torchvision.models import resnet50, ResNet50_Weights import torch.nn as nn class EncoderCNN(nn.Module): def __init__(self, embed_size, fine_tune=False): super(EncoderCNN, self).__init__() resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None) for param in resnet.parameters(): param.requires_grad = False if fine_tune: for param in resnet.layer4.parameters(): param.requires_grad = True backbone = list(resnet.children())[:-1] self.resnet = nn.Sequential(*backbone) self.fc = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01) def forward(self, images): # (B, C, W, H) features = self.resnet(images) # (B, 2048, 1, 1) features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input return self.bn(self.fc(features)) # (B, embed_size)