Spaces:
Sleeping
Sleeping
File size: 1,578 Bytes
9c5c037 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import torch
import torch.nn as nn
from torchvision import models
def get_classifier_in_features(classifier):
"""Get the input features of the classifier layer"""
if isinstance(classifier, nn.Sequential):
for layer in classifier:
if isinstance(layer, nn.Linear):
return layer.in_features
elif isinstance(classifier, nn.Linear):
return classifier.in_features
else:
# For EfficientNet, the classifier is usually just a Linear layer
return classifier.in_features
class EfficientNet(nn.Module):
def __init__(self, model_name: str, num_classes: int = 2, pretrained: bool = True, dropout_rate: float = 0.5):
super(EfficientNet, self).__init__()
# Load base EfficientNet model from torchvision
if model_name == 'efficient_b1':
if pretrained:
self.base_model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)
else:
self.base_model = models.efficientnet_b1(weights=None)
else:
raise ValueError(f"Unknown model name: {model_name}")
# Replace the classifier with a custom one
num_features = get_classifier_in_features(self.base_model.classifier)
self.base_model.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate / 2),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.base_model(x)
|