| FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel |
|
|
| ARG HTM_CUDA_ARCH=sm_86 |
| ARG TORCH_CUDA_ARCH_LIST=8.6 |
|
|
| ENV DEBIAN_FRONTEND=noninteractive \ |
| PIP_NO_CACHE_DIR=1 \ |
| PYTHONUNBUFFERED=1 \ |
| CARGO_HOME=/root/.cargo \ |
| RUSTUP_HOME=/root/.rustup \ |
| HTM_CUDA_ARCH=${HTM_CUDA_ARCH} \ |
| TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} \ |
| PATH=/root/.cargo/bin:${PATH} |
|
|
| RUN apt-get update && apt-get install -y --no-install-recommends \ |
| git curl ca-certificates build-essential pkg-config libssl-dev && \ |
| rm -rf /var/lib/apt/lists/* |
|
|
| RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable |
|
|
| RUN pip install --upgrade pip setuptools wheel && \ |
| pip install \ |
| maturin \ |
| huggingface_hub \ |
| datasets \ |
| requests \ |
| pyarrow \ |
| rustbpe \ |
| pandas \ |
| tiktoken \ |
| pydantic \ |
| ninja \ |
| packaging \ |
| einops |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| RUN pip install \ |
| 'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \ |
| 'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \ |
| python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \ |
| BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \ |
| curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \ |
| mkdir -p "$SITE/ops/triton/mamba3" && \ |
| for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \ |
| curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \ |
| done |
|
|
| |
| |
| |
| |
| |
| COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py |
|
|
| |
| RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \ |
| test -f "$SITE/modules/mamba3.py" && \ |
| test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \ |
| test -s "$SITE/__init__.py" && \ |
| echo "mamba3 graft + __init__ override verified" |
|
|
| |
| RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β continuing" |
|
|
| |
| |
| |
| |
| |
| |
| RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \ |
| python -c "import triton; from triton import language as tl; \ |
| assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \ |
| assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \ |
| print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')" |
|
|
| WORKDIR /workspace |
| COPY overlay /workspace/feather |
| COPY overlay/scripts /app/scripts |
| COPY entrypoint.py /app/entrypoint.py |
| WORKDIR /workspace/feather |
|
|
| RUN test -f /app/scripts/htm_gpu_micro_canary.py && \ |
| python -m py_compile hydra/training.py prepare.py train.py /app/scripts/htm_gpu_micro_canary.py && \ |
| bash -n scripts/run_domain_expanded_pretrain.sh |
|
|
| RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \ |
| echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \ |
| if maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml; then \ |
| pip install htm_rust/target/wheels/htm_rust-*.whl && \ |
| python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegionGpu'), 'htm_rust missing HTMRegionGpu GPU binding'"; \ |
| else \ |
| echo "[dockerfile] htm_rust GPU wheel build failed; building CPU wheel so A10 compromise/fresh-eval jobs can still run with explicit CPU fallback" && \ |
| rm -rf htm_rust/target/wheels && \ |
| maturin build --release --manifest-path htm_rust/Cargo.toml && \ |
| pip install htm_rust/target/wheels/htm_rust-*.whl && \ |
| python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegion'), 'htm_rust missing CPU HTMRegion binding'"; \ |
| fi |
|
|
| CMD ["python", "/app/entrypoint.py"] |
|
|