cat-dog-classifier / model.py
DyuHo's picture
Upload 9 files
9c5c037 verified
import torch
import torch.nn as nn
from torchvision import models
def get_classifier_in_features(classifier):
"""Get the input features of the classifier layer"""
if isinstance(classifier, nn.Sequential):
for layer in classifier:
if isinstance(layer, nn.Linear):
return layer.in_features
elif isinstance(classifier, nn.Linear):
return classifier.in_features
else:
# For EfficientNet, the classifier is usually just a Linear layer
return classifier.in_features
class EfficientNet(nn.Module):
def __init__(self, model_name: str, num_classes: int = 2, pretrained: bool = True, dropout_rate: float = 0.5):
super(EfficientNet, self).__init__()
# Load base EfficientNet model from torchvision
if model_name == 'efficient_b1':
if pretrained:
self.base_model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)
else:
self.base_model = models.efficientnet_b1(weights=None)
else:
raise ValueError(f"Unknown model name: {model_name}")
# Replace the classifier with a custom one
num_features = get_classifier_in_features(self.base_model.classifier)
self.base_model.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate / 2),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.base_model(x)