geitta commited on
Commit
0434edb
·
verified ·
1 Parent(s): 22bb2bd
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -11,7 +11,7 @@ def create_vit_model(num_classes: int = 5):
11
  for param in model.parameters():
12
  param.requires_grad = False
13
  # change the classifier head to suit our problem
14
- vit.heads = nn.Sequential(nn.Linear(in_features=768,
15
  out_features=5,
16
  bias=True))
17
  return model, transforms
 
11
  for param in model.parameters():
12
  param.requires_grad = False
13
  # change the classifier head to suit our problem
14
+ model.heads = nn.Sequential(nn.Linear(in_features=768,
15
  out_features=5,
16
  bias=True))
17
  return model, transforms