File size: 4,818 Bytes
7d21266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7b1c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.


import fileinput
import site

import gradio_client.utils as gradio_client_utils


def _patch_gradio_schema_bool_bug() -> None:
    """Patch schema parser for bool-style for gradio<5.33."""
    original_get_type = gradio_client_utils.get_type
    original_json_schema_to_python_type = (
        gradio_client_utils._json_schema_to_python_type
    )

    def _safe_get_type(schema):
        if isinstance(schema, bool):
            return {}
        return original_get_type(schema)

    def _safe_json_schema_to_python_type(schema, defs):
        if isinstance(schema, bool):
            return "Any"
        return original_json_schema_to_python_type(schema, defs)

    gradio_client_utils.get_type = _safe_get_type
    gradio_client_utils._json_schema_to_python_type = (
        _safe_json_schema_to_python_type
    )


def _patch_open3d_cuda_device_count_bug() -> None:
    """Patch open3d to avoid cuda device count bug."""
    with fileinput.FileInput(
        f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True
    ) as file:
        for line in file:
            print(
                line.replace(
                    '_pybind_cuda.open3d_core_cuda_device_count()', '1'
                ),
                end='',
            )


def _neutralize_warp_in_parent() -> None:
    """Prevent NVIDIA Warp from calling cuInit() in the ZeroGPU parent.

    Root cause of @spaces.GPU silent hangs (spaces>=0.50): kaolin imports
    warp at module top-level. When any kaolin module triggers warp.init(),
    Warp's `init_cuda_driver` dlopens libcuda.so + calls cuInit() in the
    parent process. After spaces forks the worker, torch.init(nvidia_uuid)
    in the worker hangs forever because the inherited CUDA driver state is
    poisoned (parent never had a real GPU; ZeroGPU exposes one only post-fork).

    Fix: stub warp.init / warp.context.runtime_init with a pid-aware no-op.
    The parent-resident pid skips init; the forked worker (different pid)
    runs the real init so warp keeps working inside @spaces.GPU code paths.

    Must be called BEFORE any import that pulls kaolin (e.g. embodied_gen.data,
    thirdparty.TRELLIS).
    """
    import os
    import sys

    try:
        import warp  # noqa: F401  -- pure python import, no cuInit
    except ImportError:
        return

    parent_pid = os.getpid()

    def _make_pid_safe(orig):
        def _wrapped(*args, **kwargs):
            if os.getpid() == parent_pid:
                sys.stderr.write(
                    f"[warp-neutralize] skip {orig.__name__} in parent pid={parent_pid}\n"
                )
                sys.stderr.flush()
                return None
            return orig(*args, **kwargs)

        _wrapped.__wrapped__ = orig
        _wrapped.__name__ = getattr(orig, "__name__", "wrapped")
        return _wrapped

    if hasattr(warp, "init") and not hasattr(warp.init, "__wrapped__"):
        warp.init = _make_pid_safe(warp.init)

    try:
        from warp import context as _wctx

        if hasattr(_wctx, "runtime_init") and not hasattr(
            _wctx.runtime_init, "__wrapped__"
        ):
            _wctx.runtime_init = _make_pid_safe(_wctx.runtime_init)
    except Exception:
        pass


def _disable_xformers_flash3() -> None:
    """Force xformers dispatcher to skip Flash-Attention v3 (Hopper-only).

    sm_120 (Blackwell) has no FA3 kernel binary; the dispatcher still picks
    flash3 and the launch aborts with:
      `CUDA error ... hopper/flash_fwd_launch_template.h:188: invalid argument`
    Env vars `XFORMERS_FLASH3_ATTENTION_DISABLED=1` are silently ignored in
    xformers 0.0.32.post2, so we patch `not_supported_reasons` directly.
    Cutlass and FA2 both work on sm_120, so removing flash3 from candidates
    is enough.
    """
    try:
        from xformers.ops.fmha import flash3 as _f3
    except Exception:
        return

    _disabled = ["disabled by EmbodiedGen: no FA3 kernel for sm_120"]

    def _ns(cls, d):  # noqa: ARG001
        return list(_disabled)

    if hasattr(_f3, "FwOp"):
        _f3.FwOp.not_supported_reasons = classmethod(_ns)
    if hasattr(_f3, "BwOp"):
        _f3.BwOp.not_supported_reasons = classmethod(_ns)