Spaces:
Sleeping
Sleeping
File size: 537 Bytes
423bf0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from torch.nn import Dropout, Linear
from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
def create_model():
weights = EfficientNet_B2_Weights.DEFAULT
model = efficientnet_b2(weights=weights)
transform = weights.transforms()
classifier = torch.nn.Sequential(
Dropout(p=0.3, inplace=True),
Linear(in_features=1408, out_features=101)
)
for layer in model.features:
layer.requires_grad_(False)
model.classifier = classifier
return model, transform
|