arcisvlm / scripts /train_single_gpu.sh
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
1.81 kB
#!/bin/bash
# Single-GPU training to avoid NCCL inter-GPU communication issues on vast.ai
# Uses 1x H100 with gradient accumulation to match effective batch size
set -ex
cd /root/arcisvlm
export HF_TOKEN=$HF_TOKEN
export CUDA_VISIBLE_DEVICES=0
# Clean corrupted checkpoints
rm -f checkpoints/stage2_epoch*.pt checkpoints/stage2_final.pt checkpoints/stage3_*.pt
echo "=== Stage 2: Single GPU Training ==="
date
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
# Single GPU — no torchrun, no DDP
# Use nproc=1 so DDP is trivial (single process group)
torchrun --nproc_per_node=1 --master_port=29503 \
scripts/train_stage2_ddp.py \
--config configs/scale_1.3b.yaml \
--stage1_ckpt checkpoints/v3_stage1_final.pt
STAGE2_EXIT=$?
echo "Stage 2 exit code: $STAGE2_EXIT"
date
if [ $STAGE2_EXIT -ne 0 ]; then
echo "!!! Stage 2 FAILED with exit code $STAGE2_EXIT !!!"
echo "Checking for partial checkpoints..."
ls -lh checkpoints/stage2_*.pt 2>/dev/null || echo "No stage2 checkpoints found"
exit 1
fi
echo "=== Pushing Stage 2 to HF ==="
python3 scripts/push_to_hf.py stage2_final.pt v4_stage2_final.pt
echo "=== Stage 3: Single GPU Training ==="
date
torchrun --nproc_per_node=1 --master_port=29503 \
scripts/train_stage3_ddp.py \
--config configs/scale_1.3b.yaml \
--stage2_ckpt checkpoints/stage2_final.pt
STAGE3_EXIT=$?
echo "Stage 3 exit code: $STAGE3_EXIT"
date
if [ $STAGE3_EXIT -ne 0 ]; then
echo "!!! Stage 3 FAILED with exit code $STAGE3_EXIT !!!"
ls -lh checkpoints/stage3_*.pt 2>/dev/null || echo "No stage3 checkpoints found"
exit 1
fi
echo "=== Pushing Stage 3 to HF ==="
python3 scripts/push_to_hf.py stage3_final.pt v4_stage3_final.pt
echo "=== ALL TRAINING COMPLETE ==="
date
echo "Checkpoints:"
ls -lh checkpoints/