File size: 2,058 Bytes
0f7408a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env bash
set -euo pipefail

# Bootstrap a WSL CUDA Python env capable of running train.py TPS checks.
# Usage:
#   bash scripts/wsl_bootstrap_tps.sh [cuda-tag]
# Example:
#   bash scripts/wsl_bootstrap_tps.sh cu121

CUDA_TAG="${1:-cu121}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
VENV_DIR="${VENV_DIR:-.venv-wsl}"

if ! grep -qiE "microsoft|wsl" /proc/version 2>/dev/null; then
  echo "[bootstrap] warning: not running inside WSL; continuing anyway"
fi

if ! command -v nvidia-smi >/dev/null 2>&1; then
  echo "[bootstrap] error: nvidia-smi not found. Install NVIDIA driver + WSL GPU support first."
  exit 1
fi

if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then
  echo "[bootstrap] error: Python binary not found: $PYTHON_BIN"
  exit 1
fi

"$PYTHON_BIN" -m venv "$VENV_DIR"
source "$VENV_DIR/bin/activate"

python -m pip install --upgrade pip wheel setuptools

case "$CUDA_TAG" in
  cu118)
    TORCH_INDEX_URL="https://download.pytorch.org/whl/cu118"
    ;;
  cu121)
    TORCH_INDEX_URL="https://download.pytorch.org/whl/cu121"
    ;;
  cu124)
    TORCH_INDEX_URL="https://download.pytorch.org/whl/cu124"
    ;;
  *)
    echo "[bootstrap] error: unsupported cuda tag '$CUDA_TAG' (supported: cu118, cu121, cu124)"
    exit 1
    ;;
esac

python -m pip install "torch" --index-url "$TORCH_INDEX_URL"
python -m pip install -e ".[dev]"

# IMPORTANT: --no-build-isolation keeps pip from pulling torch-cpu into an
# isolated build env, which would break mamba-ssm extension builds.
python -m pip install "causal-conv1d>=1.4.0" --no-build-isolation
python -m pip install "mamba-ssm" --no-build-isolation

python - <<'PY'
import torch
print(f"[bootstrap] torch={torch.__version__}")
print(f"[bootstrap] torch_cuda={torch.version.cuda}")
print(f"[bootstrap] cuda_available={torch.cuda.is_available()}")
if not torch.cuda.is_available():
    raise SystemExit("[bootstrap] error: CUDA not available to torch")
import mamba_ssm  # noqa: F401
print("[bootstrap] mamba_ssm import OK")
PY

echo "[bootstrap] done. Activate env with: source $VENV_DIR/bin/activate"