import torch import torchvision from torch import nn def create_model(num_classes = 6, seed = 1): """Create an instance of the effnet_b2 model, freezes all layers and changes the classifier head. Returns: The model and its data transform """ #get pretrained model and its transform weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT model = torchvision.models.efficientnet_b2(weights = weights) transform = weights.transforms() #freeze all layers for param in model.parameters(): param.requires_grad = False #create a new classifier head with 6 output classes classifier = nn.Sequential(nn.Dropout(p = 0.2, inplace = True), nn.Linear(in_features = 1408, out_features = num_classes)) #replace old classifier head with newly created one model.classifier = classifier return model, transform