Spaces:
Sleeping
Sleeping
change map location to CPU
Browse files- inference_sam.py +12 -1
inference_sam.py
CHANGED
|
@@ -33,11 +33,22 @@ if not os.path.exists('model'):
|
|
| 33 |
print("warning! A read token in env variables is needed for authentication.")
|
| 34 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
|
| 37 |
-
sam = sam_model_registry["default"](model_path
|
| 38 |
sam.to(device) #sam.cuda()
|
| 39 |
predictor = SamPredictor(sam)
|
| 40 |
|
|
|
|
|
|
|
| 41 |
|
| 42 |
from torch.nn import functional as F
|
| 43 |
|
|
|
|
| 33 |
print("warning! A read token in env variables is needed for authentication.")
|
| 34 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
|
| 35 |
|
| 36 |
+
|
| 37 |
+
original_torch_load = torch.load
|
| 38 |
+
|
| 39 |
+
def patched_torch_load(*args, **kwargs):
|
| 40 |
+
kwargs['map_location'] = device
|
| 41 |
+
return original_torch_load(*args, **kwargs)
|
| 42 |
+
|
| 43 |
+
torch.load = patched_torch_load
|
| 44 |
+
|
| 45 |
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
|
| 46 |
+
sam = sam_model_registry["default"](model_path)
|
| 47 |
sam.to(device) #sam.cuda()
|
| 48 |
predictor = SamPredictor(sam)
|
| 49 |
|
| 50 |
+
torch.load = original_torch_load
|
| 51 |
+
|
| 52 |
|
| 53 |
from torch.nn import functional as F
|
| 54 |
|