Spaces:
Runtime error
Runtime error
| FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel | |
| ENV DEBIAN_FRONTEND=noninteractive \ | |
| PIP_NO_CACHE_DIR=1 \ | |
| PYTHONUNBUFFERED=1 \ | |
| CARGO_HOME=/root/.cargo \ | |
| RUSTUP_HOME=/root/.rustup \ | |
| PATH=/root/.cargo/bin:${PATH} | |
| RUN apt-get update && apt-get install -y --no-install-recommends \ | |
| git curl ca-certificates build-essential pkg-config libssl-dev && \ | |
| rm -rf /var/lib/apt/lists/* | |
| RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable | |
| RUN pip install --upgrade pip setuptools wheel && \ | |
| pip install \ | |
| maturin \ | |
| huggingface_hub \ | |
| datasets \ | |
| requests \ | |
| pyarrow \ | |
| rustbpe \ | |
| pandas \ | |
| tiktoken \ | |
| pydantic \ | |
| ninja \ | |
| packaging \ | |
| einops \ | |
| cuda-python | |
| # Mamba-3 fused CUDA kernel stack (mandatory β NO fallback allowed). | |
| # | |
| # We install PRE-BUILT manylinux wheels from the official state-spaces/mamba | |
| # and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source | |
| # on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β | |
| # nvcc on the templated selective-scan/chunk-scan kernels needs 8β12GB per TU. | |
| # | |
| # Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel: | |
| # - Python 3.11 (cp311) β matches PyTorch 2.6.0 image | |
| # - CUDA 12.x wheels (cu12) β matches host CUDA 12.4 | |
| # - PyTorch 2.6 ABI (torch2.6) β exact torch match | |
| # - cxx11abiFALSE β standard PyTorch pip build | |
| # | |
| # Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d | |
| # 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed | |
| # on the Space builder. | |
| # | |
| # Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops | |
| # for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc). | |
| RUN pip install \ | |
| 'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \ | |
| 'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \ | |
| python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))" | |
| # | |
| # Step B: graft the Mamba3 class + its pure-Python/Triton helper tree from the | |
| # Mamba-3 release commit. Do NOT graft from main: main now requires Triton APIs | |
| # such as tl.make_tensor_descriptor that force Triton 3.5.x, and Triton 3.5.x | |
| # fails driver discovery on HF A10 Jobs with "0 active drivers". The release | |
| # commit works with torch 2.6's matching Triton 3.2 runtime. | |
| # | |
| # This avoids the source-build OOM on the cpu-basic HF Space builder and the | |
| # missing-file error the smoke hit on the last attempt. | |
| # Download grafted mamba3 module + triton ops subtree | |
| COPY mamba3_siso_combined_torch_fallback.py /tmp/mamba3_siso_combined_torch_fallback.py | |
| RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \ | |
| BASE=https://raw.githubusercontent.com/state-spaces/mamba/5235bdcd3fca41e336f17322acbfe8d8abb6c93f && \ | |
| curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \ | |
| mkdir -p "$SITE/ops/triton/mamba3" "$SITE/ops/tilelang/mamba3" "$SITE/ops/cute/mamba3" && \ | |
| for f in angle_cumsum.py k_activations.py layer_norm.py layernorm_gated.py selective_state_update.py softplus.py ssd_bmm.py ssd_chunk_scan.py ssd_chunk_state.py ssd_combined.py ssd_state_passing.py; do \ | |
| curl -fsSL "$BASE/mamba_ssm/ops/triton/$f" -o "$SITE/ops/triton/$f"; \ | |
| done && \ | |
| for f in angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \ | |
| curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \ | |
| done && \ | |
| for f in mamba3_mimo.py mamba3_mimo_bwd.py mamba3_mimo_fwd.py; do \ | |
| curl -fsSL "$BASE/mamba_ssm/ops/tilelang/mamba3/$f" -o "$SITE/ops/tilelang/mamba3/$f"; \ | |
| done && \ | |
| curl -fsSL "$BASE/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py" -o "$SITE/ops/cute/mamba3/mamba3_step_fn.py" && \ | |
| touch "$SITE/ops/triton/mamba3/__init__.py" "$SITE/ops/tilelang/__init__.py" \ | |
| "$SITE/ops/tilelang/mamba3/__init__.py" "$SITE/ops/cute/__init__.py" \ | |
| "$SITE/ops/cute/mamba3/__init__.py" && \ | |
| python - <<'PY' | |
| from pathlib import Path | |
| path = Path('/opt/conda/lib/python3.11/site-packages/mamba_ssm/modules/mamba3.py') | |
| text = path.read_text() | |
| text = text.replace( | |
| 'from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn', | |
| 'try:\n from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn\nexcept Exception:\n mamba3_step_fn = None', | |
| ) | |
| text = text.replace( | |
| ' # in_proj\n zxBCdt = self.in_proj(u)', | |
| ' if mamba3_step_fn is None:\n raise RuntimeError("Mamba3 step() requires optional CUTLASS/CuTe dependencies")\n\n # in_proj\n zxBCdt = self.in_proj(u)', | |
| ) | |
| path.write_text(text) | |
| PY | |
| # Triton 3.2 is required for A10 driver discovery, but upstream Mamba3 SISO | |
| # forward uses tl.make_tensor_descriptor (Triton 3.5 API). Replace only the | |
| # combined SISO wrapper with a CUDA Torch/autograd implementation; keep the | |
| # public mamba3_siso_combined API stable for Mamba3.forward. | |
| RUN cp /tmp/mamba3_siso_combined_torch_fallback.py \ | |
| /opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py && \ | |
| python -m py_compile /opt/conda/lib/python3.11/site-packages/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py | |
| # Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3 | |
| # (pure-Triton, works). The shipped __init__.py eagerly imports | |
| # selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base | |
| # image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs | |
| # Mamba3 (grafted from main), we skip all compiled-CUDA imports. | |
| COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py | |
| # Structural check (no triton init β triton has no GPU on the builder) | |
| RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \ | |
| test -f "$SITE/modules/mamba3.py" && \ | |
| test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \ | |
| test -s "$SITE/__init__.py" && \ | |
| echo "mamba3 graft + __init__ override verified" | |
| # Optional tilelang for MIMO path β pure-python, cheap; SISO Mamba3 works without. | |
| RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β continuing" | |
| # Keep Triton matched to torch 2.6.0. A10 diagnostics showed Triton 3.5.1 | |
| # reports 0 active drivers while torch 2.6 + Triton 3.2.0 sees the A10G. | |
| RUN pip install --force-reinstall --no-deps 'triton==3.2.0' && \ | |
| python -c "import triton; print(f'triton={triton.__version__} torch2.6-compatible')" | |
| WORKDIR /workspace | |
| COPY overlay /workspace/feather | |
| COPY entrypoint.py /app/entrypoint.py | |
| WORKDIR /workspace/feather | |
| RUN python -m py_compile hydra/training.py prepare.py train.py && \ | |
| bash -n scripts/run_domain_expanded_pretrain.sh | |
| ARG HTM_CUDA_ARCH=sm_86 | |
| RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \ | |
| export HTM_CUDA_ARCH=${HTM_CUDA_ARCH} && \ | |
| maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \ | |
| pip install htm_rust/target/wheels/htm_rust-*.whl | |
| CMD ["python", "/app/entrypoint.py"] | |