File size: 7,622 Bytes
951f760
20f5256
83c8f8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec4aea
 
 
 
 
83c8f8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec4aea
 
 
 
 
83c8f8c
 
 
3c48dfc
 
1ec4aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c48dfc
 
 
 
 
 
 
 
 
 
83c8f8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec4aea
 
 
 
20f5256
 
 
 
 
83c8f8c
951f760
 
83c8f8c
1ec4aea
 
 
951f760
 
83c8f8c
20f5256
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel

ENV DEBIAN_FRONTEND=noninteractive \
    PIP_NO_CACHE_DIR=1 \
    PYTHONUNBUFFERED=1 \
    CARGO_HOME=/root/.cargo \
    RUSTUP_HOME=/root/.rustup \
    PATH=/root/.cargo/bin:${PATH}

RUN apt-get update && apt-get install -y --no-install-recommends \
    git curl ca-certificates build-essential pkg-config libssl-dev && \
    rm -rf /var/lib/apt/lists/*

RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable

RUN pip install --upgrade pip setuptools wheel && \
    pip install \
      maturin \
      huggingface_hub \
      datasets \
      requests \
      pyarrow \
      rustbpe \
      pandas \
      tiktoken \
      pydantic \
      ninja \
      packaging \
      einops \
      cuda-python

# Mamba-3 fused CUDA kernel stack (mandatory β€” NO fallback allowed).
#
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β€”
# nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
#
# Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
#   - Python 3.11 (cp311)                       β€” matches PyTorch 2.6.0 image
#   - CUDA 12.x wheels (cu12)                   β€” matches host CUDA 12.4
#   - PyTorch 2.6 ABI (torch2.6)                β€” exact torch match
#   - cxx11abiFALSE                             β€” standard PyTorch pip build
#
# Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
# 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
# on the Space builder.
#
# Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
RUN pip install \
      'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
      'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
    python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"

#
# Step B: graft the Mamba3 class + its pure-Python/Triton helper tree from the
# Mamba-3 release commit. Do NOT graft from main: main now requires Triton APIs
# such as tl.make_tensor_descriptor that force Triton 3.5.x, and Triton 3.5.x
# fails driver discovery on HF A10 Jobs with "0 active drivers". The release
# commit works with torch 2.6's matching Triton 3.2 runtime.
#
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
# missing-file error the smoke hit on the last attempt.
# Download grafted mamba3 module + triton ops subtree
COPY mamba3_siso_combined_torch_fallback.py /tmp/mamba3_siso_combined_torch_fallback.py
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    BASE=https://raw.githubusercontent.com/state-spaces/mamba/5235bdcd3fca41e336f17322acbfe8d8abb6c93f && \
    curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
    mkdir -p "$SITE/ops/triton/mamba3" "$SITE/ops/tilelang/mamba3" "$SITE/ops/cute/mamba3" && \
    for f in angle_cumsum.py k_activations.py layer_norm.py layernorm_gated.py selective_state_update.py softplus.py ssd_bmm.py ssd_chunk_scan.py ssd_chunk_state.py ssd_combined.py ssd_state_passing.py; do \
        curl -fsSL "$BASE/mamba_ssm/ops/triton/$f" -o "$SITE/ops/triton/$f"; \
    done && \
    for f in angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
        curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
    done && \
    for f in mamba3_mimo.py mamba3_mimo_bwd.py mamba3_mimo_fwd.py; do \
        curl -fsSL "$BASE/mamba_ssm/ops/tilelang/mamba3/$f" -o "$SITE/ops/tilelang/mamba3/$f"; \
    done && \
    curl -fsSL "$BASE/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py" -o "$SITE/ops/cute/mamba3/mamba3_step_fn.py" && \
    touch "$SITE/ops/triton/mamba3/__init__.py" "$SITE/ops/tilelang/__init__.py" \
          "$SITE/ops/tilelang/mamba3/__init__.py" "$SITE/ops/cute/__init__.py" \
          "$SITE/ops/cute/mamba3/__init__.py" && \
    python - <<'PY'
from pathlib import Path
path = Path('/opt/conda/lib/python3.11/site-packages/mamba_ssm/modules/mamba3.py')
text = path.read_text()
text = text.replace(
    'from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn',
    'try:\n    from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn\nexcept Exception:\n    mamba3_step_fn = None',
)
text = text.replace(
    '        # in_proj\n        zxBCdt = self.in_proj(u)',
    '        if mamba3_step_fn is None:\n            raise RuntimeError("Mamba3 step() requires optional CUTLASS/CuTe dependencies")\n\n        # in_proj\n        zxBCdt = self.in_proj(u)',
)
path.write_text(text)
PY

# Triton 3.2 is required for A10 driver discovery, but upstream Mamba3 SISO
# forward uses tl.make_tensor_descriptor (Triton 3.5 API). Replace only the
# combined SISO wrapper with a CUDA Torch/autograd implementation; keep the
# public mamba3_siso_combined API stable for Mamba3.forward.
RUN cp /tmp/mamba3_siso_combined_torch_fallback.py \
       /opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py && \
    python -m py_compile /opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py

# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
# (pure-Triton, works). The shipped __init__.py eagerly imports
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py

# Structural check (no triton init β€” triton has no GPU on the builder)
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    test -f "$SITE/modules/mamba3.py" && \
    test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
    test -s "$SITE/__init__.py" && \
    echo "mamba3 graft + __init__ override verified"

# Optional tilelang for MIMO path β€” pure-python, cheap; SISO Mamba3 works without.
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β€” continuing"

# Keep Triton matched to torch 2.6.0. A10 diagnostics showed Triton 3.5.1
# reports 0 active drivers while torch 2.6 + Triton 3.2.0 sees the A10G.
RUN pip install --force-reinstall --no-deps 'triton==3.2.0' && \
    python -c "import triton; print(f'triton={triton.__version__} torch2.6-compatible')"

WORKDIR /workspace
COPY overlay /workspace/feather
COPY entrypoint.py /app/entrypoint.py
WORKDIR /workspace/feather

RUN python -m py_compile hydra/training.py prepare.py train.py && \
    bash -n scripts/run_domain_expanded_pretrain.sh

ARG HTM_CUDA_ARCH=sm_86
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
    export HTM_CUDA_ARCH=${HTM_CUDA_ARCH} && \
    maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
    pip install htm_rust/target/wheels/htm_rust-*.whl

CMD ["python", "/app/entrypoint.py"]