Spaces:
Runtime error
Runtime error
File size: 7,622 Bytes
951f760 20f5256 83c8f8c 1ec4aea 83c8f8c 1ec4aea 83c8f8c 3c48dfc 1ec4aea 3c48dfc 83c8f8c 1ec4aea 20f5256 83c8f8c 951f760 83c8f8c 1ec4aea 951f760 83c8f8c 20f5256 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
ENV DEBIAN_FRONTEND=noninteractive \
PIP_NO_CACHE_DIR=1 \
PYTHONUNBUFFERED=1 \
CARGO_HOME=/root/.cargo \
RUSTUP_HOME=/root/.rustup \
PATH=/root/.cargo/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
git curl ca-certificates build-essential pkg-config libssl-dev && \
rm -rf /var/lib/apt/lists/*
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
RUN pip install --upgrade pip setuptools wheel && \
pip install \
maturin \
huggingface_hub \
datasets \
requests \
pyarrow \
rustbpe \
pandas \
tiktoken \
pydantic \
ninja \
packaging \
einops \
cuda-python
# Mamba-3 fused CUDA kernel stack (mandatory β NO fallback allowed).
#
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β
# nvcc on the templated selective-scan/chunk-scan kernels needs 8β12GB per TU.
#
# Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
# - Python 3.11 (cp311) β matches PyTorch 2.6.0 image
# - CUDA 12.x wheels (cu12) β matches host CUDA 12.4
# - PyTorch 2.6 ABI (torch2.6) β exact torch match
# - cxx11abiFALSE β standard PyTorch pip build
#
# Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
# 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
# on the Space builder.
#
# Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
RUN pip install \
'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
#
# Step B: graft the Mamba3 class + its pure-Python/Triton helper tree from the
# Mamba-3 release commit. Do NOT graft from main: main now requires Triton APIs
# such as tl.make_tensor_descriptor that force Triton 3.5.x, and Triton 3.5.x
# fails driver discovery on HF A10 Jobs with "0 active drivers". The release
# commit works with torch 2.6's matching Triton 3.2 runtime.
#
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
# missing-file error the smoke hit on the last attempt.
# Download grafted mamba3 module + triton ops subtree
COPY mamba3_siso_combined_torch_fallback.py /tmp/mamba3_siso_combined_torch_fallback.py
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
BASE=https://raw.githubusercontent.com/state-spaces/mamba/5235bdcd3fca41e336f17322acbfe8d8abb6c93f && \
curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
mkdir -p "$SITE/ops/triton/mamba3" "$SITE/ops/tilelang/mamba3" "$SITE/ops/cute/mamba3" && \
for f in angle_cumsum.py k_activations.py layer_norm.py layernorm_gated.py selective_state_update.py softplus.py ssd_bmm.py ssd_chunk_scan.py ssd_chunk_state.py ssd_combined.py ssd_state_passing.py; do \
curl -fsSL "$BASE/mamba_ssm/ops/triton/$f" -o "$SITE/ops/triton/$f"; \
done && \
for f in angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
done && \
for f in mamba3_mimo.py mamba3_mimo_bwd.py mamba3_mimo_fwd.py; do \
curl -fsSL "$BASE/mamba_ssm/ops/tilelang/mamba3/$f" -o "$SITE/ops/tilelang/mamba3/$f"; \
done && \
curl -fsSL "$BASE/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py" -o "$SITE/ops/cute/mamba3/mamba3_step_fn.py" && \
touch "$SITE/ops/triton/mamba3/__init__.py" "$SITE/ops/tilelang/__init__.py" \
"$SITE/ops/tilelang/mamba3/__init__.py" "$SITE/ops/cute/__init__.py" \
"$SITE/ops/cute/mamba3/__init__.py" && \
python - <<'PY'
from pathlib import Path
path = Path('/opt/conda/lib/python3.11/site-packages/mamba_ssm/modules/mamba3.py')
text = path.read_text()
text = text.replace(
'from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn',
'try:\n from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn\nexcept Exception:\n mamba3_step_fn = None',
)
text = text.replace(
' # in_proj\n zxBCdt = self.in_proj(u)',
' if mamba3_step_fn is None:\n raise RuntimeError("Mamba3 step() requires optional CUTLASS/CuTe dependencies")\n\n # in_proj\n zxBCdt = self.in_proj(u)',
)
path.write_text(text)
PY
# Triton 3.2 is required for A10 driver discovery, but upstream Mamba3 SISO
# forward uses tl.make_tensor_descriptor (Triton 3.5 API). Replace only the
# combined SISO wrapper with a CUDA Torch/autograd implementation; keep the
# public mamba3_siso_combined API stable for Mamba3.forward.
RUN cp /tmp/mamba3_siso_combined_torch_fallback.py \
/opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py && \
python -m py_compile /opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py
# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
# (pure-Triton, works). The shipped __init__.py eagerly imports
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
# Structural check (no triton init β triton has no GPU on the builder)
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
test -f "$SITE/modules/mamba3.py" && \
test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
test -s "$SITE/__init__.py" && \
echo "mamba3 graft + __init__ override verified"
# Optional tilelang for MIMO path β pure-python, cheap; SISO Mamba3 works without.
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β continuing"
# Keep Triton matched to torch 2.6.0. A10 diagnostics showed Triton 3.5.1
# reports 0 active drivers while torch 2.6 + Triton 3.2.0 sees the A10G.
RUN pip install --force-reinstall --no-deps 'triton==3.2.0' && \
python -c "import triton; print(f'triton={triton.__version__} torch2.6-compatible')"
WORKDIR /workspace
COPY overlay /workspace/feather
COPY entrypoint.py /app/entrypoint.py
WORKDIR /workspace/feather
RUN python -m py_compile hydra/training.py prepare.py train.py && \
bash -n scripts/run_domain_expanded_pretrain.sh
ARG HTM_CUDA_ARCH=sm_86
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
export HTM_CUDA_ARCH=${HTM_CUDA_ARCH} && \
maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
pip install htm_rust/target/wheels/htm_rust-*.whl
CMD ["python", "/app/entrypoint.py"]
|