Shekarss commited on
Commit
6eb52e0
·
verified ·
1 Parent(s): 5d3d2ca

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -7,7 +7,7 @@ def create_effnetb2_model(num_classes: int = 3,
7
  seed: int = 42):
8
  weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
9
  transforms = weights.transforms()
10
- model = torchvision.models.efficientnet_b2(weights=weights).to(device)
11
 
12
  for param in model.parameters():
13
  param.requires_grad = False
@@ -17,5 +17,5 @@ def create_effnetb2_model(num_classes: int = 3,
17
  model.classifier = nn.Sequential(
18
  nn.Dropout(p=0.3, inplace=True),
19
  nn.Linear(in_features=1408, out_features=num_classes, bias=True)
20
- ).to(device)
21
  return model, transforms
 
7
  seed: int = 42):
8
  weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
9
  transforms = weights.transforms()
10
+ model = torchvision.models.efficientnet_b2(weights=weights)
11
 
12
  for param in model.parameters():
13
  param.requires_grad = False
 
17
  model.classifier = nn.Sequential(
18
  nn.Dropout(p=0.3, inplace=True),
19
  nn.Linear(in_features=1408, out_features=num_classes, bias=True)
20
+ )
21
  return model, transforms