File size: 1,578 Bytes
9c5c037
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
from torchvision import models

def get_classifier_in_features(classifier):
    """Get the input features of the classifier layer"""
    if isinstance(classifier, nn.Sequential):
        for layer in classifier:
            if isinstance(layer, nn.Linear):
                return layer.in_features
    elif isinstance(classifier, nn.Linear):
        return classifier.in_features
    else:
        # For EfficientNet, the classifier is usually just a Linear layer
        return classifier.in_features

class EfficientNet(nn.Module):
    def __init__(self, model_name: str, num_classes: int = 2, pretrained: bool = True, dropout_rate: float = 0.5):
        super(EfficientNet, self).__init__()

        # Load base EfficientNet model from torchvision
        if model_name == 'efficient_b1':
            if pretrained:
                self.base_model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)
            else:
                self.base_model = models.efficientnet_b1(weights=None)
        else:
            raise ValueError(f"Unknown model name: {model_name}")

        # Replace the classifier with a custom one
        num_features = get_classifier_in_features(self.base_model.classifier)
        self.base_model.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate / 2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.base_model(x)