File size: 6,408 Bytes
e317e25
0f7408a
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6618931
 
 
e317e25
b814e8c
e317e25
 
b814e8c
e317e25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel

# Default target is HF Jobs a10g-large (NVIDIA A10G, Ampere GA102, sm_86).
# Override at build time for other cards, e.g. --build-arg FEATHER_GPU_ARCH=sm_90a.
ARG FEATHER_GPU_ARCH=sm_86
ARG FEATHER_TORCH_CUDA_ARCH_LIST=8.6

ENV DEBIAN_FRONTEND=noninteractive \
    PIP_NO_CACHE_DIR=1 \
    PYTHONUNBUFFERED=1 \
    CARGO_HOME=/root/.cargo \
    RUSTUP_HOME=/root/.rustup \
    HTM_CUDA_ARCH=${FEATHER_GPU_ARCH} \
    TORCH_CUDA_ARCH_LIST=${FEATHER_TORCH_CUDA_ARCH_LIST} \
    PATH=/root/.cargo/bin:${PATH}

RUN apt-get update && apt-get install -y --no-install-recommends \
    git curl ca-certificates build-essential pkg-config libssl-dev && \
    rm -rf /var/lib/apt/lists/*

RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --profile minimal --default-toolchain stable

RUN pip install --upgrade pip setuptools wheel && \
    pip install \
      maturin \
      huggingface_hub \
      datasets \
      requests \
      pyarrow \
      rustbpe \
      pandas \
      tiktoken \
      pydantic \
      ninja \
      packaging \
      einops

# Mamba-3 fused CUDA kernel stack (mandatory β€” NO fallback allowed).
#
# We install PRE-BUILT manylinux wheels from the official state-spaces/mamba
# and Dao-AILab/causal-conv1d GitHub releases. Compiling mamba_ssm from source
# on HF Spaces' cpu-basic builder (~16GB RAM) OOMKills even with MAX_JOBS=1 β€”
# nvcc on the templated selective-scan/chunk-scan kernels needs 8–12GB per TU.
#
# Wheel selection for base image pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel:
#   - Python 3.11 (cp311)                       β€” matches PyTorch 2.5.1 image
#   - CUDA 12.x wheels (cu12)                   β€” compatible with CUDA 12.1 base
#   - PyTorch 2.5 ABI (torch2.5)                β€” exact torch match
#   - cxx11abiFALSE                             β€” standard PyTorch pip build
#
# Versions: mamba_ssm 2.3.0 + causal_conv1d 1.6.0 (matching torch2.5 ABI).
# Both are CUDA-compiled, no build toolchain needed
# on the Space builder.
#
# Step A: install the published v2.3.0 prebuilt wheel (compiled CUDA ops
# for selective_scan, layernorm_gated, ssd_*, causal_conv1d, etc).
RUN pip install \
      'https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.0/causal_conv1d-1.6.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' \
      'https://github.com/state-spaces/mamba/releases/download/v2.3.0/mamba_ssm-2.3.0+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl' && \
    python -c "import importlib.metadata as m; print('installed mamba_ssm=' + m.version('mamba_ssm') + ' causal_conv1d=' + m.version('causal_conv1d'))"

#
# Step B: graft the Mamba3 class + its pure-Triton ops subtree from mamba-ssm
# main. v2.3.1 is the latest release but Mamba3 landed post-release; the new
# files under ops/triton/mamba3/ are ALL pure Python @triton.jit kernels with
# zero compiled-CUDA dependencies (verified: every import in that subtree is
# triton/torch/python β€” no .so files, no nvcc). So we install the v2.3.1 wheel
# (for its compiled ops) and overlay the main-branch Mamba3 sources on top.
#
# This avoids the source-build OOM on the cpu-basic HF Space builder and the
# missing-file error the smoke hit on the last attempt.
# Download grafted mamba3 module + triton ops subtree
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    BASE=https://raw.githubusercontent.com/state-spaces/mamba/main && \
    curl -fsSL "$BASE/mamba_ssm/modules/mamba3.py" -o "$SITE/modules/mamba3.py" && \
    mkdir -p "$SITE/ops/triton/mamba3" && \
    for f in __init__.py angle_dt.py mamba3_mimo_rotary_step.py mamba3_mimo_utils.py mamba3_siso_bwd.py mamba3_siso_combined.py mamba3_siso_fwd.py mamba3_siso_step.py utils.py; do \
        curl -fsSL "$BASE/mamba_ssm/ops/triton/mamba3/$f" -o "$SITE/ops/triton/mamba3/$f"; \
    done

# Replace mamba_ssm/__init__.py with a minimal one that only imports Mamba3
# (pure-Triton, works). The shipped __init__.py eagerly imports
# selective_scan_cuda.so which has a libtorch C++ ABI mismatch on this base
# image ("undefined symbol: _ZN3c107WarningC1E..."). Since training only needs
# Mamba3 (grafted from main), we skip all compiled-CUDA imports.
COPY mamba_ssm_init.py /opt/conda/lib/python3.11/site-packages/mamba_ssm/__init__.py

# Structural check (no triton init β€” triton has no GPU on the builder)
RUN SITE=/opt/conda/lib/python3.11/site-packages/mamba_ssm && \
    test -f "$SITE/modules/mamba3.py" && \
    test -f "$SITE/ops/triton/mamba3/mamba3_siso_combined.py" && \
    test -s "$SITE/__init__.py" && \
    echo "mamba3 graft + __init__ override verified"

# Optional tilelang for MIMO path β€” pure-python, cheap; SISO Mamba3 works without.
RUN pip install tilelang || echo "[dockerfile] tilelang optional install failed β€” continuing"

# Triton version decision: FORCE 3.4.0 β€” first line with both mamba3
# APIs (set_allocator + tl.make_tensor_descriptor) while avoiding the 3.5.x
# driver-discovery regression seen on HF A10G (`0 active drivers` despite
# torch.cuda being available). torch 2.5's _inductor expects older Triton
# internals, but mamba_ssm/__init__.py shims AttrsDescriptor as a stub
# before any torch._inductor import path runs, so the incompatibility is
# neutralized. Build-time assert verifies mamba3's two required APIs.
RUN pip install --force-reinstall --no-deps 'triton==3.4.0' && \
    python -c "import triton; from triton import language as tl; \
               assert hasattr(triton, 'set_allocator'), 'missing triton.set_allocator'; \
               assert hasattr(tl, 'make_tensor_descriptor'), 'missing tl.make_tensor_descriptor'; \
               print(f'triton={triton.__version__} set_allocator+make_tensor_descriptor OK, AttrsDescriptor shimmed in mamba_ssm/__init__.py')"

WORKDIR /workspace
COPY overlay /workspace/feather
COPY entrypoint.py /app/entrypoint.py
WORKDIR /workspace/feather

RUN python -m py_compile hydra/training.py prepare.py train.py && \
    bash -n scripts/run_domain_expanded_pretrain.sh

RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
    echo "building htm_rust GPU kernels for HTM_CUDA_ARCH=${HTM_CUDA_ARCH} TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}" && \
    maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
    pip install htm_rust/target/wheels/htm_rust-*.whl

CMD ["python", "/app/entrypoint.py"]