xinjjj's picture
chore(space): sync EmbodiedGen text space updates
d7b1c6f
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import fileinput
import site
import gradio_client.utils as gradio_client_utils
def _patch_gradio_schema_bool_bug() -> None:
"""Patch schema parser for bool-style for gradio<5.33."""
original_get_type = gradio_client_utils.get_type
original_json_schema_to_python_type = (
gradio_client_utils._json_schema_to_python_type
)
def _safe_get_type(schema):
if isinstance(schema, bool):
return {}
return original_get_type(schema)
def _safe_json_schema_to_python_type(schema, defs):
if isinstance(schema, bool):
return "Any"
return original_json_schema_to_python_type(schema, defs)
gradio_client_utils.get_type = _safe_get_type
gradio_client_utils._json_schema_to_python_type = (
_safe_json_schema_to_python_type
)
def _patch_open3d_cuda_device_count_bug() -> None:
"""Patch open3d to avoid cuda device count bug."""
with fileinput.FileInput(
f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True
) as file:
for line in file:
print(
line.replace(
'_pybind_cuda.open3d_core_cuda_device_count()', '1'
),
end='',
)
def _neutralize_warp_in_parent() -> None:
"""Prevent NVIDIA Warp from calling cuInit() in the ZeroGPU parent.
Root cause of @spaces.GPU silent hangs (spaces>=0.50): kaolin imports
warp at module top-level. When any kaolin module triggers warp.init(),
Warp's `init_cuda_driver` dlopens libcuda.so + calls cuInit() in the
parent process. After spaces forks the worker, torch.init(nvidia_uuid)
in the worker hangs forever because the inherited CUDA driver state is
poisoned (parent never had a real GPU; ZeroGPU exposes one only post-fork).
Fix: stub warp.init / warp.context.runtime_init with a pid-aware no-op.
The parent-resident pid skips init; the forked worker (different pid)
runs the real init so warp keeps working inside @spaces.GPU code paths.
Must be called BEFORE any import that pulls kaolin (e.g. embodied_gen.data,
thirdparty.TRELLIS).
"""
import os
import sys
try:
import warp # noqa: F401 -- pure python import, no cuInit
except ImportError:
return
parent_pid = os.getpid()
def _make_pid_safe(orig):
def _wrapped(*args, **kwargs):
if os.getpid() == parent_pid:
sys.stderr.write(
f"[warp-neutralize] skip {orig.__name__} in parent pid={parent_pid}\n"
)
sys.stderr.flush()
return None
return orig(*args, **kwargs)
_wrapped.__wrapped__ = orig
_wrapped.__name__ = getattr(orig, "__name__", "wrapped")
return _wrapped
if hasattr(warp, "init") and not hasattr(warp.init, "__wrapped__"):
warp.init = _make_pid_safe(warp.init)
try:
from warp import context as _wctx
if hasattr(_wctx, "runtime_init") and not hasattr(
_wctx.runtime_init, "__wrapped__"
):
_wctx.runtime_init = _make_pid_safe(_wctx.runtime_init)
except Exception:
pass
def _disable_xformers_flash3() -> None:
"""Force xformers dispatcher to skip Flash-Attention v3 (Hopper-only).
sm_120 (Blackwell) has no FA3 kernel binary; the dispatcher still picks
flash3 and the launch aborts with:
`CUDA error ... hopper/flash_fwd_launch_template.h:188: invalid argument`
Env vars `XFORMERS_FLASH3_ATTENTION_DISABLED=1` are silently ignored in
xformers 0.0.32.post2, so we patch `not_supported_reasons` directly.
Cutlass and FA2 both work on sm_120, so removing flash3 from candidates
is enough.
"""
try:
from xformers.ops.fmha import flash3 as _f3
except Exception:
return
_disabled = ["disabled by EmbodiedGen: no FA3 kernel for sm_120"]
def _ns(cls, d): # noqa: ARG001
return list(_disabled)
if hasattr(_f3, "FwOp"):
_f3.FwOp.not_supported_reasons = classmethod(_ns)
if hasattr(_f3, "BwOp"):
_f3.BwOp.not_supported_reasons = classmethod(_ns)