Spaces:
Running on Zero
Running on Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import fileinput | |
| import site | |
| import gradio_client.utils as gradio_client_utils | |
| def _patch_gradio_schema_bool_bug() -> None: | |
| """Patch schema parser for bool-style for gradio<5.33.""" | |
| original_get_type = gradio_client_utils.get_type | |
| original_json_schema_to_python_type = ( | |
| gradio_client_utils._json_schema_to_python_type | |
| ) | |
| def _safe_get_type(schema): | |
| if isinstance(schema, bool): | |
| return {} | |
| return original_get_type(schema) | |
| def _safe_json_schema_to_python_type(schema, defs): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return original_json_schema_to_python_type(schema, defs) | |
| gradio_client_utils.get_type = _safe_get_type | |
| gradio_client_utils._json_schema_to_python_type = ( | |
| _safe_json_schema_to_python_type | |
| ) | |
| def _patch_open3d_cuda_device_count_bug() -> None: | |
| """Patch open3d to avoid cuda device count bug.""" | |
| with fileinput.FileInput( | |
| f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True | |
| ) as file: | |
| for line in file: | |
| print( | |
| line.replace( | |
| '_pybind_cuda.open3d_core_cuda_device_count()', '1' | |
| ), | |
| end='', | |
| ) | |
| def _neutralize_warp_in_parent() -> None: | |
| """Prevent NVIDIA Warp from calling cuInit() in the ZeroGPU parent. | |
| Root cause of @spaces.GPU silent hangs (spaces>=0.50): kaolin imports | |
| warp at module top-level. When any kaolin module triggers warp.init(), | |
| Warp's `init_cuda_driver` dlopens libcuda.so + calls cuInit() in the | |
| parent process. After spaces forks the worker, torch.init(nvidia_uuid) | |
| in the worker hangs forever because the inherited CUDA driver state is | |
| poisoned (parent never had a real GPU; ZeroGPU exposes one only post-fork). | |
| Fix: stub warp.init / warp.context.runtime_init with a pid-aware no-op. | |
| The parent-resident pid skips init; the forked worker (different pid) | |
| runs the real init so warp keeps working inside @spaces.GPU code paths. | |
| Must be called BEFORE any import that pulls kaolin (e.g. embodied_gen.data, | |
| thirdparty.TRELLIS). | |
| """ | |
| import os | |
| import sys | |
| try: | |
| import warp # noqa: F401 -- pure python import, no cuInit | |
| except ImportError: | |
| return | |
| parent_pid = os.getpid() | |
| def _make_pid_safe(orig): | |
| def _wrapped(*args, **kwargs): | |
| if os.getpid() == parent_pid: | |
| sys.stderr.write( | |
| f"[warp-neutralize] skip {orig.__name__} in parent pid={parent_pid}\n" | |
| ) | |
| sys.stderr.flush() | |
| return None | |
| return orig(*args, **kwargs) | |
| _wrapped.__wrapped__ = orig | |
| _wrapped.__name__ = getattr(orig, "__name__", "wrapped") | |
| return _wrapped | |
| if hasattr(warp, "init") and not hasattr(warp.init, "__wrapped__"): | |
| warp.init = _make_pid_safe(warp.init) | |
| try: | |
| from warp import context as _wctx | |
| if hasattr(_wctx, "runtime_init") and not hasattr( | |
| _wctx.runtime_init, "__wrapped__" | |
| ): | |
| _wctx.runtime_init = _make_pid_safe(_wctx.runtime_init) | |
| except Exception: | |
| pass | |
| def _disable_xformers_flash3() -> None: | |
| """Force xformers dispatcher to skip Flash-Attention v3 (Hopper-only). | |
| sm_120 (Blackwell) has no FA3 kernel binary; the dispatcher still picks | |
| flash3 and the launch aborts with: | |
| `CUDA error ... hopper/flash_fwd_launch_template.h:188: invalid argument` | |
| Env vars `XFORMERS_FLASH3_ATTENTION_DISABLED=1` are silently ignored in | |
| xformers 0.0.32.post2, so we patch `not_supported_reasons` directly. | |
| Cutlass and FA2 both work on sm_120, so removing flash3 from candidates | |
| is enough. | |
| """ | |
| try: | |
| from xformers.ops.fmha import flash3 as _f3 | |
| except Exception: | |
| return | |
| _disabled = ["disabled by EmbodiedGen: no FA3 kernel for sm_120"] | |
| def _ns(cls, d): # noqa: ARG001 | |
| return list(_disabled) | |
| if hasattr(_f3, "FwOp"): | |
| _f3.FwOp.not_supported_reasons = classmethod(_ns) | |
| if hasattr(_f3, "BwOp"): | |
| _f3.BwOp.not_supported_reasons = classmethod(_ns) | |