Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| def get_classifier_in_features(classifier): | |
| """Get the input features of the classifier layer""" | |
| if isinstance(classifier, nn.Sequential): | |
| for layer in classifier: | |
| if isinstance(layer, nn.Linear): | |
| return layer.in_features | |
| elif isinstance(classifier, nn.Linear): | |
| return classifier.in_features | |
| else: | |
| # For EfficientNet, the classifier is usually just a Linear layer | |
| return classifier.in_features | |
| class EfficientNet(nn.Module): | |
| def __init__(self, model_name: str, num_classes: int = 2, pretrained: bool = True, dropout_rate: float = 0.5): | |
| super(EfficientNet, self).__init__() | |
| # Load base EfficientNet model from torchvision | |
| if model_name == 'efficient_b1': | |
| if pretrained: | |
| self.base_model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT) | |
| else: | |
| self.base_model = models.efficientnet_b1(weights=None) | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| # Replace the classifier with a custom one | |
| num_features = get_classifier_in_features(self.base_model.classifier) | |
| self.base_model.classifier = nn.Sequential( | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(num_features, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout_rate / 2), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.base_model(x) | |