#!/bin/bash # scripts/bootstrap_venvs.sh # Dual-venv bootstrap for Synesthesia ROCm 7.2.1 # Creates .venv-jax and .venv-torch with pinned wheels. set -e # --- Configuration --- FORCE=0 if [[ "$1" == "--force" ]]; then FORCE=1 fi # Function to log messages log() { echo -e "\033[1;34m[BOOTSTRAP]\033[0m $1" } # Function to log errors error() { echo -e "\033[1;31m[ERROR]\033[0m $1" >&2 } # --- 1. ROCm 7.2.1 Presence Check --- # We ensure the hardware environment is correct before proceeding. if [[ "$SKIP_ROCM_CHECK" != "1" ]]; then log "Checking for ROCm 7.2.1..." if [[ ! -d "/opt/rocm" ]]; then error "ROCm not found at /opt/rocm. Please install ROCm 7.2.1." exit 1 fi if ! command -v rocm-smi &> /dev/null; then error "rocm-smi not found. Ensure ROCm binaries are in your PATH." exit 1 fi log "ROCm 7.2.1 check passed." else log "Skipping ROCm check (SKIP_ROCM_CHECK=1)." fi # --- 2. Venv Management Functions --- # Function to verify JAX GPU visibility verify_jax() { local venv_path=$1 log "Verifying GPU visibility in $venv_path..." if "$venv_path/bin/python" -c "import jax; devices = jax.devices(); print(f'Devices: {devices}'); exit(0) if any(d.platform == 'gpu' or d.platform == 'rocm' for d in devices) else exit(1)" 2>/dev/null; then return 0 else return 1 fi } # Function to verify Torch GPU visibility verify_torch() { local venv_path=$1 log "Verifying GPU visibility in $venv_path..." if "$venv_path/bin/python" -c "import torch; print(f'CUDA Available: {torch.cuda.is_available()}'); exit(0) if torch.cuda.is_available() else exit(1)" 2>/dev/null; then return 0 else return 1 fi } # --- 3. Bootstrap .venv-jax --- # This venv is used for JAX-based inference and IREE/ONNX model exports. VENV_JAX=".venv-jax" if [[ $FORCE -eq 1 ]]; then log "Force flag detected. Deleting $VENV_JAX..." rm -rf "$VENV_JAX" fi if [[ -d "$VENV_JAX" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_jax "$VENV_JAX"; }; then log "$VENV_JAX already exists and passes verification. Skipping." else log "Creating $VENV_JAX..." rm -rf "$VENV_JAX" python3 -m venv "$VENV_JAX" "$VENV_JAX/bin/pip" install --upgrade pip setuptools wheel # Install JAX with ROCm 7.2 wheels from the specified GitHub release log "Installing JAX-ROCm 7.2 wheels from GitHub..." "$VENV_JAX/bin/pip" install jax==0.4.30 "jaxlib[rocm]==0.4.19+rocm7.2" --find-links https://storage.googleapis.com/jax-releases/jax_rocm_releases.html --extra-index-url https://pypi.org/simple # Install TensorFlow-ROCm and conversion tools log "Installing TensorFlow-ROCm and conversion tools..." "$VENV_JAX/bin/pip" install --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/ tensorflow-rocm==2.16.1 tf2jax==0.3.8 tf2onnx iree-compiler iree-runtime optax flax huggingface_hub if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_jax "$VENV_JAX"; then error ".venv-jax verification FAILED: GPU not visible." exit 1 fi log ".venv-jax PASS" fi # --- 4. Bootstrap .venv-torch --- # This venv is used for PyTorch-based training and fine-tuning (Gemma 3, etc.) VENV_TORCH=".venv-torch" if [[ $FORCE -eq 1 ]]; then log "Force flag detected. Deleting $VENV_TORCH..." rm -rf "$VENV_TORCH" fi if [[ -d "$VENV_TORCH" ]] && { [[ "$SKIP_ROCM_CHECK" == "1" ]] || verify_torch "$VENV_TORCH"; }; then log "$VENV_TORCH already exists and passes verification. Skipping." else log "Creating $VENV_TORCH..." rm -rf "$VENV_TORCH" python3 -m venv "$VENV_TORCH" "$VENV_TORCH/bin/pip" install --upgrade pip setuptools wheel # Install PyTorch with ROCm 7.2 (latest for performance and security) log "Installing PyTorch-ROCm 7.2..." "$VENV_TORCH/bin/pip" install --index-url https://download.pytorch.org/whl/rocm7.2 torch torchvision torchaudio # Install HuggingFace stack and utilities log "Installing HuggingFace stack and utilities..." "$VENV_TORCH/bin/pip" install transformers trl peft "bitsandbytes>=0.43" accelerate onnx onnxruntime-rocm tensorboard trackio huggingface_hub streamlit rich python-dotenv if [[ "$SKIP_ROCM_CHECK" != "1" ]] && ! verify_torch "$VENV_TORCH"; then error ".venv-torch verification FAILED: GPU not visible." exit 1 fi log ".venv-torch PASS" fi log "Dual-venv bootstrap complete. Both venvs verified." exit 0