# utils/feature_extractor.py import torch import torch.nn as nn from torchvision import models class FeatureExtractor(nn.Module): def __init__(self, backbone='resnet50'): super(FeatureExtractor, self).__init__() if backbone == 'resnet50': self.model = models.resnet50(pretrained=True) # Remove the final fully connected layer self.features = nn.Sequential(*list(self.model.children())[:-2]) else: raise NotImplementedError(f"Backbone {backbone} is not implemented.") def forward(self, x): return self.features(x)