Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| from torch import nn | |
| def create_model(num_classes = 6, seed = 1): | |
| """Create an instance of the effnet_b2 model, freezes all layers and changes the classifier head. | |
| Returns: The model and its data transform | |
| """ | |
| #get pretrained model and its transform | |
| weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT | |
| model = torchvision.models.efficientnet_b2(weights = weights) | |
| transform = weights.transforms() | |
| #freeze all layers | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| #create a new classifier head with 6 output classes | |
| classifier = nn.Sequential(nn.Dropout(p = 0.2, inplace = True), | |
| nn.Linear(in_features = 1408, out_features = num_classes)) | |
| #replace old classifier head with newly created one | |
| model.classifier = classifier | |
| return model, transform | |