| # Spider-FLEXITOKENS Remote Training Guide |
|
|
| ## Target Hardware: NVIDIA RTX 6000 Pro (Blackwell) |
|
|
| - **GPU**: RTX 6000 Pro (Blackwell architecture, sm120+) |
| - **VRAM**: 48GB GDDR7 |
| - **Precision**: MXFP8 (rowwise_with_gw_hp recipe) — primary; FP8_DYNAMIC fallback |
| - **Expected peak VRAM**: ~15-20GB (model ~4GB FP8, optimizer ~8GB standard AdamW, activations ~4-8GB with gradient checkpointing) |
|
|
| ## Quick Start |
|
|
| ```bash |
| # 1. Clone/transfer the repo to the remote machine |
| # 2. Install dependencies (see below) |
| # 3. Run the launch script |
| bash scripts/train_remote.sh |
| ``` |
|
|
| ## Environment Setup |
|
|
| ### Required Dependencies |
|
|
| ```bash |
| pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 |
| pip install torchao>=0.17.0 |
| pip install datasets transformers |
| pip install bitsandbytes # optional — only used for BF16 fallback |
| ``` |
|
|
| ### Optional (Recommended) |
|
|
| ```bash |
| pip install unsloth # MoE kernel optimizations + memory-efficient GC |
| ``` |
|
|
| ### Verify Installation |
|
|
| ```bash |
| python3 -c " |
| import torch |
| print(f'PyTorch: {torch.__version__}') |
| print(f'CUDA: {torch.version.cuda}') |
| print(f'GPU: {torch.cuda.get_device_name(0)}') |
| print(f'Compute capability: sm{torch.cuda.get_device_capability(0)[0]}') |
| |
| import torchao |
| print(f'torchao: {torchao.__version__}') |
| |
| from torchao.float8 import Float8LinearConfig |
| print('FP8 training: available') |
| print(f'Recipes: {[n.value for n in __import__(\"torchao.float8.config\", fromlist=[\"Float8LinearRecipeName\"]).Float8LinearRecipeName]}') |
| " |
| ``` |
|
|
| Expected output on RTX 6000 Pro: `sm120` or higher, all 3 recipes available (`tensorwise`, `rowwise`, `rowwise_with_gw_hp`). |
|
|
| ## Configuration |
|
|
| ### Environment Variables |
|
|
| | Variable | Default | Description | |
| |---|---|---| |
| | `PRECISION` | `mxfp8` | Training precision: `mxfp8`, `fp8_dynamic`, `bf16` | |
| | `SEQ_LEN` | `2048` | Sequence length per sample | |
| | `MICRO_BATCH` | `8` | Batch size per forward pass | |
| | `GRAD_ACCUM` | `4` | Gradient accumulation steps | |
| | `TARGET_TOKENS` | `10000000000` | Total training tokens (10B) | |
| | `N_LOOPS` | `6` | Recurrent loop iterations | |
| | `LR` | `3e-4` | Peak learning rate | |
| | `CKPT_EVERY` | `500` | Save checkpoint every N steps | |
| | `CKPT_DIR` | `checkpoints-spider-remote` | Checkpoint output directory | |
| | `RESUME` | _(empty)_ | Path to checkpoint for manual resume | |
|
|
| ### Recommended Settings for RTX 6000 Pro (48GB) |
|
|
| ```bash |
| # MXFP8 — maximum accuracy, best VRAM efficiency |
| export PRECISION=mxfp8 |
| export MICRO_BATCH=8 |
| export GRAD_ACCUM=4 |
| # Global batch: 8 * 4 * 2048 = 65,536 tokens/step |
| # ~10B tokens ≈ 152,000 steps |
| ``` |
|
|
| ### Conservative Settings (if VRAM-constrained) |
|
|
| ```bash |
| export PRECISION=fp8_dynamic |
| export MICRO_BATCH=4 |
| export GRAD_ACCUM=8 |
| # Global batch: 4 * 8 * 2048 = 65,536 tokens/step (same throughput, lower peak VRAM) |
| ``` |
|
|
| ## Launch |
|
|
| ### Fresh Training Run |
|
|
| ```bash |
| bash scripts/train_remote.sh |
| ``` |
|
|
| ### Resume from Checkpoint |
|
|
| ```bash |
| # Auto-resume (picks latest from CKPT_DIR) |
| bash scripts/train_remote.sh |
| |
| # Manual resume from specific checkpoint |
| export RESUME=checkpoints-spider-remote/spider-step5000.pt |
| bash scripts/train_remote.sh |
| ``` |
|
|
| ### Resume from Local Smoke Test |
|
|
| Transfer the local checkpoint to the remote machine, then: |
|
|
| ```bash |
| export RESUME=checkpoints-spider-real/spider-final-ep1.pt |
| bash scripts/train_remote.sh |
| ``` |
|
|
| **Note**: The local checkpoint was trained with 8-bit AdamW (BF16). On resume with MXFP8/FP8, the training script will: |
| 1. Load model weights (always compatible) |
| 2. Skip 8-bit optimizer state with a warning (8-bit → standard AdamW mismatch) |
| 3. Continue training with standard AdamW from step 0 optimizer state |
|
|
| This is by design — the optimizer state mismatch is handled gracefully. |
|
|
| ## Monitoring |
|
|
| ### Training Logs |
|
|
| The script outputs structured logs every 10 steps: |
|
|
| ``` |
| Epoch 1 | step 10/152000 | loss 3.2140 | lm 3.1020 | aux 0.0312 | bp 0.0808 [FIXED/FROZEN] | gnorm 1.23 | lr 3.00e-04 | 0.42M tok/s | 0.07B tokens |
| ``` |
|
|
| Key metrics: |
| - **loss**: Total loss (lm + aux + bp) |
| - **lm**: Language modeling loss |
| - **aux**: MoE load-balancing auxiliary loss |
| - **bp**: Boundary predictor loss [FIXED=30% curriculum / ADAPTIVE=learned] |
| - **gnorm**: Gradient norm (should stabilize ~1-5) |
| - **tok/s**: Throughput (expect 0.5-1.0M tok/s on RTX 6000 Pro) |
|
|
| ### VRAM Monitoring |
|
|
| ```bash |
| watch -n 5 nvidia-smi |
| ``` |
|
|
| Expected on RTX 6000 Pro with MXFP8: |
| - Model: ~2GB (weights in FP8) |
| - Optimizer: ~8GB (standard AdamW, FP32 states) |
| - Activations: ~4-8GB (gradient checkpointing enabled) |
| - **Peak**: ~15-20GB total |
|
|
| ### Health Warnings |
|
|
| The `RecurrentMonitor` checks for: |
| - **Representation drift**: Loop hidden states diverging (cosine sim < 0.5) |
| - **Collapse**: All experts producing identical outputs (std < 1e-6) |
|
|
| If you see these warnings, consider reducing `N_LOOPS` or lowering learning rate. |
|
|
| ## Precision Fallback Chain |
|
|
| The training script automatically falls back if precision setup fails: |
|
|
| ``` |
| MXFP8 (sm120+ Blackwell) → FP8_DYNAMIC (sm89+ Ada) → BF16 (all GPUs) |
| ``` |
|
|
| - **MXFP8**: Row-wise scaling + high-precision grad weight accumulation. Best accuracy. |
| - **FP8_DYNAMIC**: Row-wise dynamic scaling. Good accuracy/performance tradeoff. |
| - **BF16**: No quantization. Most VRAM, but simplest path. |
| |
| ## Checkpoint Files |
| |
| | File | Description | |
| |---|---| |
| | `spider-step{N}.pt` | Step checkpoint (every `CKPT_EVERY` steps) | |
| | `spider-ep{N}.pt` | Epoch boundary checkpoint | |
| | `spider-best.pt` | Best loss checkpoint (updated when epoch loss improves) | |
| | `spider-final-ep{N}.pt` | Final checkpoint at training end | |
| |
| Each checkpoint contains: |
| - Model state dict |
| - Optimizer state dict |
| - Training step, epoch, config |
| - `best_loss` value |
| - BP optimizer state (if active) |
| |
| ## Troubleshooting |
| |
| ### `mat2 shape must be divisible by 16` |
| |
| Fixed with `pad_inner_dim=True` in `Float8LinearConfig` (v0.17.0+). The training script handles this automatically. |
| |
| ### `CUDA out of memory` |
| |
| Reduce `MICRO_BATCH` or increase `GRAD_ACCUM` to maintain the same global batch size: |
| |
| ```bash |
| export MICRO_BATCH=4 # was 8 |
| export GRAD_ACCUM=8 # was 4 (same 65,536 tok/step) |
| ``` |
| |
| ### Optimizer state mismatch on resume |
| |
| Expected when resuming a BF16 (8-bit Adam) checkpoint on FP8/MXFP8 (standard AdamW). The script logs a warning and continues — model weights load fine, optimizer restarts from scratch. |
| |
| ### Slower than expected throughput |
| |
| - Ensure `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` is set (default in script) |
| - Check `torch.compile` isn't being used inadvertently (adds compile overhead) |
| - Verify torchao version >= 0.17.0 for optimal FP8 kernels |
| |