Munzali commited on
Commit
a2744f3
·
verified ·
1 Parent(s): 6591c32

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -0
model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ import torch.nn as nn
5
+ from torchvision.models import mobilenet_v2
6
+
7
+ # Load MobileNetV2 with pre-trained weights
8
+
9
+
10
+ def create_effnetb2_model(num_classes:int=4,
11
+ seed:int=42):
12
+ """Creates an EfficientNetB2 feature extractor model and transforms.
13
+ Args:
14
+ num_classes (int, optional): number of classes in the classifier head.
15
+ Defaults to 3.
16
+ seed (int, optional): random seed value. Defaults to 42.
17
+ Returns:
18
+ model (torch.nn.Module): EffNetB2 feature extractor model.
19
+ transforms (torchvision.transforms): EffNetB2 image transforms.
20
+ """
21
+ # Create EffNetB2 pretrained weights, transforms and model
22
+
23
+ transforms = transforms.Compose([
24
+ transforms.Resize((224, 224)), # 1. Reshape all images to 224x224 (though some models may require different sizes)
25
+ transforms.ToTensor(), # 2. Turn image values to between 0 & 1
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], # 3. A mean of [0.485, 0.456, 0.406] (across each colour channel)
27
+ std=[0.229, 0.224, 0.225]) # 4. A standard deviation of [0.229, 0.224, 0.225] (across each colour channel),
28
+ ])
29
+ model = mobilenet_v2(pretrained=True)
30
+
31
+ # Freeze all layers in base model
32
+ # Freeze all base layers by setting requires_grad attribute to False
33
+ for param in model.parameters():
34
+ param.requires_grad = False
35
+
36
+ # Since we're creating a new layer with random weights (torch.nn.Linear),
37
+ # let's set the seeds
38
+ torch.manual_seed(42)
39
+
40
+ # Update the classifier head to suit our problem
41
+ model.classifier = nn.Sequential(
42
+ nn.Dropout(p=0.2, inplace=True),
43
+ nn.Linear(in_features=model.classifier[1].in_features, # Accessing the last layer of the classifier
44
+ out_features=num_classes,
45
+ bias=True)
46
+ )
47
+
48
+ return model, transforms
49
+