import torch import torchvision from torch import nn def create_effnetb4_model(num_classes:int = 101): weights = torchvision.models.EfficientNet_B4_Weights.DEFAULT transforms = weights.transforms() model = torchvision.models.efficientnet_b4(weights=weights) model.classifier[1] = nn.Linear(in_features=1792, out_features=num_classes) return model, transforms