i4ata commited on
Commit
21ed674
·
1 Parent(s): 5feebb1

small fix

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -19,6 +19,7 @@ class GradioApp:
19
  custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
20
 
21
  pretrained = models.vit_b_16().to(device).eval()
 
22
  pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
23
 
24
  self.models: Dict[str, Union[str, nn.Module]] = {
 
19
  custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
20
 
21
  pretrained = models.vit_b_16().to(device).eval()
22
+ pretrained.heads = nn.Linear(768, 3)
23
  pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
24
 
25
  self.models: Dict[str, Union[str, nn.Module]] = {