File size: 1,026 Bytes
eb55711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn


class EncoderCNN(nn.Module):
    def __init__(self, embed_size, fine_tune=False):
        super(EncoderCNN, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
        for param in resnet.parameters():
            param.requires_grad = False
        if fine_tune:
            for param in resnet.layer4.parameters():
                param.requires_grad = True
        backbone = list(resnet.children())[:-1]

        self.resnet = nn.Sequential(*backbone)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
    
    def forward(self, images): # (B, C, W, H)
        features = self.resnet(images) # (B, 2048, 1, 1)
        features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
        return self.bn(self.fc(features)) # (B, embed_size)