AJain1234 commited on
Commit
d26706f
·
verified ·
1 Parent(s): bee0f52

Update Experiments/Resnet50_classification.py

Browse files
Experiments/Resnet50_classification.py CHANGED
@@ -91,7 +91,7 @@ def predict(features_path,image):
91
  image_tensor = transform(pil_image).unsqueeze(0)
92
  resnet = models.resnet50(pretrained=True)
93
  model_check = HiddenLayer(resnet)
94
- model_check.load_state_dict(torch.load("CIFAR_end_hll.pt"))
95
  model_check.eval()
96
 
97
  with torch.no_grad():
 
91
  image_tensor = transform(pil_image).unsqueeze(0)
92
  resnet = models.resnet50(pretrained=True)
93
  model_check = HiddenLayer(resnet)
94
+ model_check.load_state_dict(torch.load("CIFAR_end_hll.pt",map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
95
  model_check.eval()
96
 
97
  with torch.no_grad():