from torch.nn import Dropout, Linear from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2 import torch def create_model(): weights = EfficientNet_B2_Weights.DEFAULT model = efficientnet_b2(weights=weights) transform = weights.transforms() classifier = torch.nn.Sequential( Dropout(p=0.3, inplace=True), Linear(in_features=1408, out_features=3) ) for layer in model.features: layer.requires_grad_(False) model.classifier = classifier return model,transform