import torch.nn as nn from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights class CatDogEfficientNetB0(nn.Module): def __init__(self): super().__init__() weights = EfficientNet_B0_Weights.DEFAULT self.base = efficientnet_b0(weights=weights) for param in self.base.parameters(): param.requires_grad = False in_features = self.base.classifier[1].in_features self.base.classifier[1] = nn.Linear(in_features, 2) def forward(self, x): return self.base(x)