Image-Text-to-Text
Transformers
Safetensors
gemma4
Generated from Trainer

Built with Axolotl

See axolotl config

axolotl version: 0.16.0.dev0

# ══════════════════════════════════════════════════════════════════════════════
# Axolotl — Full Fine-Tuning Continued Pre-Training
# Model:    Gemma 4 31B Dense (google/gemma-4-31B) — all parameters trainable
# GPUs:     8× A100 80GB SXM (NVLink) — DeepSpeed ZeRO-3
# Data:     321,196 chunks | 75% domain (Vedic/SPH) + 25% FineWeb-Edu
# Tokens:   ~1.013B | Sequence: 4096 | 1 epoch
# Cost:     6× $1.49/hr = $8.94/hr → est. 55-75 hrs → $490-670
#
# Launch:
# PYTORCH_ALLOC_CONF=expandable_segments:True accelerate launch --num_processes 8 -m axolotl.cli.train axolotl_cpt.yml > train.log 2>&1
#
# References (verified):
#   - MEDITRON-70B (EPFL, arXiv:2311.16079): FFT CPT, LR=1.5e-4, 48B tok
#   - Me-LLaMA-70B (UF, PMC/11142305): FFT CPT, LR=8e-6, 129B tok
#   - Biderman et al. (TMLR 2024, arXiv:2405.09673): FFT > LoRA for CPT
# ══════════════════════════════════════════════════════════════════════════════

# ── Model (FFT — no adapter, no quantization) ─────────────────────────────
# Use the BASE (pre-trained) model, NOT instruction-tuned (-it).
base_model: google/gemma-4-31B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
hf_use_auth_token: true

# ── DeepSpeed ZeRO-3 ──────────────────────────────────────────────────────
# Shards weights, gradients, and optimizer states across 6 GPUs.
# Per-GPU: ~62GB sharded model state + activations → fits 80GB with grad ckpt.
# Config ships with Axolotl — no custom JSON needed.
deepspeed: /workspace/axolotl/deepspeed_configs/zero3_bf16.json

# ── Dataset ────────────────────────────────────────────────────────────────
# Loaded from HuggingFace Hub — pre-shuffled: 75% domain (Vedic/SPH) + 25% FineWeb-Edu.
# Pre-chunked to ≤4096 Gemma tokens with 256-tok intra-doc overlap.
# type: completion → each {"text": "..."} line is one sample, loss on all tokens.
# (Do NOT use type: pretrain — that re-concatenates, breaking our chunking.)
datasets:
  - path: AiForgeMaster/gemma4-31b-cpt-data
    type: completion
    split: train

dataset_prepared_path: /workspace/axolotl/axolotl_cache/cpt

# ── Sequence & Packing ─────────────────────────────────────────────────────
sequence_len: 4096
sample_packing: true       # packs shorter chunks together — eliminates padding waste
pad_to_sequence_len: true
gemma4_hybrid_attn_impl: true  # FA2 on sliding (head_dim=256) layers, SDPA on global (head_dim=512) layers; sets flash_attention internally

# ── Training Hyperparameters ───────────────────────────────────────────────
num_epochs: 1                # one epoch — ~1,932 steps over 1.013B tokens

# Effective batch = micro_batch × grad_accum × 8 GPUs = 1 × 16 × 8 = 128 samples
# → ~524K tokens/step
micro_batch_size: 1            # mbs=2 OOM'd on transient all-gather (8 GB failed alloc) — stay at mbs=1
gradient_accumulation_steps: 16

chunked_cross_entropy: true    # avoids materializing full (B,S,V) logits tensor

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true     # fused GEGLU MLP for Gemma 4 (Triton)
liger_rms_norm: false          # keep existing fused_attn.py RMSNorm patch
liger_rope: false              # Gemma 4 incompatible (separate q/k)
liger_cross_entropy: false     # chunked_cross_entropy handles this
liger_fused_linear_cross_entropy: false  # Gemma 4 incompatible

optimizer: adamw_bnb_8bit    # 8-bit Adam — ~6 bytes/param opt state instead of 12; saves ~23 GB/GPU
lr_scheduler: cosine
learning_rate: 5e-5        # conservative for FFT CPT on 31B with 1B tokens
                           # verified range: 5e-6 (Biderman) to 1.5e-4 (MEDITRON)
                           # 5e-5 balances learning vs forgetting for our data scale
weight_decay: 0.1          # standard for FFT with AdamW (MEDITRON used 0.1)
max_grad_norm: 1.0
warmup_ratio: 0.01         # ~25-33 warmup steps before cosine decay

# ── Precision & Memory ─────────────────────────────────────────────────────
bf16: true
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: true      # DeepSpeed compatibility (per Axolotl docs)

# ── Output & Checkpointing ─────────────────────────────────────────────────
# /workspace/data is on a 16 TB volume — full DS checkpoints (~310 GB each) fit fine.
output_dir: /workspace/data/axolotl_output/gemma4-31b-cpt
logging_steps: 10

save_only_model: false       # save optimizer + scheduler + RNG for exact resume
saves_per_epoch: 4           # every ~25% of epoch (~every 483 steps / ~14 hrs)
save_total_limit: 2          # keep latest 2 (briefly 3 during write) — ~900 GB peak on disk
# Resume: accelerate launch ... axolotl_cpt.yml --resume_from_checkpoint <path>
# Infra switch (different GPU count): run zero_to_fp32.py on old checkpoint,
# then start fresh — optimizer state resets, loss wobbles briefly then recovers.

val_set_size: 0            # no eval split — CPT trains on all data
load_best_model_at_end: false

# ── Weights & Biases ──────────────────────────────────────────────────────
wandb_project: virtual_agama
wandb_run_id: gemma4-31b-fft-stage1

# ── Benchmark First! ──────────────────────────────────────────────────────
# Before committing full budget, run a quick throughput test:
#   1. Set max_steps: 50
#   2. Launch training, note tokens/sec from logs
#   3. Calculate: 1,013,000,000 / tok_per_sec / 3600 * 8.94 = total cost
#   4. If over $500, options:
#      a) Train on domain only (760M tok) — skip GK mix, add at SFT stage
#      b) Stretch budget $50-100 — worth it for FFT quality over QLoRA

workspace/data/axolotl_output/gemma4-31b-cpt

This model is a fine-tuned version of google/gemma-4-31B on the AiForgeMaster/gemma4-31b-cpt-data dataset.

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 1
  • eval_batch_size: 1
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 8
  • gradient_accumulation_steps: 16
  • total_train_batch_size: 128
  • total_eval_batch_size: 8
  • optimizer: Use OptimizerNames.ADAMW_BNB with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: cosine
  • lr_scheduler_warmup_steps: 16
  • training_steps: 1624

Training results

Framework versions

  • Transformers 5.5.4
  • Pytorch 2.10.0+cu128
  • Datasets 4.8.4
  • Tokenizers 0.22.2
Downloads last month
7
Safetensors
Model size
1.46M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for AiForgeMaster/gemma4-31b-cpt

Finetuned
(40)
this model

Dataset used to train AiForgeMaster/gemma4-31b-cpt

Papers for AiForgeMaster/gemma4-31b-cpt