File size: 2,441 Bytes
d426cb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# feature_extractor.py
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import inception


class FeatureExtractor(nn.Module):
    def __init__(self, model_name="resnet50"):
        super().__init__()

        if model_name == "resnet50":
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        elif model_name == "resnet18":
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif model_name == "vgg16":
            model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        elif model_name == "densenet121":
            model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        elif model_name == "efficientnet_b0":
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        elif model_name == "inception_v3":
            model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=True)
        else:
            raise ValueError(f"Unsupported feature extractor model: {model_name}")

        if isinstance(model, (models.ResNet, models.DenseNet)):
            self.features = nn.Sequential(*list(model.children())[:-1])
            self.out_features_dim = model.fc.in_features
        elif isinstance(model, models.VGG):
            self.features = model.features
            self.out_features_dim = model.classifier[0].in_features
        elif isinstance(model, inception.Inception3):
            self.features = nn.Sequential(
                model.Conv2d_1a_3x3, model.Conv2d_2a_3x3, model.Conv2d_2b_3x3,
                model.maxpool1, model.Conv2d_3b_1x1, model.Conv2d_4a_3x3,
                model.maxpool2, model.Mixed_5b, model.Mixed_5c, model.Mixed_5d,
                model.Mixed_6a, model.Mixed_6b, model.Mixed_6c, model.Mixed_6d, model.Mixed_6e,
                model.Mixed_7a, model.Mixed_7b, model.Mixed_7c, model.avgpool
            )
            self.out_features_dim = model.fc.in_features
        else:
            raise ValueError(f"Unknown model type for feature extraction: {model_name}")

    def forward(self, x):
        features = self.features(x)
        if features.dim() > 2:
            features = torch.flatten(features, 1)
        return features

    def get_output_dim(self):
        return self.out_features_dim