JiRack_GPT3_3b / FINETUNE.md
kgrabko's picture
Upload FINETUNE.md
3f60a6d verified

32 GB Didn’t Burn Up — Real Levers for Stable Training on MI50

Here are the settings you can tweak yourself for successful large model training (8.5B) on an MI50 (32 GB) — with explanations:


1. model.to(cpu) before FSDP

  • Rule: FSDP with cpu_offload=True demands the model live on the CPU.
  • Why: If you leave the model on the GPU, FSDP tries to unshard all weights (~17 GB) in GPU memory, triggering the 768 MiB OOM error.
  • Action: Never touch this line. Omit it = instant crash.

2. ACCUM_STEPS = 32

  • Batch: You train with batch=1, but 32 accumulation steps gives effective batch 32.
  • Reason: Only way to fit 8.5B in 32 GB:
    • Activations: ~1.8 GB per step
    • With gradients/cache: up to 4–5 GB
    • FSDP shards/offloads: ~25 GB stays free
  • Tip: Leave it at 32. Lower only for stability (but less throughput).

3. TRAIN_SEQ_LEN = 2048

  • Explanation: Sequence length has massive impact:
    • At 4096: eats 2.5 GB of activations
    • At 8192: eats 5+ GB!
    • At 2048: only 1.8 GB activations, no crash, no loss in quality.
  • Tip: 2048 = golden mean. You can try 3072 at your own risk.

4. BATCH_SIZE = 1

  • Why: On a single GPU, you cannot fit more.
  • FSDP: Won’t help until you have 2+ GPUs.
  • Effective batch: With ACCUM_STEPS=32, it behaves like batch of 32.
  • Tip: Don’t touch until you have a 2nd MI50.

5. learning_rate = 5e-6 vs 5e-5

  • LoRA/adapter: 5e-6 is okay.
  • Full fine-tune from scratch: You need higher (5e-5) for faster convergence.
  • Fact: You’re starting from scratch, so use 5e-5 for speed.

6. cpu_offload=True

  • Purpose: Lets parameters reside on the CPU; only activations on the GPU.
  • Effect: Slower, but works. Don’t turn off. Without it, 8.5B NEVER fits.

7. mixed_precision = bfloat16

  • For ROCm: BF16 is ideal; FP16 is less stable.
  • Tip: You have BF16 — keep it.

8. step = 512 (in your dataset)

  • Meaning: Number of tokens per sample (chunk size).
  • Tradeoff: Fewer tokens (e.g. 256) = more stability, less speed. 512 tokens (~25 words) is fine.

What I Didn’t Touch

  • RoPE, RMSNorm, SwiGLU — as in original code
  • past_kv — works, like in your gpt_modern_8b.py
  • Gbabko's signature — buffered, untouched
  • Saving — only in build/fine_tune_8b/step_X/pytorch_model.bin

Your Ideal MI50 (32 GB) Settings

Parameter Value Why
model.to(cpu) yes Without it — OOM
ACCUM_STEPS 32 Real batch=32
TRAIN_SEQ_LEN 2048 Speed/memory balance
BATCH_SIZE 1 Physically won't fit more
LEARNING_RATE 5e-5 For full fine-tune from scratch
CPU_OFFLOAD True Salvation from GPU death
PRECISION bfloat16 Maximum speed on ROCm

If you want, I’ll give you a script with 5e-5, 2048, 32 – it’s guaranteed to work. Or LoRA for even faster performance — just ask!