import torch import torch.nn as nn from torchvision import models def get_classifier_in_features(classifier): """Get the input features of the classifier layer""" if isinstance(classifier, nn.Sequential): for layer in classifier: if isinstance(layer, nn.Linear): return layer.in_features elif isinstance(classifier, nn.Linear): return classifier.in_features else: # For EfficientNet, the classifier is usually just a Linear layer return classifier.in_features class EfficientNet(nn.Module): def __init__(self, model_name: str, num_classes: int = 2, pretrained: bool = True, dropout_rate: float = 0.5): super(EfficientNet, self).__init__() # Load base EfficientNet model from torchvision if model_name == 'efficient_b1': if pretrained: self.base_model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT) else: self.base_model = models.efficientnet_b1(weights=None) else: raise ValueError(f"Unknown model name: {model_name}") # Replace the classifier with a custom one num_features = get_classifier_in_features(self.base_model.classifier) self.base_model.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(num_features, 512), nn.ReLU(inplace=True), nn.Dropout(dropout_rate / 2), nn.Linear(512, num_classes) ) def forward(self, x): return self.base_model(x)