CausalGrok / code /scripts /install_flash_attn.sh
nileshsarkar-ai's picture
Upload code/scripts
42c0d23 verified
#!/usr/bin/env bash
# scripts/install_flash_attn.sh
#
# Build/install flash-attention into the active conda venv, detached
# under nohup so SSH disconnects don't kill the build (which can take
# 30-60 minutes on a fresh checkout).
#
# Usage:
# conda activate causalgrok # do this in your shell first
# bash scripts/install_flash_attn.sh # auto-pick FLASH_ATTN_VERSION
# FLASH_ATTN_VERSION=2.7.4 bash scripts/install_flash_attn.sh
#
# The build/install logs land in:
# logs/install/<UTC-stamp>_flash_attn/{logs/train.log,logs/train.err,run.pid,env.txt}
# Watch progress with:
# tail -f logs/install/<stamp>_flash_attn/logs/train.log
# Kill if needed:
# kill "$(cat logs/install/<stamp>_flash_attn/run.pid)"
#
# Requirements:
# - active conda env with PyTorch already installed (matching CUDA)
# - nvcc / CUDA toolkit visible (the build needs it)
# - Ampere+ GPU (A100/H100/RTX30xx/40xx). On A100 (sm_80) flash-attn
# ships prebuilt wheels for most PyTorchΓ—CUDA combos, so the
# install is usually fast.
set -euo pipefail
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "${ROOT}"
source "${ROOT}/scripts/lib/nohup_runner.sh"
FLASH_ATTN_VERSION="${FLASH_ATTN_VERSION:-2.7.4}"
# ── Sanity checks ─────────────────────────────────────────────────────
if ! command -v python >/dev/null 2>&1; then
echo " python not found in PATH. Activate the conda env first:" >&2
echo " conda activate causalgrok" >&2
exit 1
fi
if [[ -z "${CONDA_PREFIX:-}" ]]; then
echo " CONDA_PREFIX is empty β€” no conda env appears to be active." >&2
echo " conda activate causalgrok" >&2
exit 1
fi
# ── Where the install logs live ───────────────────────────────────────
# logs/install/ is reserved for environment / dependency setup output.
# It is deliberately separate from experiments/runs/ so an install
# never gets confused with a training run.
STAMP="$(date -u +%Y%m%d-%H%M%S)"
INSTALL_DIR="logs/install/${STAMP}_flash_attn"
mkdir -p "${INSTALL_DIR}"
# Snapshot the env we're installing into (so we can debug later)
{
echo "# captured: $(date -u +%FT%TZ)"
echo "# host: $(hostname)"
echo "# CONDA_PREFIX:${CONDA_PREFIX}"
echo "# CONDA_DEFAULT_ENV: ${CONDA_DEFAULT_ENV:-}"
echo "# python: $(python --version 2>&1)"
echo "# which python: $(command -v python)"
if command -v nvcc >/dev/null 2>&1; then
echo "# nvcc: $(nvcc --version | tail -1)"
else
echo "# nvcc: NOT FOUND (CUDA toolkit may be missing)"
fi
if command -v nvidia-smi >/dev/null 2>&1; then
echo "# nvidia-smi:"
nvidia-smi --query-gpu=name,driver_version,compute_cap --format=csv,noheader \
| sed 's/^/# /'
fi
python -c "import torch, sys; print(f'# torch: {torch.__version__}'); print(f'# torch.cuda: {torch.version.cuda}'); print(f'# CUDA avail: {torch.cuda.is_available()}')" 2>/dev/null \
|| echo "# torch: NOT INSTALLED β€” install torch first"
} > "${INSTALL_DIR}/env.txt"
cat "${INSTALL_DIR}/env.txt"
echo
# ── Pre-flight: torch must be present ─────────────────────────────────
if ! python -c "import torch" >/dev/null 2>&1; then
echo " torch is not installed in this env. Install it first:" >&2
echo " pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118" >&2
exit 1
fi
# ── Launch the build under nohup ──────────────────────────────────────
# packaging/ninja/wheel are required for the source build path; the
# resolver will pick the prebuilt wheel where one exists.
echo "Starting flash-attn ${FLASH_ATTN_VERSION} install (detached)..."
launch_detached "${INSTALL_DIR}" \
bash -c "
set -euo pipefail
echo '== installing build deps =='
pip install --upgrade pip wheel setuptools packaging ninja
echo
echo '== installing flash-attn ${FLASH_ATTN_VERSION} =='
pip install --no-build-isolation 'flash-attn==${FLASH_ATTN_VERSION}'
echo
echo '== sanity check =='
python -c 'import flash_attn; print(\"flash-attn version:\", flash_attn.__version__)'
echo 'DONE'
"
echo
echo "Outputs:"
echo " env snapshot : ${INSTALL_DIR}/env.txt"
echo " build log : ${INSTALL_DIR}/logs/train.log"
echo " build err : ${INSTALL_DIR}/logs/train.err"
echo " PID : ${INSTALL_DIR}/run.pid"
echo
echo "Verify completion with:"
echo " grep -E 'DONE|ERROR' ${INSTALL_DIR}/logs/train.log"