Spaces:
Runtime error
Runtime error
| import torchvision | |
| from torch import nn | |
| import torch | |
| def create_effnet(effnet_number, class_count, device="cpu"): | |
| weights = torchvision.models.get_weight(f"EfficientNet_B{effnet_number}_Weights.DEFAULT") | |
| model = torchvision.models.get_model(f"efficientnet_b{effnet_number}", weights=weights).to(device) | |
| input_features = model.classifier[1].in_features # Get the old classifiers input_size to be used for the new classifier | |
| p = model.classifier[0].p # Get the old classifiers dropout layers probability to be used for the new classifier | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p, inplace=True), | |
| nn.Linear(input_features, class_count) | |
| ).to(device) | |
| return model, weights.transforms() | |