File size: 6,933 Bytes
d39539e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel

ARG HTM_CUDA_ARCH=sm_86
ARG TORCH_CUDA_ARCH_LIST=8.6

ENV DEBIAN_FRONTEND=noninteractive \
    PIP_NO_CACHE_DIR=1 \
    PYTHONUNBUFFERED=1 \
    CARGO_HOME=/root/.cargo \
    RUSTUP_HOME=/root/.rustup \
    HTM_CUDA_ARCH=${HTM_CUDA_ARCH} \
    TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} \
    PATH=/root/.cargo/bin:${PATH}

RUN apt-get update && apt-get install -y --no-install-recommends \
    git curl ca-certificates build-essential pkg-config libssl-dev && \
    rm -rf /var/lib/apt/lists/*

RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable

RUN pip install --upgrade pip setuptools wheel && \
    pip install \
      maturin \
      huggingface_hub \
      datasets \
      requests \
      pyarrow \
      rustbpe \
      pandas \
      tiktoken \
      pydantic \
      ninja \
      packaging \
      einops

# Mamba-3 fused CUDA kernel stack (mandatory β€” NO fallback allowed).
#
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β€”
# nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
#
# Wheel selection for base image pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel:
#   - Python 3.11 (cp311)                       β€” matches PyTorch 2.6.0 image
#   - CUDA 12.x wheels (cu12)                   β€” matches host CUDA 12.4
#   - PyTorch 2.6 ABI (torch2.6)                β€” exact torch match
#   - cxx11abiFALSE                             β€” standard PyTorch pip build
#
# Versions: mamba_ssm 2.3.1 (first stable with Mamba3 class) + causal_conv1d
# 1.6.1.post4 (matching ABI). Both are CUDA-compiled, no build toolchain needed
# on the Space builder.
#
# Step A: install the published v2.3.1 prebuilt wheel (compiled CUDA ops
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
RUN pip install \
      'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
      'https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
    python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"

#
# Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
# main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
# files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
# zero compiled-CUDA dependencies (verified: every import in that subtree is
# triton/torch/python β€” no .so files, no nvcc). So we install the v2.3.1 wheel
# (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
#
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
# missing-file error the smoke hit on the last attempt.
# Download grafted mamba3 module + triton ops subtree
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
    curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
    mkdir -p "$SITE/ops/triton/mamba3" && \
    for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
        curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
    done

# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
# (pure-Triton, works). The shipped __init__.py eagerly imports
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py

# Structural check (no triton init β€” triton has no GPU on the builder)
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    test -f "$SITE/modules/mamba3.py" && \
    test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
    test -s "$SITE/__init__.py" && \
    echo "mamba3 graft + __init__ override verified"

# Optional tilelang for MIMO path β€” pure-python, cheap; SISO Mamba3 works without.
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β€” continuing"

# Triton version decision: FORCE 3.5.1 β€” the only version with both mamba3
# APIs (set_allocator + tl.make_tensor_descriptor). torch 2.6's _inductor
# imports AttrsDescriptor from triton.compiler.compiler which was removed in
# triton 3.4+, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
# before any torch._inductor import path runs, so the incompatibility is
# neutralized. Build-time assert verifies mamba3's two required APIs.
RUN pip install --force-reinstall --no-deps 'triton==3.5.1' && \
    python -c "import triton; from triton import language as tl; \
               assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
               assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
               print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"

WORKDIR /workspace
COPY overlay /workspace/feather
COPY overlay/scripts /app/scripts
COPY entrypoint.py /app/entrypoint.py
WORKDIR /workspace/feather

RUN test -f /app/scripts/htm_gpu_micro_canary.py && \
    python -m py_compile hydra/training.py prepare.py train.py /app/scripts/htm_gpu_micro_canary.py && \
    bash -n scripts/run_domain_expanded_pretrain.sh

RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
    echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \
    if maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml; then \
        pip install htm_rust/target/wheels/htm_rust-*.whl && \
        python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegionGpu'), 'htm_rust missing HTMRegionGpu GPU binding'"; \
    else \
        echo "[dockerfile] htm_rust GPU wheel build failed; building CPU wheel so A10 compromise/fresh-eval jobs can still run with explicit CPU fallback" && \
        rm -rf htm_rust/target/wheels && \
        maturin build --release --manifest-path htm_rust/Cargo.toml && \
        pip install htm_rust/target/wheels/htm_rust-*.whl && \
        python -c "import htm_rust; assert hasattr(htm_rust, 'HTMRegion'), 'htm_rust missing CPU HTMRegion binding'"; \
    fi

CMD ["python", "/app/entrypoint.py"]