| # utils/feature_extractor.py | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, backbone='resnet50'): | |
| super(FeatureExtractor, self).__init__() | |
| if backbone == 'resnet50': | |
| self.model = models.resnet50(pretrained=True) | |
| # Remove the final fully connected layer | |
| self.features = nn.Sequential(*list(self.model.children())[:-2]) | |
| else: | |
| raise NotImplementedError(f"Backbone {backbone} is not implemented.") | |
| def forward(self, x): | |
| return self.features(x) | |