dvtiendat commited on
Commit
e55321c
·
1 Parent(s): 379a299

cpu fixed

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -23,11 +23,11 @@ class Pipeline:
23
 
24
  def _load_models(self):
25
  classification_model = resnet_model
26
- classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt'))
27
  classification_model.eval()
28
 
29
  segmentation_model = ResNetUnet()
30
- checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt')
31
  segmentation_model.load_state_dict(checkpoint['model_state_dict'])
32
  segmentation_model.eval()
33
 
 
23
 
24
  def _load_models(self):
25
  classification_model = resnet_model
26
+ classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt', map_location=torch.device('cpu')))
27
  classification_model.eval()
28
 
29
  segmentation_model = ResNetUnet()
30
+ checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt', map_location=torch.device('cpu'))
31
  segmentation_model.load_state_dict(checkpoint['model_state_dict'])
32
  segmentation_model.eval()
33