Upload mamba_ssm_init.py with huggingface_hub
Browse files- mamba_ssm_init.py +69 -0
mamba_ssm_init.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
|
| 2 |
+
# ABI mismatch with the base image's libtorch.
|
| 3 |
+
#
|
| 4 |
+
# The upstream __init__.py eagerly imports selective_scan_cuda which fails on
|
| 5 |
+
# pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
|
| 6 |
+
# symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
|
| 7 |
+
# all compiled-CUDA imports here and let Mamba3 load directly.
|
| 8 |
+
|
| 9 |
+
__version__ = "2.3.1+feather-graft"
|
| 10 |
+
|
| 11 |
+
# selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
|
| 12 |
+
# by the Feather training path (which is Mamba3-only). If any import path
|
| 13 |
+
# hits this, it will get a clear AttributeError instead of an obscure ImportError.
|
| 14 |
+
selective_scan_fn = None
|
| 15 |
+
mamba_inner_fn = None
|
| 16 |
+
|
| 17 |
+
# --- triton API compatibility shims -----------------------------------------
|
| 18 |
+
# Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
|
| 19 |
+
# imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
|
| 20 |
+
# Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
|
| 21 |
+
# tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
|
| 22 |
+
# satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
|
| 23 |
+
# APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
|
| 24 |
+
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
+
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
+
# exist at module load time.
|
| 27 |
+
import triton as _triton # noqa: E402
|
| 28 |
+
if not hasattr(_triton, "set_allocator"):
|
| 29 |
+
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
+
return None
|
| 31 |
+
_triton.set_allocator = _noop_set_allocator
|
| 32 |
+
|
| 33 |
+
import triton.compiler.compiler as _tcc # noqa: E402
|
| 34 |
+
if not hasattr(_tcc, "AttrsDescriptor"):
|
| 35 |
+
class _AttrsDescriptorShim:
|
| 36 |
+
"""Stub for torch._inductor compatibility on triton >= 3.4.
|
| 37 |
+
torch._inductor.runtime.hints imports this at module load but the
|
| 38 |
+
constructor is only called inside torch.compile paths. Accept any
|
| 39 |
+
args/kwargs so the import itself succeeds."""
|
| 40 |
+
def __init__(self, *args, **kwargs):
|
| 41 |
+
self.args = args
|
| 42 |
+
self.kwargs = kwargs
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_hints(cls, *args, **kwargs):
|
| 46 |
+
return cls(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
_tcc.AttrsDescriptor = _AttrsDescriptorShim
|
| 49 |
+
|
| 50 |
+
# triton_key: removed in triton 3.5, used by torch._inductor.codecache for
|
| 51 |
+
# FxGraphCache key derivation. Return a stable string so caching still works.
|
| 52 |
+
if not hasattr(_tcc, "triton_key"):
|
| 53 |
+
def _triton_key_shim():
|
| 54 |
+
import triton as _t
|
| 55 |
+
return f"triton-{_t.__version__}-shim"
|
| 56 |
+
_tcc.triton_key = _triton_key_shim
|
| 57 |
+
|
| 58 |
+
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 59 |
+
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
| 60 |
+
# so fall back to eager on any dynamo failure rather than crashing. This is
|
| 61 |
+
# defense-in-depth against further triton API drift.
|
| 62 |
+
try:
|
| 63 |
+
import torch._dynamo # noqa: F401 — triggers dynamo module init
|
| 64 |
+
torch._dynamo.config.suppress_errors = True
|
| 65 |
+
except Exception: # pragma: no cover
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
# Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
|
| 69 |
+
from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402
|