Spaces:
Build error
Build error
dvtiendat commited on
Commit ·
e55321c
1
Parent(s): 379a299
cpu fixed
Browse files- pipeline.py +2 -2
pipeline.py
CHANGED
|
@@ -23,11 +23,11 @@ class Pipeline:
|
|
| 23 |
|
| 24 |
def _load_models(self):
|
| 25 |
classification_model = resnet_model
|
| 26 |
-
classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt'))
|
| 27 |
classification_model.eval()
|
| 28 |
|
| 29 |
segmentation_model = ResNetUnet()
|
| 30 |
-
checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt')
|
| 31 |
segmentation_model.load_state_dict(checkpoint['model_state_dict'])
|
| 32 |
segmentation_model.eval()
|
| 33 |
|
|
|
|
| 23 |
|
| 24 |
def _load_models(self):
|
| 25 |
classification_model = resnet_model
|
| 26 |
+
classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt', map_location=torch.device('cpu')))
|
| 27 |
classification_model.eval()
|
| 28 |
|
| 29 |
segmentation_model = ResNetUnet()
|
| 30 |
+
checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt', map_location=torch.device('cpu'))
|
| 31 |
segmentation_model.load_state_dict(checkpoint['model_state_dict'])
|
| 32 |
segmentation_model.eval()
|
| 33 |
|