File size: 4,832 Bytes
42c0d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env bash
# scripts/install_flash_attn.sh
#
# Build/install flash-attention into the active conda venv, detached
# under nohup so SSH disconnects don't kill the build (which can take
# 30-60 minutes on a fresh checkout).
#
# Usage:
#   conda activate causalgrok                       # do this in your shell first
#   bash scripts/install_flash_attn.sh              # auto-pick FLASH_ATTN_VERSION
#   FLASH_ATTN_VERSION=2.7.4 bash scripts/install_flash_attn.sh
#
# The build/install logs land in:
#   logs/install/<UTC-stamp>_flash_attn/{logs/train.log,logs/train.err,run.pid,env.txt}
# Watch progress with:
#   tail -f logs/install/<stamp>_flash_attn/logs/train.log
# Kill if needed:
#   kill "$(cat logs/install/<stamp>_flash_attn/run.pid)"
#
# Requirements:
#   - active conda env with PyTorch already installed (matching CUDA)
#   - nvcc / CUDA toolkit visible (the build needs it)
#   - Ampere+ GPU (A100/H100/RTX30xx/40xx). On A100 (sm_80) flash-attn
#     ships prebuilt wheels for most PyTorchΓ—CUDA combos, so the
#     install is usually fast.

set -euo pipefail

ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "${ROOT}"
source "${ROOT}/scripts/lib/nohup_runner.sh"

FLASH_ATTN_VERSION="${FLASH_ATTN_VERSION:-2.7.4}"

# ── Sanity checks ─────────────────────────────────────────────────────
if ! command -v python >/dev/null 2>&1; then
    echo "  python not found in PATH. Activate the conda env first:" >&2
    echo "      conda activate causalgrok" >&2
    exit 1
fi
if [[ -z "${CONDA_PREFIX:-}" ]]; then
    echo "  CONDA_PREFIX is empty β€” no conda env appears to be active." >&2
    echo "      conda activate causalgrok" >&2
    exit 1
fi

# ── Where the install logs live ───────────────────────────────────────
# logs/install/ is reserved for environment / dependency setup output.
# It is deliberately separate from experiments/runs/ so an install
# never gets confused with a training run.
STAMP="$(date -u +%Y%m%d-%H%M%S)"
INSTALL_DIR="logs/install/${STAMP}_flash_attn"
mkdir -p "${INSTALL_DIR}"

# Snapshot the env we're installing into (so we can debug later)
{
    echo "# captured: $(date -u +%FT%TZ)"
    echo "# host:        $(hostname)"
    echo "# CONDA_PREFIX:${CONDA_PREFIX}"
    echo "# CONDA_DEFAULT_ENV: ${CONDA_DEFAULT_ENV:-}"
    echo "# python:    $(python --version 2>&1)"
    echo "# which python: $(command -v python)"
    if command -v nvcc >/dev/null 2>&1; then
        echo "# nvcc:    $(nvcc --version | tail -1)"
    else
        echo "# nvcc:    NOT FOUND (CUDA toolkit may be missing)"
    fi
    if command -v nvidia-smi >/dev/null 2>&1; then
        echo "# nvidia-smi:"
        nvidia-smi --query-gpu=name,driver_version,compute_cap --format=csv,noheader \
            | sed 's/^/#   /'
    fi
    python -c "import torch, sys; print(f'# torch:      {torch.__version__}'); print(f'# torch.cuda: {torch.version.cuda}'); print(f'# CUDA avail: {torch.cuda.is_available()}')" 2>/dev/null \
        || echo "# torch:      NOT INSTALLED β€” install torch first"
} > "${INSTALL_DIR}/env.txt"

cat "${INSTALL_DIR}/env.txt"
echo

# ── Pre-flight: torch must be present ─────────────────────────────────
if ! python -c "import torch" >/dev/null 2>&1; then
    echo "  torch is not installed in this env. Install it first:" >&2
    echo "      pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118" >&2
    exit 1
fi

# ── Launch the build under nohup ──────────────────────────────────────
# packaging/ninja/wheel are required for the source build path; the
# resolver will pick the prebuilt wheel where one exists.
echo "Starting flash-attn ${FLASH_ATTN_VERSION} install (detached)..."
launch_detached "${INSTALL_DIR}" \
    bash -c "
        set -euo pipefail
        echo '== installing build deps =='
        pip install --upgrade pip wheel setuptools packaging ninja
        echo
        echo '== installing flash-attn ${FLASH_ATTN_VERSION} =='
        pip install --no-build-isolation 'flash-attn==${FLASH_ATTN_VERSION}'
        echo
        echo '== sanity check =='
        python -c 'import flash_attn; print(\"flash-attn version:\", flash_attn.__version__)'
        echo 'DONE'
    "

echo
echo "Outputs:"
echo "  env snapshot : ${INSTALL_DIR}/env.txt"
echo "  build log    : ${INSTALL_DIR}/logs/train.log"
echo "  build err    : ${INSTALL_DIR}/logs/train.err"
echo "  PID          : ${INSTALL_DIR}/run.pid"
echo
echo "Verify completion with:"
echo "  grep -E 'DONE|ERROR' ${INSTALL_DIR}/logs/train.log"