xinjjj commited on
Commit
96809f7
·
1 Parent(s): 6f48d96
common.py CHANGED
@@ -14,15 +14,23 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
- import spaces
18
- from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
 
 
 
19
 
20
- monkey_path_trellis()
21
  from embodied_gen.utils.monkey_patch.gradio import (
22
- _patch_open3d_cuda_device_count_bug,
 
 
23
  )
 
24
 
25
- _patch_open3d_cuda_device_count_bug()
 
 
 
26
 
27
  import gc
28
  import logging
@@ -161,7 +169,7 @@ def end_session(req: gr.Request) -> None:
161
  shutil.rmtree(user_dir)
162
 
163
 
164
- @spaces.GPU
165
  def preprocess_image_fn(
166
  image: str | np.ndarray | Image.Image,
167
  rmbg_tag: str = "rembg",
@@ -280,7 +288,7 @@ def select_point(
280
  return (image, masks), seg_image
281
 
282
 
283
- @spaces.GPU
284
  def image_to_3d(
285
  image: Image.Image,
286
  seed: int,
@@ -581,7 +589,7 @@ def extract_urdf(
581
  )
582
 
583
 
584
- @spaces.GPU
585
  def text2image_fn(
586
  prompt: str,
587
  guidance_scale: float,
@@ -637,7 +645,7 @@ def text2image_fn(
637
  return save_paths + save_paths
638
 
639
 
640
- @spaces.GPU
641
  def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
642
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
643
 
@@ -653,7 +661,7 @@ def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
653
  return None, None, None
654
 
655
 
656
- @spaces.GPU
657
  def generate_texture_mvimages(
658
  prompt: str,
659
  controlnet_cond_scale: float = 0.55,
@@ -740,7 +748,7 @@ def backproject_texture(
740
  return output_glb_mesh, output_obj_mesh, zip_file
741
 
742
 
743
- @spaces.GPU
744
  def backproject_texture_v2(
745
  mesh_path: str,
746
  input_image: str,
@@ -787,7 +795,7 @@ def backproject_texture_v2(
787
  return output_glb_mesh, output_obj_mesh, zip_file
788
 
789
 
790
- @spaces.GPU
791
  def render_result_video(
792
  mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
793
  ) -> str:
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
+ # from embodied_gen.utils.monkey_patch.gradio import _patch_spaces_zerogpu_logs
18
+
19
+ # _patch_spaces_zerogpu_logs()
20
+
21
+ import spaces # noqa: E402
22
 
 
23
  from embodied_gen.utils.monkey_patch.gradio import (
24
+ _disable_xformers_flash3,
25
+ # _neutralize_warp_in_parent,
26
+ # _patch_open3d_cuda_device_count_bug,
27
  )
28
+ from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
29
 
30
+ # _neutralize_warp_in_parent()
31
+ # _patch_open3d_cuda_device_count_bug()
32
+ _disable_xformers_flash3()
33
+ monkey_path_trellis()
34
 
35
  import gc
36
  import logging
 
169
  shutil.rmtree(user_dir)
170
 
171
 
172
+ @spaces.GPU(duration=180)
173
  def preprocess_image_fn(
174
  image: str | np.ndarray | Image.Image,
175
  rmbg_tag: str = "rembg",
 
288
  return (image, masks), seg_image
289
 
290
 
291
+ @spaces.GPU(duration=180)
292
  def image_to_3d(
293
  image: Image.Image,
294
  seed: int,
 
589
  )
590
 
591
 
592
+ @spaces.GPU(duration=180)
593
  def text2image_fn(
594
  prompt: str,
595
  guidance_scale: float,
 
645
  return save_paths + save_paths
646
 
647
 
648
+ @spaces.GPU(duration=180)
649
  def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
650
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
651
 
 
661
  return None, None, None
662
 
663
 
664
+ @spaces.GPU(duration=180)
665
  def generate_texture_mvimages(
666
  prompt: str,
667
  controlnet_cond_scale: float = 0.55,
 
748
  return output_glb_mesh, output_obj_mesh, zip_file
749
 
750
 
751
+ @spaces.GPU(duration=180)
752
  def backproject_texture_v2(
753
  mesh_path: str,
754
  input_image: str,
 
795
  return output_glb_mesh, output_obj_mesh, zip_file
796
 
797
 
798
+ @spaces.GPU(duration=180)
799
  def render_result_video(
800
  mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
801
  ) -> str:
embodied_gen/utils/monkey_patch/gradio.py CHANGED
@@ -56,3 +56,83 @@ def _patch_open3d_cuda_device_count_bug() -> None:
56
  ),
57
  end='',
58
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ),
57
  end='',
58
  )
59
+
60
+
61
+ def _neutralize_warp_in_parent() -> None:
62
+ """Prevent NVIDIA Warp from calling cuInit() in the ZeroGPU parent.
63
+
64
+ Root cause of @spaces.GPU silent hangs (spaces>=0.50): kaolin imports
65
+ warp at module top-level. When any kaolin module triggers warp.init(),
66
+ Warp's `init_cuda_driver` dlopens libcuda.so + calls cuInit() in the
67
+ parent process. After spaces forks the worker, torch.init(nvidia_uuid)
68
+ in the worker hangs forever because the inherited CUDA driver state is
69
+ poisoned (parent never had a real GPU; ZeroGPU exposes one only post-fork).
70
+
71
+ Fix: stub warp.init / warp.context.runtime_init with a pid-aware no-op.
72
+ The parent-resident pid skips init; the forked worker (different pid)
73
+ runs the real init so warp keeps working inside @spaces.GPU code paths.
74
+
75
+ Must be called BEFORE any import that pulls kaolin (e.g. embodied_gen.data,
76
+ thirdparty.TRELLIS).
77
+ """
78
+ import os
79
+ import sys
80
+
81
+ try:
82
+ import warp # noqa: F401 -- pure python import, no cuInit
83
+ except ImportError:
84
+ return
85
+
86
+ parent_pid = os.getpid()
87
+
88
+ def _make_pid_safe(orig):
89
+ def _wrapped(*args, **kwargs):
90
+ if os.getpid() == parent_pid:
91
+ sys.stderr.write(
92
+ f"[warp-neutralize] skip {orig.__name__} in parent pid={parent_pid}\n"
93
+ )
94
+ sys.stderr.flush()
95
+ return None
96
+ return orig(*args, **kwargs)
97
+ _wrapped.__wrapped__ = orig
98
+ _wrapped.__name__ = getattr(orig, "__name__", "wrapped")
99
+ return _wrapped
100
+
101
+ if hasattr(warp, "init") and not hasattr(warp.init, "__wrapped__"):
102
+ warp.init = _make_pid_safe(warp.init)
103
+
104
+ try:
105
+ from warp import context as _wctx
106
+ if hasattr(_wctx, "runtime_init") and not hasattr(
107
+ _wctx.runtime_init, "__wrapped__"
108
+ ):
109
+ _wctx.runtime_init = _make_pid_safe(_wctx.runtime_init)
110
+ except Exception:
111
+ pass
112
+
113
+
114
+ def _disable_xformers_flash3() -> None:
115
+ """Force xformers dispatcher to skip Flash-Attention v3 (Hopper-only).
116
+
117
+ sm_120 (Blackwell) has no FA3 kernel binary; the dispatcher still picks
118
+ flash3 and the launch aborts with:
119
+ `CUDA error ... hopper/flash_fwd_launch_template.h:188: invalid argument`
120
+ Env vars `XFORMERS_FLASH3_ATTENTION_DISABLED=1` are silently ignored in
121
+ xformers 0.0.32.post2, so we patch `not_supported_reasons` directly.
122
+ Cutlass and FA2 both work on sm_120, so removing flash3 from candidates
123
+ is enough.
124
+ """
125
+ try:
126
+ from xformers.ops.fmha import flash3 as _f3
127
+ except Exception:
128
+ return
129
+
130
+ _disabled = ["disabled by EmbodiedGen: no FA3 kernel for sm_120"]
131
+
132
+ def _ns(cls, d): # noqa: ARG001
133
+ return list(_disabled)
134
+
135
+ if hasattr(_f3, "FwOp"):
136
+ _f3.FwOp.not_supported_reasons = classmethod(_ns)
137
+ if hasattr(_f3, "BwOp"):
138
+ _f3.BwOp.not_supported_reasons = classmethod(_ns)