jboth commited on
Commit
ccebd17
·
verified ·
1 Parent(s): 37b8f65

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -76,6 +76,42 @@ patch = SAM3D_PATH / "patching" / "hydra"
76
  if patch.exists():
77
  subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  sys.path.insert(0, str(SAM3D_PATH))
80
  sys.path.insert(0, str(SAM3D_PATH / "notebook"))
81
 
@@ -291,15 +327,6 @@ def test_sam3d_only(image: np.ndarray):
291
  print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
292
 
293
  from inference import Inference
294
- # SAM3D's inference_pipeline.py auto-detects H200 and sets ATTN_BACKEND=flash_attn
295
- # We must override BACK to sdpa since flash_attn is not available
296
- import sam3d_objects.model.backbone.tdfy_dit.modules.attention as _attn_mod
297
- import sam3d_objects.model.backbone.tdfy_dit.modules.sparse as _sparse_mod
298
- _attn_mod.BACKEND = "sdpa"
299
- _sparse_mod.ATTN = "sdpa"
300
- os.environ["ATTN_BACKEND"] = "sdpa"
301
- os.environ["SPARSE_ATTN_BACKEND"] = "sdpa"
302
- print(f" Attention backends forced to sdpa")
303
  print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
304
  sam3d = Inference(CONFIG_PATH, compile=False)
305
  print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
 
76
  if patch.exists():
77
  subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
78
 
79
+ # CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn
80
+ # inference_pipeline.py auto-detects H200/A100 and forces flash_attn,
81
+ # but we don't have the real flash_attn package.
82
+ ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
83
+ if ip_file.exists():
84
+ ip_src = ip_file.read_text()
85
+ # Replace the set_attention_backend function to respect our env vars
86
+ old_fn = """def set_attention_backend():
87
+ if torch.cuda.is_available():
88
+ gpu_name = torch.cuda.get_device_name(0)
89
+ else:
90
+ gpu_name = "CPU"
91
+
92
+ logger.info(f"GPU name is {gpu_name}")
93
+ if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:
94
+ # logger.info("Use flash_attn")
95
+ os.environ["ATTN_BACKEND"] = "flash_attn"
96
+ os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn""""
97
+ new_fn = """def set_attention_backend():
98
+ if torch.cuda.is_available():
99
+ gpu_name = torch.cuda.get_device_name(0)
100
+ else:
101
+ gpu_name = "CPU"
102
+
103
+ logger.info(f"GPU name is {gpu_name}")
104
+ # PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)
105
+ logger.info("Using sdpa backend (patched for ZeroGPU)")
106
+ os.environ.setdefault("ATTN_BACKEND", "sdpa")
107
+ os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")""""
108
+ if old_fn in ip_src:
109
+ ip_src = ip_src.replace(old_fn, new_fn)
110
+ ip_file.write_text(ip_src)
111
+ print("PATCHED: inference_pipeline.py - forced sdpa backend")
112
+ else:
113
+ print("WARNING: Could not patch inference_pipeline.py")
114
+
115
  sys.path.insert(0, str(SAM3D_PATH))
116
  sys.path.insert(0, str(SAM3D_PATH / "notebook"))
117
 
 
327
  print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
328
 
329
  from inference import Inference
 
 
 
 
 
 
 
 
 
330
  print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
331
  sam3d = Inference(CONFIG_PATH, compile=False)
332
  print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")