Anshu13 commited on
Commit
9b828a2
·
verified ·
1 Parent(s): f7ff479

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +16 -16
model.py CHANGED
@@ -1,16 +1,16 @@
1
- import torchvision
2
- import torch
3
- from torchvision import transforms
4
- from torch import nn
5
- def create_model(num_of_classes:int=3):
6
- weights=torchvision.models.MobileNet_V3_Large_Weights.DEFAULT
7
- transform=weights.transforms()
8
- model=torchvision.models.mobilenet_v3_large(weights=weights)
9
- for parameter in model.parameters():
10
- parameter.requires_grad=False
11
- for parameter in model.classifier[-4:].parameters():
12
- parameter.requires_grad=True
13
- for parameter in model.features[-6:].parameters():
14
- parameter.requires_grad=True
15
- model.classifier[3]=nn.Sequential(nn.Linear(1280,1000),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(1000,num_of_classes))
16
- return model,transform
 
1
+ import torchvision
2
+ import torch
3
+ from torchvision import transforms
4
+ from torch import nn
5
+ def create_model(num_of_classes:int=3):
6
+ weights=torchvision.models.MobileNet_V3_Large_Weights.DEFAULT
7
+ transform=weights.transforms()
8
+ model=torchvision.models.mobilenet_v3_large(weights=weights)
9
+ for parameter in model.parameters():
10
+ parameter.requires_grad=False
11
+ for parameter in model.classifier[-4:].parameters():
12
+ parameter.requires_grad=True
13
+ for parameter in model.features[-2:].parameters():
14
+ parameter.requires_grad=True
15
+ model.classifier[3]=nn.Sequential(nn.Linear(1280,1000),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(1000,num_of_classes))
16
+ return model,transform