| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| set -euo pipefail |
|
|
| |
| |
| |
| export CKPT_DIR="${CKPT_DIR:-checkpoints-spider-remote}" |
| export SEQ_LEN="${SEQ_LEN:-2048}" |
| export MICRO_BATCH="${MICRO_BATCH:-8}" |
| export GRAD_ACCUM="${GRAD_ACCUM:-4}" |
| export TARGET_TOKENS="${TARGET_TOKENS:-10000000000}" |
| export N_LOOPS="${N_LOOPS:-6}" |
| export LR="${LR:-3e-4}" |
| export CKPT_EVERY="${CKPT_EVERY:-500}" |
| export PRECISION="${PRECISION:-mxfp8}" |
| export RESUME="${RESUME:-}" |
|
|
| |
| export GLOBAL_BATCH_TOK=$(( MICRO_BATCH * GRAD_ACCUM * SEQ_LEN )) |
| export TOTAL_STEPS=$(( TARGET_TOKENS / GLOBAL_BATCH_TOK )) |
|
|
| echo "=============================================" |
| echo " Spider-FLEXITOKENS Remote Training" |
| echo "=============================================" |
| echo " GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" |
| echo " VRAM: $(nvidia-smi --query-gpu=memory.total --format=csv,noheader 2>/dev/null || echo 'N/A')" |
| echo " Precision: ${PRECISION}" |
| echo " Seq length: ${SEQ_LEN}" |
| echo " Micro batch: ${MICRO_BATCH}" |
| echo " Grad accum: ${GRAD_ACCUM}" |
| echo " Global batch: ${GLOBAL_BATCH_TOK} tokens/step" |
| echo " Target tokens: ${TARGET_TOKENS} ($(( TARGET_TOKENS / 1000000000 ))B)" |
| echo " Total steps: ~${TOTAL_STEPS}" |
| echo " LR: ${LR}" |
| echo " N loops: ${N_LOOPS}" |
| echo " Checkpoint dir: ${CKPT_DIR}" |
| echo " Resume from: ${RESUME:-none (auto-resume from ${CKPT_DIR})}" |
| echo "=============================================" |
|
|
| |
| |
| |
| export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" |
| export UNSLOTH_MOE_BACKEND="grouped_mm" |
| export CUDA_DEVICE_ORDER="PCI_BUS_ID" |
| export OMP_NUM_THREADS="${OMP_NUM_THREADS:-8}" |
|
|
| |
| |
| |
| python3 -c "import torch; print(f'PyTorch {torch.__version__} | CUDA {torch.version.cuda} | sm{torch.cuda.get_device_capability()[0]}')" || { echo "ERROR: PyTorch not found"; exit 1; } |
| python3 -c "import torchao; print(f'torchao {torchao.__version__}')" || echo "WARNING: torchao not found — FP8/MXFP8 unavailable, will use BF16" |
| python3 -c "from torchao.float8 import Float8LinearConfig; print(f'Float8LinearConfig OK (recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName])}')" || echo "WARNING: torchao.float8 not available" |
| python3 -c "import bitsandbytes; print(f'bitsandbytes {bitsandbytes.__version__}')" || echo "INFO: bitsandbytes not found — using standard AdamW (expected for FP8+ modes)" |
| python3 -c "import unsloth; print('Unsloth available')" 2>/dev/null || echo "INFO: Unsloth not available — using standard PyTorch training" |
|
|
| |
| |
| |
| mkdir -p "${CKPT_DIR}" |
|
|
| |
| |
| |
| CMD="python3 scripts/train_spider.py" |
| CMD="${CMD} --precision ${PRECISION}" |
| CMD="${CMD} --ckpt_dir ${CKPT_DIR}" |
| CMD="${CMD} --seq_len ${SEQ_LEN}" |
| CMD="${CMD} --micro_batch ${MICRO_BATCH}" |
| CMD="${CMD} --n_loops ${N_LOOPS}" |
| CMD="${CMD} --lr ${LR}" |
|
|
| if [ -n "${RESUME}" ] && [ -f "${RESUME}" ]; then |
| CMD="${CMD} --resume ${RESUME}" |
| fi |
|
|
| |
| |
| |
| echo "" |
| echo "Launching training..." |
| echo "Command: ${CMD}" |
| echo "" |
|
|
| exec ${CMD} |
|
|