PIVOT / scripts /setup_gears_env.sh
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
1.43 kB
#!/bin/bash
set -e
source /home/bcheng/miniconda3/etc/profile.d/conda.sh
# fresh env, gears needs py3.10
echo "=== create fresh env pivot_gears (py3.10) ==="
conda create -y -n pivot_gears python=3.10 >/tmp/g_env.log 2>&1 && echo "env created"
conda activate pivot_gears
# torch cu118 bundles its own cuda 11.8 runtime so it works on the old 470 driver
echo "=== torch 2.2.0 +cu118 WITH deps (bundles CUDA 11.8 runtime; works on 470 driver) ==="
pip install --no-input "numpy<2" >/tmp/g_np.log 2>&1
pip install --no-input torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu118 >/tmp/g_torch.log 2>&1 && echo "torch+cu118 installed"
python -c "import torch; print('TORCH', torch.__version__, 'CUDA avail:', torch.cuda.is_available(), 'devices:', torch.cuda.device_count())"
echo "=== torch_geometric + cell-gears ==="
pip install --no-input torch_geometric >/tmp/g_pyg.log 2>&1 && echo "pyg ok"
pip install --no-input cell-gears >/tmp/g_gears.log 2>&1 && echo "gears ok"
echo "=== final import + GPU test ==="
python - <<PY
import torch, torch_geometric, gears, numpy
print("numpy", numpy.__version__)
print("torch", torch.__version__, "cuda", torch.cuda.is_available())
print("PyG", torch_geometric.__version__, "GEARS", getattr(gears,"__version__","?"))
if torch.cuda.is_available():
x=torch.randn(1000,1000,device="cuda"); print("GPU matmul ok:", (x@x).sum().item()!=0)
PY
echo REBUILD_DONE