Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| import torch.nn as nn | |
| from torchvision.models import mobilenet_v2 | |
| # Load MobileNetV2 with pre-trained weights | |
| def create_mobilenet_model(num_classes:int=4, | |
| seed:int=42): | |
| """Creates an EfficientNetB2 feature extractor model and transforms. | |
| Args: | |
| num_classes (int, optional): number of classes in the classifier head. | |
| Defaults to 3. | |
| seed (int, optional): random seed value. Defaults to 42. | |
| Returns: | |
| model (torch.nn.Module): EffNetB2 feature extractor model. | |
| transforms (torchvision.transforms): EffNetB2 image transforms. | |
| """ | |
| # Create EffNetB2 pretrained weights, transforms and model | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # 1. Reshape all images to 224x224 (though some models may require different sizes) | |
| transforms.ToTensor(), # 2. Turn image values to between 0 & 1 | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], # 3. A mean of [0.485, 0.456, 0.406] (across each colour channel) | |
| std=[0.229, 0.224, 0.225]) # 4. A standard deviation of [0.229, 0.224, 0.225] (across each colour channel), | |
| ]) | |
| model = mobilenet_v2(pretrained=True) | |
| # Freeze all layers in base model | |
| # Freeze all base layers by setting requires_grad attribute to False | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Since we're creating a new layer with random weights (torch.nn.Linear), | |
| # let's set the seeds | |
| torch.manual_seed(42) | |
| # Update the classifier head to suit our problem | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features=model.classifier[1].in_features, # Accessing the last layer of the classifier | |
| out_features=num_classes, | |
| bias=True) | |
| ) | |
| return model, transform | |