| #!/usr/bin/env bash |
| |
|
|
| detect_num_gpus() { |
| |
| if [[ -n "${NUM_GPUS:-}" ]]; then |
| echo "${NUM_GPUS}" |
| return |
| fi |
|
|
| |
| if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then |
| local count=0 |
| local d |
| IFS=',' read -ra _DEVS <<< "${CUDA_VISIBLE_DEVICES}" |
| for d in "${_DEVS[@]}"; do |
| d="${d// /}" |
| if [[ -n "${d}" ]]; then |
| count=$((count + 1)) |
| fi |
| done |
| if [[ "${count}" -gt 0 ]]; then |
| echo "${count}" |
| return |
| fi |
| fi |
|
|
| |
| if command -v python >/dev/null 2>&1; then |
| local torch_count |
| torch_count="$(python - <<'PY' 2>/dev/null || true |
| import torch |
| print(torch.cuda.device_count()) |
| PY |
| )" |
| if [[ "${torch_count}" =~ ^[0-9]+$ ]] && [[ "${torch_count}" -gt 0 ]]; then |
| echo "${torch_count}" |
| return |
| fi |
| fi |
|
|
| if command -v nvidia-smi >/dev/null 2>&1; then |
| local count |
| count="$(nvidia-smi -L 2>/dev/null | wc -l | tr -d ' ')" |
| if [[ "${count}" =~ ^[0-9]+$ ]] && [[ "${count}" -gt 0 ]]; then |
| echo "${count}" |
| return |
| fi |
| fi |
|
|
| echo 1 |
| } |
|
|
| launch_num_processes_flag() { |
| local num_gpus |
| num_gpus="$(detect_num_gpus)" |
| echo "--num_processes ${num_gpus}" |
| } |
|
|
| resolve_accelerate_config() { |
| |
| if [[ -n "${ACCELERATE_CONFIG:-}" ]]; then |
| echo "${ACCELERATE_CONFIG}" |
| return |
| fi |
|
|
| local num_gpus |
| num_gpus="$(detect_num_gpus)" |
|
|
| |
| |
| |
| |
| if [[ "${num_gpus}" -ge 8 ]]; then |
| echo "default_config_8gpu.yaml" |
| else |
| echo "default_config.yaml" |
| fi |
| } |
|
|
| print_launch_plan() { |
| local num_gpus |
| local accel_config |
| num_gpus="$(detect_num_gpus)" |
| accel_config="$(resolve_accelerate_config)" |
| echo "============================================================" |
| echo "Launch plan: --num_processes ${num_gpus}" |
| echo "accelerate config: ${accel_config} (DDP/MULTI_GPU unless overridden)" |
| if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then |
| echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" |
| fi |
| if command -v nvidia-smi >/dev/null 2>&1; then |
| echo "nvidia-smi -L:" |
| nvidia-smi -L 2>/dev/null || true |
| fi |
| if command -v python >/dev/null 2>&1; then |
| python - <<'PY' 2>/dev/null || true |
| import torch |
| print(f"torch.cuda.device_count()={torch.cuda.device_count()}") |
| PY |
| fi |
| echo "============================================================" |
| } |
|
|