Spaces:
Runtime error
Runtime error
Update Experiments/Resnet50_classification.py
Browse files
Experiments/Resnet50_classification.py
CHANGED
|
@@ -91,7 +91,7 @@ def predict(features_path,image):
|
|
| 91 |
image_tensor = transform(pil_image).unsqueeze(0)
|
| 92 |
resnet = models.resnet50(pretrained=True)
|
| 93 |
model_check = HiddenLayer(resnet)
|
| 94 |
-
model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
|
| 95 |
model_check.eval()
|
| 96 |
|
| 97 |
with torch.no_grad():
|
|
|
|
| 91 |
image_tensor = transform(pil_image).unsqueeze(0)
|
| 92 |
resnet = models.resnet50(pretrained=True)
|
| 93 |
model_check = HiddenLayer(resnet)
|
| 94 |
+
model_check.load_state_dict(torch.load("CIFAR_end_hll.pt",map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
|
| 95 |
model_check.eval()
|
| 96 |
|
| 97 |
with torch.no_grad():
|