Spaces:
Paused
Paused
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -76,6 +76,42 @@ patch = SAM3D_PATH / "patching" / "hydra"
|
|
| 76 |
if patch.exists():
|
| 77 |
subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
sys.path.insert(0, str(SAM3D_PATH))
|
| 80 |
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
|
| 81 |
|
|
@@ -291,15 +327,6 @@ def test_sam3d_only(image: np.ndarray):
|
|
| 291 |
print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
|
| 292 |
|
| 293 |
from inference import Inference
|
| 294 |
-
# SAM3D's inference_pipeline.py auto-detects H200 and sets ATTN_BACKEND=flash_attn
|
| 295 |
-
# We must override BACK to sdpa since flash_attn is not available
|
| 296 |
-
import sam3d_objects.model.backbone.tdfy_dit.modules.attention as _attn_mod
|
| 297 |
-
import sam3d_objects.model.backbone.tdfy_dit.modules.sparse as _sparse_mod
|
| 298 |
-
_attn_mod.BACKEND = "sdpa"
|
| 299 |
-
_sparse_mod.ATTN = "sdpa"
|
| 300 |
-
os.environ["ATTN_BACKEND"] = "sdpa"
|
| 301 |
-
os.environ["SPARSE_ATTN_BACKEND"] = "sdpa"
|
| 302 |
-
print(f" Attention backends forced to sdpa")
|
| 303 |
print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
|
| 304 |
sam3d = Inference(CONFIG_PATH, compile=False)
|
| 305 |
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
|
|
|
|
| 76 |
if patch.exists():
|
| 77 |
subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
|
| 78 |
|
| 79 |
+
# CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn
|
| 80 |
+
# inference_pipeline.py auto-detects H200/A100 and forces flash_attn,
|
| 81 |
+
# but we don't have the real flash_attn package.
|
| 82 |
+
ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
|
| 83 |
+
if ip_file.exists():
|
| 84 |
+
ip_src = ip_file.read_text()
|
| 85 |
+
# Replace the set_attention_backend function to respect our env vars
|
| 86 |
+
old_fn = """def set_attention_backend():
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 89 |
+
else:
|
| 90 |
+
gpu_name = "CPU"
|
| 91 |
+
|
| 92 |
+
logger.info(f"GPU name is {gpu_name}")
|
| 93 |
+
if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:
|
| 94 |
+
# logger.info("Use flash_attn")
|
| 95 |
+
os.environ["ATTN_BACKEND"] = "flash_attn"
|
| 96 |
+
os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn""""
|
| 97 |
+
new_fn = """def set_attention_backend():
|
| 98 |
+
if torch.cuda.is_available():
|
| 99 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 100 |
+
else:
|
| 101 |
+
gpu_name = "CPU"
|
| 102 |
+
|
| 103 |
+
logger.info(f"GPU name is {gpu_name}")
|
| 104 |
+
# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)
|
| 105 |
+
logger.info("Using sdpa backend (patched for ZeroGPU)")
|
| 106 |
+
os.environ.setdefault("ATTN_BACKEND", "sdpa")
|
| 107 |
+
os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")""""
|
| 108 |
+
if old_fn in ip_src:
|
| 109 |
+
ip_src = ip_src.replace(old_fn, new_fn)
|
| 110 |
+
ip_file.write_text(ip_src)
|
| 111 |
+
print("PATCHED: inference_pipeline.py - forced sdpa backend")
|
| 112 |
+
else:
|
| 113 |
+
print("WARNING: Could not patch inference_pipeline.py")
|
| 114 |
+
|
| 115 |
sys.path.insert(0, str(SAM3D_PATH))
|
| 116 |
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
|
| 117 |
|
|
|
|
| 327 |
print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
|
| 328 |
|
| 329 |
from inference import Inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
|
| 331 |
sam3d = Inference(CONFIG_PATH, compile=False)
|
| 332 |
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
|