Alessandro Goller commited on
Commit
375095a
·
1 Parent(s): 58041f2

Update model

Browse files
Files changed (1) hide show
  1. model.py +7 -0
model.py CHANGED
@@ -3,6 +3,13 @@ import torchvision
3
 
4
  from torch import nn
5
 
 
 
 
 
 
 
 
6
 
7
  def create_model(num_classes:int=3,
8
  seed:int=42):
 
3
 
4
  from torch import nn
5
 
6
+ from torchvision.models._api import WeightsEnum
7
+ from torch.hub import load_state_dict_from_url
8
+
9
+ def get_state_dict(self, *args, **kwargs):
10
+ kwargs.pop("check_hash")
11
+ return load_state_dict_from_url(self.url, *args, **kwargs)
12
+ WeightsEnum.get_state_dict = get_state_dict
13
 
14
  def create_model(num_classes:int=3,
15
  seed:int=42):