File size: 622 Bytes
e5461d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# utils/feature_extractor.py

import torch
import torch.nn as nn
from torchvision import models

class FeatureExtractor(nn.Module):
    def __init__(self, backbone='resnet50'):
        super(FeatureExtractor, self).__init__()
        if backbone == 'resnet50':
            self.model = models.resnet50(pretrained=True)
            # Remove the final fully connected layer
            self.features = nn.Sequential(*list(self.model.children())[:-2])
        else:
            raise NotImplementedError(f"Backbone {backbone} is not implemented.")

    def forward(self, x):
        return self.features(x)