Spaces:
Running on A10G
Running on A10G
File size: 6,933 Bytes
d39539e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
ARG HTM_CUDA_ARCH=sm_86
ARG TORCH_CUDA_ARCH_LIST=8.6
ENV DEBIAN_FRONTEND=noninteractive \
PIP_NO_CACHE_DIR=1 \
PYTHONUNBUFFERED=1 \
CARGO_HOME=/root/.cargo \
RUSTUP_HOME=/root/.rustup \
HTM_CUDA_ARCH=${HTM_CUDA_ARCH} \
TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} \
PATH=/root/.cargo/bin:${PATH}
RUN apt-get update && apt-get install -y --no-install-recommends \
git curl ca-certificates build-essential pkg-config libssl-dev && \
rm -rf /var/lib/apt/lists/*
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable
RUN pip install --upgrade pip setuptools wheel && \
pip install \
maturin \
huggingface_hub \
datasets \
requests \
pyarrow \
rustbpe \
pandas \
tiktoken \
pydantic \
ninja \
packaging \
einops
# Mamba-3 fused CUDA kernel stack (mandatory β NO fallback allowed).
#
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β
# nvcc on the templated selective-scan/chunk-scan kernels needs 8β12GB per TU.
#
# Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
# - Python 3.11 (cp311) β matches PyTorch 2.6.0 image
# - CUDA 12.x wheels (cu12) β matches host CUDA 12.4
# - PyTorch 2.6 ABI (torch2.6) β exact torch match
# - cxx11abiFALSE β standard PyTorch pip build
#
# Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
# 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
# on the Space builder.
#
# Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
RUN pip install \
'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"
#
# Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
# main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
# files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
# zero compiled-CUDA dependencies (verified: every import in that subtree is
# triton/torch/python β no .so files, no nvcc). So we install the v2.3.1 wheel
# (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
#
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
# missing-file error the smoke hit on the last attempt.
# Download grafted mamba3 module + triton ops subtree
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
mkdir -p "$SITE/ops/triton/mamba3" && \
for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
done
# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
# (pure-Triton, works). The shipped __init__.py eagerly imports
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py
# Structural check (no triton init β triton has no GPU on the builder)
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
test -f "$SITE/modules/mamba3.py" && \
test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
test -s "$SITE/__init__.py" && \
echo "mamba3 graft + __init__ override verified"
# Optional tilelang for MIMO path β pure-python, cheap; SISO Mamba3 works without.
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β continuing"
# Triton version decision: FORCE 3.5.1 β the only version with both mamba3
# APIs (set_allocator + tl.make_tensor_descriptor). torch 2.6's _inductor
# imports AttrsDescriptor from triton.compiler.compiler which was removed in
# triton 3.4+, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
# before any torch._inductor import path runs, so the incompatibility is
# neutralized. Build-time assert verifies mamba3's two required APIs.
RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \
python -c "import triton; from triton import language as tl; \
assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"
WORKDIR /workspace
COPY overlay /workspace/feather
COPY overlay/scripts /app/scripts
COPY entrypoint.py /app/entrypoint.py
WORKDIR /workspace/feather
RUN test -f /app/scripts/htm_gpu_micro_canary.py && \
python -m py_compile hydra/training.py prepare.py train.py /app/scripts/htm_gpu_micro_canary.py && \
bash -n scripts/run_domain_expanded_pretrain.sh
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \
if maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml; then \
pip install htm_rust/target/wheels/htm_rust-*.whl && \
python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegionGpu'), 'htm_rust missing HTMRegionGpu GPU binding'"; \
else \
echo "[dockerfile] htm_rust GPU wheel build failed; building CPU wheel so A10 compromise/fresh-eval jobs can still run with explicit CPU fallback" && \
rm -rf htm_rust/target/wheels && \
maturin build --release --manifest-path htm_rust/Cargo.toml && \
pip install htm_rust/target/wheels/htm_rust-*.whl && \
python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegion'), 'htm_rust missing CPU HTMRegion binding'"; \
fi
CMD ["python", "/app/entrypoint.py"]
|