# mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so # ABI mismatch with the base image's libtorch. # # The upstream __init__.py eagerly imports selective_scan_cuda which fails on # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip # all compiled-CUDA imports here and let Mamba3 load directly. __version__ = "2.3.1+feather-graft" # selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used # by the Feather training path (which is Mamba3-only). If any import path # hits this, it will get a clear AttributeError instead of an obscure ImportError. selective_scan_fn = None mamba_inner_fn = None # --- triton API compatibility shims ----------------------------------------- # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor # imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+. # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3 # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The # stub is never actually invoked at runtime because the codebase does not use # torch.compile — but importing torch._inductor.* still requires the symbol to # exist at module load time. import triton as _triton # noqa: E402 if not hasattr(_triton, "set_allocator"): def _noop_set_allocator(_fn): # pragma: no cover return None _triton.set_allocator = _noop_set_allocator import triton.compiler.compiler as _tcc # noqa: E402 if not hasattr(_tcc, "AttrsDescriptor"): class _AttrsDescriptorShim: """Stub for torch._inductor compatibility on triton >= 3.4. torch._inductor.runtime.hints imports this at module load but the constructor is only called inside torch.compile paths. Accept any args/kwargs so the import itself succeeds.""" def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs @classmethod def from_hints(cls, *args, **kwargs): return cls(*args, **kwargs) _tcc.AttrsDescriptor = _AttrsDescriptorShim # triton_key: removed in triton 3.5, used by torch._inductor.codecache for # FxGraphCache key derivation. Return a stable string so caching still works. if not hasattr(_tcc, "triton_key"): def _triton_key_shim(): import triton as _t return f"triton-{_t.__version__}-shim" _tcc.triton_key = _triton_key_shim # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile # for performance in this codebase (Muon + mamba3 CUDA kernels already fused), # so fall back to eager on any dynamo failure rather than crashing. This is # defense-in-depth against further triton API drift. try: import torch._dynamo # noqa: F401 — triggers dynamo module init torch._dynamo.config.suppress_errors = True except Exception: # pragma: no cover pass # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`. from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402