jboth commited on
Commit
7b8ab13
·
verified ·
1 Parent(s): 8707d60

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -82,35 +82,24 @@ if patch.exists():
82
  ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
83
  if ip_file.exists():
84
  ip_src = ip_file.read_text()
85
- # Replace the set_attention_backend function to respect our env vars
86
- old_fn = """def set_attention_backend():
87
- if torch.cuda.is_available():
88
- gpu_name = torch.cuda.get_device_name(0)
89
- else:
90
- gpu_name = "CPU"
91
-
92
- logger.info(f"GPU name is {gpu_name}")
93
- if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:
94
- # logger.info("Use flash_attn")
95
- os.environ["ATTN_BACKEND"] = "flash_attn"
96
- os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn""""
97
- new_fn = """def set_attention_backend():
98
- if torch.cuda.is_available():
99
- gpu_name = torch.cuda.get_device_name(0)
100
- else:
101
- gpu_name = "CPU"
102
-
103
- logger.info(f"GPU name is {gpu_name}")
104
- # PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)
105
- logger.info("Using sdpa backend (patched for ZeroGPU)")
106
- os.environ.setdefault("ATTN_BACKEND", "sdpa")
107
- os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")""""
108
- if old_fn in ip_src:
109
- ip_src = ip_src.replace(old_fn, new_fn)
110
  ip_file.write_text(ip_src)
111
  print("PATCHED: inference_pipeline.py - forced sdpa backend")
112
  else:
113
- print("WARNING: Could not patch inference_pipeline.py")
114
 
115
  sys.path.insert(0, str(SAM3D_PATH))
116
  sys.path.insert(0, str(SAM3D_PATH / "notebook"))
 
82
  ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
83
  if ip_file.exists():
84
  ip_src = ip_file.read_text()
85
+ # Find and replace the set_attention_backend function
86
+ old_marker = 'os.environ["ATTN_BACKEND"] = "flash_attn"'
87
+ if old_marker in ip_src:
88
+ # Replace the entire if-block that forces flash_attn
89
+ ip_src = ip_src.replace(
90
+ 'if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:\n'
91
+ ' # logger.info("Use flash_attn")\n'
92
+ ' os.environ["ATTN_BACKEND"] = "flash_attn"\n'
93
+ ' os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn"',
94
+ '# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)\n'
95
+ ' logger.info("Using sdpa backend (patched for ZeroGPU)")\n'
96
+ ' os.environ.setdefault("ATTN_BACKEND", "sdpa")\n'
97
+ ' os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")'
98
+ )
 
 
 
 
 
 
 
 
 
 
 
99
  ip_file.write_text(ip_src)
100
  print("PATCHED: inference_pipeline.py - forced sdpa backend")
101
  else:
102
+ print("INFO: inference_pipeline.py already patched or different version")
103
 
104
  sys.path.insert(0, str(SAM3D_PATH))
105
  sys.path.insert(0, str(SAM3D_PATH / "notebook"))