File size: 4,818 Bytes
7d21266 d7b1c6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import fileinput
import site
import gradio_client.utils as gradio_client_utils
def _patch_gradio_schema_bool_bug() -> None:
"""Patch schema parser for bool-style for gradio<5.33."""
original_get_type = gradio_client_utils.get_type
original_json_schema_to_python_type = (
gradio_client_utils._json_schema_to_python_type
)
def _safe_get_type(schema):
if isinstance(schema, bool):
return {}
return original_get_type(schema)
def _safe_json_schema_to_python_type(schema, defs):
if isinstance(schema, bool):
return "Any"
return original_json_schema_to_python_type(schema, defs)
gradio_client_utils.get_type = _safe_get_type
gradio_client_utils._json_schema_to_python_type = (
_safe_json_schema_to_python_type
)
def _patch_open3d_cuda_device_count_bug() -> None:
"""Patch open3d to avoid cuda device count bug."""
with fileinput.FileInput(
f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True
) as file:
for line in file:
print(
line.replace(
'_pybind_cuda.open3d_core_cuda_device_count()', '1'
),
end='',
)
def _neutralize_warp_in_parent() -> None:
"""Prevent NVIDIA Warp from calling cuInit() in the ZeroGPU parent.
Root cause of @spaces.GPU silent hangs (spaces>=0.50): kaolin imports
warp at module top-level. When any kaolin module triggers warp.init(),
Warp's `init_cuda_driver` dlopens libcuda.so + calls cuInit() in the
parent process. After spaces forks the worker, torch.init(nvidia_uuid)
in the worker hangs forever because the inherited CUDA driver state is
poisoned (parent never had a real GPU; ZeroGPU exposes one only post-fork).
Fix: stub warp.init / warp.context.runtime_init with a pid-aware no-op.
The parent-resident pid skips init; the forked worker (different pid)
runs the real init so warp keeps working inside @spaces.GPU code paths.
Must be called BEFORE any import that pulls kaolin (e.g. embodied_gen.data,
thirdparty.TRELLIS).
"""
import os
import sys
try:
import warp # noqa: F401 -- pure python import, no cuInit
except ImportError:
return
parent_pid = os.getpid()
def _make_pid_safe(orig):
def _wrapped(*args, **kwargs):
if os.getpid() == parent_pid:
sys.stderr.write(
f"[warp-neutralize] skip {orig.__name__} in parent pid={parent_pid}\n"
)
sys.stderr.flush()
return None
return orig(*args, **kwargs)
_wrapped.__wrapped__ = orig
_wrapped.__name__ = getattr(orig, "__name__", "wrapped")
return _wrapped
if hasattr(warp, "init") and not hasattr(warp.init, "__wrapped__"):
warp.init = _make_pid_safe(warp.init)
try:
from warp import context as _wctx
if hasattr(_wctx, "runtime_init") and not hasattr(
_wctx.runtime_init, "__wrapped__"
):
_wctx.runtime_init = _make_pid_safe(_wctx.runtime_init)
except Exception:
pass
def _disable_xformers_flash3() -> None:
"""Force xformers dispatcher to skip Flash-Attention v3 (Hopper-only).
sm_120 (Blackwell) has no FA3 kernel binary; the dispatcher still picks
flash3 and the launch aborts with:
`CUDA error ... hopper/flash_fwd_launch_template.h:188: invalid argument`
Env vars `XFORMERS_FLASH3_ATTENTION_DISABLED=1` are silently ignored in
xformers 0.0.32.post2, so we patch `not_supported_reasons` directly.
Cutlass and FA2 both work on sm_120, so removing flash3 from candidates
is enough.
"""
try:
from xformers.ops.fmha import flash3 as _f3
except Exception:
return
_disabled = ["disabled by EmbodiedGen: no FA3 kernel for sm_120"]
def _ns(cls, d): # noqa: ARG001
return list(_disabled)
if hasattr(_f3, "FwOp"):
_f3.FwOp.not_supported_reasons = classmethod(_ns)
if hasattr(_f3, "BwOp"):
_f3.BwOp.not_supported_reasons = classmethod(_ns)
|