icarus112 commited on
Commit
502c880
·
verified ·
1 Parent(s): 861dd6c

Upload mamba_ssm_init.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mamba_ssm_init.py +69 -0
mamba_ssm_init.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mamba_ssm package init — minimal override to avoid broken selective_scan_cuda.so
2
+ # ABI mismatch with the base image's libtorch.
3
+ #
4
+ # The upstream __init__.py eagerly imports selective_scan_cuda which fails on
5
+ # pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel (undefined c10::Warning ctor
6
+ # symbol). We only need Mamba3 (grafted from main, pure-Triton), so we skip
7
+ # all compiled-CUDA imports here and let Mamba3 load directly.
8
+
9
+ __version__ = "2.3.1+feather-graft"
10
+
11
+ # selective_scan_fn / mamba_inner_fn are shimmed to None — they are NOT used
12
+ # by the Feather training path (which is Mamba3-only). If any import path
13
+ # hits this, it will get a clear AttributeError instead of an obscure ImportError.
14
+ selective_scan_fn = None
15
+ mamba_inner_fn = None
16
+
17
+ # --- triton API compatibility shims -----------------------------------------
18
+ # Version matrix is hostile: torch 2.6 pins triton==3.2.0 because torch._inductor
19
+ # imports AttrsDescriptor from triton.compiler.compiler — removed in triton 3.4+.
20
+ # Grafted Mamba3 (from mamba-ssm main) needs triton.set_allocator and
21
+ # tl.make_tensor_descriptor, both added in triton 3.3+. No single triton version
22
+ # satisfies both simultaneously. We run on triton 3.5.1 (latest, has both mamba3
23
+ # APIs) and shim AttrsDescriptor as a stub dataclass for torch._inductor. The
24
+ # stub is never actually invoked at runtime because the codebase does not use
25
+ # torch.compile — but importing torch._inductor.* still requires the symbol to
26
+ # exist at module load time.
27
+ import triton as _triton # noqa: E402
28
+ if not hasattr(_triton, "set_allocator"):
29
+ def _noop_set_allocator(_fn): # pragma: no cover
30
+ return None
31
+ _triton.set_allocator = _noop_set_allocator
32
+
33
+ import triton.compiler.compiler as _tcc # noqa: E402
34
+ if not hasattr(_tcc, "AttrsDescriptor"):
35
+ class _AttrsDescriptorShim:
36
+ """Stub for torch._inductor compatibility on triton >= 3.4.
37
+ torch._inductor.runtime.hints imports this at module load but the
38
+ constructor is only called inside torch.compile paths. Accept any
39
+ args/kwargs so the import itself succeeds."""
40
+ def __init__(self, *args, **kwargs):
41
+ self.args = args
42
+ self.kwargs = kwargs
43
+
44
+ @classmethod
45
+ def from_hints(cls, *args, **kwargs):
46
+ return cls(*args, **kwargs)
47
+
48
+ _tcc.AttrsDescriptor = _AttrsDescriptorShim
49
+
50
+ # triton_key: removed in triton 3.5, used by torch._inductor.codecache for
51
+ # FxGraphCache key derivation. Return a stable string so caching still works.
52
+ if not hasattr(_tcc, "triton_key"):
53
+ def _triton_key_shim():
54
+ import triton as _t
55
+ return f"triton-{_t.__version__}-shim"
56
+ _tcc.triton_key = _triton_key_shim
57
+
58
+ # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
59
+ # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
60
+ # so fall back to eager on any dynamo failure rather than crashing. This is
61
+ # defense-in-depth against further triton API drift.
62
+ try:
63
+ import torch._dynamo # noqa: F401 — triggers dynamo module init
64
+ torch._dynamo.config.suppress_errors = True
65
+ except Exception: # pragma: no cover
66
+ pass
67
+
68
+ # Expose Mamba3 at top level to match `from mamba_ssm import Mamba3`.
69
+ from mamba_ssm.modules.mamba3 import Mamba3 # noqa: E402