File size: 3,281 Bytes
502c880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# mamba_ssm package init β€” minimal override to avoid broken selective_scan_cuda.so
# ABI mismatch with the base image's libtorch.
#
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
# all compiled-CUDA imports here and let Mamba3 load directly.

__version__ = "2.3.1+feather-graft"

# selective_scan_fn / mamba_inner_fn are shimmed to None β€” they are NOT used
# by the Feather training path (which is Mamba3-only). If any import path
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
selective_scan_fn = None
mamba_inner_fn = None

# --- triton API compatibility shims -----------------------------------------
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
# imports AttrsDescriptor from triton.compiler.compiler β€” removed in triton 3.4+.
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
# stub is never actually invoked at runtime because the codebase does not use
# torch.compile β€” but importing torch._inductor.* still requires the symbol to
# exist at module load time.
import triton as _triton  # noqa: E402
if not hasattr(_triton, "set_allocator"):
    def _noop_set_allocator(_fn):  # pragma: no cover
        return None
    _triton.set_allocator = _noop_set_allocator

import triton.compiler.compiler as _tcc  # noqa: E402
if not hasattr(_tcc, "AttrsDescriptor"):
    class _AttrsDescriptorShim:
        """Stub for torch._inductor compatibility on triton >= 3.4.
        torch._inductor.runtime.hints imports this at module load but the
        constructor is only called inside torch.compile paths. Accept any
        args/kwargs so the import itself succeeds."""
        def __init__(self, *args, **kwargs):
            self.args = args
            self.kwargs = kwargs

        @classmethod
        def from_hints(cls, *args, **kwargs):
            return cls(*args, **kwargs)

    _tcc.AttrsDescriptor = _AttrsDescriptorShim

# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
# FxGraphCache key derivation. Return a stable string so caching still works.
if not hasattr(_tcc, "triton_key"):
    def _triton_key_shim():
        import triton as _t
        return f"triton-{_t.__version__}-shim"
    _tcc.triton_key = _triton_key_shim

# Suppress torch.compile/_dynamo errors globally β€” we don't rely on torch.compile
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
# so fall back to eager on any dynamo failure rather than crashing. This is
# defense-in-depth against further triton API drift.
try:
    import torch._dynamo  # noqa: F401 β€” triggers dynamo module init
    torch._dynamo.config.suppress_errors = True
except Exception:  # pragma: no cover
    pass

# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
from mamba_ssm.modules.mamba3 import Mamba3  # noqa: E402