NizamuddinMandekar commited on
Commit
51dd661
·
verified ·
1 Parent(s): 56e5bb5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -16
model.py CHANGED
@@ -2,24 +2,26 @@ import torch
2
  import torchvision
3
  from torch import nn
4
 
5
- def create_model(num_classes: int = 2, seed: int = 42):
6
- weights = torchvision.models.ResNet50_Weights.DEFAULT
7
- transforms = weights.transforms()
8
- model = torchvision.models.resnet50(weights=weights)
 
9
 
10
- # Freeze all base model parameters
11
- for param in model.parameters():
12
- param.requires_grad = False
13
 
14
- torch.manual_seed(seed)
 
 
 
 
 
 
 
 
15
 
16
- # Replace the final classifier layer with a custom head
17
- model.fc = nn.Sequential(
18
- nn.Linear(2048, 1000),
19
- nn.ReLU(),
20
- nn.Linear(1000, 500),
21
- nn.Dropout(),
22
- nn.Linear(in_features=500, out_features=num_classes, bias=True)
23
  )
24
 
25
- return model, transforms
 
 
2
  import torchvision
3
  from torch import nn
4
 
5
+ def create_model(num_classes:int=2,
6
+ seed:int=42):
7
+ weights=torchvision.models.ResNet50_Weights.DEFAULT
8
+ transforms=weights.transforms()
9
+ model=torchvision.models.resnet50(weights=weights)
10
 
11
+ for param in model.parameters():
12
+ param.requires_grad=False
 
13
 
14
+ torch.manual_seed(42)
15
+ model.fc= torch.nn.Sequential(
16
+ torch.nn.Linear(2048,1000),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Linear(1000,500),
19
+ torch.nn.Dropout(),
20
+ torch.nn.Linear(in_features=500,
21
+ out_features=num_classes,
22
+ bias=True)
23
 
 
 
 
 
 
 
 
24
  )
25
 
26
+ return model,transforms
27
+