Spaces:
Runtime error
Runtime error
File size: 1,026 Bytes
eb55711 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
class EncoderCNN(nn.Module):
def __init__(self, embed_size, fine_tune=False):
super(EncoderCNN, self).__init__()
resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
for param in resnet.parameters():
param.requires_grad = False
if fine_tune:
for param in resnet.layer4.parameters():
param.requires_grad = True
backbone = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*backbone)
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
def forward(self, images): # (B, C, W, H)
features = self.resnet(images) # (B, 2048, 1, 1)
features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
return self.bn(self.fc(features)) # (B, embed_size)
|