Anigor66 commited on
Commit
3ba0b94
·
1 Parent(s): 58f0767

Fix: Load model with CPU mapping for HuggingFace Space

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -22,11 +22,25 @@ MODEL_CHECKPOINT = "medsam_vit_b.pth"
22
  MODEL_TYPE = "vit_b"
23
 
24
  print("Loading MedSAM model...")
25
- sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
26
- sam.to(device=device)
27
- sam.eval()
28
- predictor = SamPredictor(sam)
29
- print("✓ MedSAM model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def segment_with_points(image, points_json):
 
22
  MODEL_TYPE = "vit_b"
23
 
24
  print("Loading MedSAM model...")
25
+
26
+ # Monkey-patch torch.load to use CPU mapping when needed
27
+ original_torch_load = torch.load
28
+ def patched_torch_load(f, *args, **kwargs):
29
+ if 'map_location' not in kwargs and device == 'cpu':
30
+ kwargs['map_location'] = 'cpu'
31
+ return original_torch_load(f, *args, **kwargs)
32
+
33
+ torch.load = patched_torch_load
34
+
35
+ try:
36
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_CHECKPOINT)
37
+ sam.to(device=device)
38
+ sam.eval()
39
+ predictor = SamPredictor(sam)
40
+ print("✓ MedSAM model loaded successfully!")
41
+ finally:
42
+ # Restore original torch.load
43
+ torch.load = original_torch_load
44
 
45
 
46
  def segment_with_points(image, points_json):