--- library_name: transformers license: apache-2.0 base_model: google/gemma-4-31B tags: - generated_from_trainer datasets: - AiForgeMaster/gemma4-31b-cpt-data model-index: - name: workspace/data/axolotl_output/gemma4-31b-cpt results: [] --- [Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)
See axolotl config axolotl version: `0.16.0.dev0` ```yaml # ══════════════════════════════════════════════════════════════════════════════ # Axolotl — Full Fine-Tuning Continued Pre-Training # Model: Gemma 4 31B Dense (google/gemma-4-31B) — all parameters trainable # GPUs: 8× A100 80GB SXM (NVLink) — DeepSpeed ZeRO-3 # Data: 321,196 chunks | 75% domain (Vedic/SPH) + 25% FineWeb-Edu # Tokens: ~1.013B | Sequence: 4096 | 1 epoch # Cost: 6× $1.49/hr = $8.94/hr → est. 55-75 hrs → $490-670 # # Launch: # PYTORCH_ALLOC_CONF=expandable_segments:True accelerate launch --num_processes 8 -m axolotl.cli.train axolotl_cpt.yml > train.log 2>&1 # # References (verified): # - MEDITRON-70B (EPFL, arXiv:2311.16079): FFT CPT, LR=1.5e-4, 48B tok # - Me-LLaMA-70B (UF, PMC/11142305): FFT CPT, LR=8e-6, 129B tok # - Biderman et al. (TMLR 2024, arXiv:2405.09673): FFT > LoRA for CPT # ══════════════════════════════════════════════════════════════════════════════ # ── Model (FFT — no adapter, no quantization) ───────────────────────────── # Use the BASE (pre-trained) model, NOT instruction-tuned (-it). base_model: google/gemma-4-31B model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer trust_remote_code: true hf_use_auth_token: true # ── DeepSpeed ZeRO-3 ────────────────────────────────────────────────────── # Shards weights, gradients, and optimizer states across 6 GPUs. # Per-GPU: ~62GB sharded model state + activations → fits 80GB with grad ckpt. # Config ships with Axolotl — no custom JSON needed. deepspeed: /workspace/axolotl/deepspeed_configs/zero3_bf16.json # ── Dataset ──────────────────────────────────────────────────────────────── # Loaded from HuggingFace Hub — pre-shuffled: 75% domain (Vedic/SPH) + 25% FineWeb-Edu. # Pre-chunked to ≤4096 Gemma tokens with 256-tok intra-doc overlap. # type: completion → each {"text": "..."} line is one sample, loss on all tokens. # (Do NOT use type: pretrain — that re-concatenates, breaking our chunking.) datasets: - path: AiForgeMaster/gemma4-31b-cpt-data type: completion split: train dataset_prepared_path: /workspace/axolotl/axolotl_cache/cpt # ── Sequence & Packing ───────────────────────────────────────────────────── sequence_len: 4096 sample_packing: true # packs shorter chunks together — eliminates padding waste pad_to_sequence_len: true gemma4_hybrid_attn_impl: true # FA2 on sliding (head_dim=256) layers, SDPA on global (head_dim=512) layers; sets flash_attention internally # ── Training Hyperparameters ─────────────────────────────────────────────── num_epochs: 1 # one epoch — ~1,932 steps over 1.013B tokens # Effective batch = micro_batch × grad_accum × 8 GPUs = 1 × 16 × 8 = 128 samples # → ~524K tokens/step micro_batch_size: 1 # mbs=2 OOM'd on transient all-gather (8 GB failed alloc) — stay at mbs=1 gradient_accumulation_steps: 16 chunked_cross_entropy: true # avoids materializing full (B,S,V) logits tensor plugins: - axolotl.integrations.liger.LigerPlugin liger_glu_activation: true # fused GEGLU MLP for Gemma 4 (Triton) liger_rms_norm: false # keep existing fused_attn.py RMSNorm patch liger_rope: false # Gemma 4 incompatible (separate q/k) liger_cross_entropy: false # chunked_cross_entropy handles this liger_fused_linear_cross_entropy: false # Gemma 4 incompatible optimizer: adamw_bnb_8bit # 8-bit Adam — ~6 bytes/param opt state instead of 12; saves ~23 GB/GPU lr_scheduler: cosine learning_rate: 5e-5 # conservative for FFT CPT on 31B with 1B tokens # verified range: 5e-6 (Biderman) to 1.5e-4 (MEDITRON) # 5e-5 balances learning vs forgetting for our data scale weight_decay: 0.1 # standard for FFT with AdamW (MEDITRON used 0.1) max_grad_norm: 1.0 warmup_ratio: 0.01 # ~25-33 warmup steps before cosine decay # ── Precision & Memory ───────────────────────────────────────────────────── bf16: true tf32: true gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # DeepSpeed compatibility (per Axolotl docs) # ── Output & Checkpointing ───────────────────────────────────────────────── # /workspace/data is on a 16 TB volume — full DS checkpoints (~310 GB each) fit fine. output_dir: /workspace/data/axolotl_output/gemma4-31b-cpt logging_steps: 10 save_only_model: false # save optimizer + scheduler + RNG for exact resume saves_per_epoch: 4 # every ~25% of epoch (~every 483 steps / ~14 hrs) save_total_limit: 2 # keep latest 2 (briefly 3 during write) — ~900 GB peak on disk # Resume: accelerate launch ... axolotl_cpt.yml --resume_from_checkpoint # Infra switch (different GPU count): run zero_to_fp32.py on old checkpoint, # then start fresh — optimizer state resets, loss wobbles briefly then recovers. val_set_size: 0 # no eval split — CPT trains on all data load_best_model_at_end: false # ── Weights & Biases ────────────────────────────────────────────────────── wandb_project: virtual_agama wandb_run_id: gemma4-31b-fft-stage1 # ── Benchmark First! ────────────────────────────────────────────────────── # Before committing full budget, run a quick throughput test: # 1. Set max_steps: 50 # 2. Launch training, note tokens/sec from logs # 3. Calculate: 1,013,000,000 / tok_per_sec / 3600 * 8.94 = total cost # 4. If over $500, options: # a) Train on domain only (760M tok) — skip GK mix, add at SFT stage # b) Stretch budget $50-100 — worth it for FFT quality over QLoRA ```

# workspace/data/axolotl_output/gemma4-31b-cpt This model is a fine-tuned version of [google/gemma-4-31B](https://huggingface.co/google/gemma-4-31B) on the AiForgeMaster/gemma4-31b-cpt-data dataset. ## Model description More information needed ## Intended uses & limitations More information needed ## Training and evaluation data More information needed ## Training procedure ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 5e-05 - train_batch_size: 1 - eval_batch_size: 1 - seed: 42 - distributed_type: multi-GPU - num_devices: 8 - gradient_accumulation_steps: 16 - total_train_batch_size: 128 - total_eval_batch_size: 8 - optimizer: Use OptimizerNames.ADAMW_BNB with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments - lr_scheduler_type: cosine - lr_scheduler_warmup_steps: 16 - training_steps: 1624 ### Training results ### Framework versions - Transformers 5.5.4 - Pytorch 2.10.0+cu128 - Datasets 4.8.4 - Tokenizers 0.22.2