AI Agent commited on
Commit
09cf416
·
1 Parent(s): f6409a9

Monkey-patch torch to prevent hardcoded CUDA crashes on CPU instances

Browse files
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -10,10 +10,28 @@ import io
10
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
11
  from huggingface_hub import hf_hub_download
12
 
 
13
  print("Downloading SAM 3 model...")
14
  hf_token = os.environ.get("HF_TOKEN")
15
  ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ── SAM 3 Imports ────────────────────────────────────────────────
18
  try:
19
  from sam3.model_builder import build_sam3_image_model
 
10
  # (HuggingFace Spaces can use the hf_hub_download mechanism)
11
  from huggingface_hub import hf_hub_download
12
 
13
+ # ── HF Token Authentication ────────────────────────────────────────
14
  print("Downloading SAM 3 model...")
15
  hf_token = os.environ.get("HF_TOKEN")
16
  ckpt_path = hf_hub_download(repo_id="facebook/sam3", filename="sam3.pt", token=hf_token)
17
 
18
+ # ── Monkey Patch SAM 3 CUDA Hardcoding Bug ───────────────────────
19
+ # Meta's SAM 3 repo hardcodes `device="cuda"` in position_encoding.py
20
+ # This intercepts torch.zeros to force "cpu" if no GPU is available.
21
+ original_zeros = torch.zeros
22
+ def patched_zeros(*args, **kwargs):
23
+ if kwargs.get('device') == 'cuda' and not torch.cuda.is_available():
24
+ kwargs['device'] = 'cpu'
25
+ return original_zeros(*args, **kwargs)
26
+ torch.zeros = patched_zeros
27
+
28
+ original_arange = torch.arange
29
+ def patched_arange(*args, **kwargs):
30
+ if kwargs.get('device') == 'cuda' and not torch.cuda.is_available():
31
+ kwargs['device'] = 'cpu'
32
+ return original_arange(*args, **kwargs)
33
+ torch.arange = patched_arange
34
+
35
  # ── SAM 3 Imports ────────────────────────────────────────────────
36
  try:
37
  from sam3.model_builder import build_sam3_image_model