Partial-BP 1B checkpoints โ€” Which Layers Need Backpropagation?

Companion checkpoints for the preprint by Raen2264 (pseudonym). Paper: Zenodo DOI 10.5281/zenodo.20392068 ยท Code: github.com/2264K/hybrid-zo-pretrain

Llama-architecture 1B, trained from scratch on FineWeb-Edu (sample-10BT), GPT-2 tokenizer (vocab 50257), 10B tokens, seed 42.

Files (all bf16 model state_dict, no optimizer)

file what
frozen_pos6_10b_s42_init.pt random init (BEFORE training), seed 42
frozen_pos6_10b_s42_final.pt frozen-partial final โ€” BP window = layers [6,12) trained, all other layers frozen at random init
fullbp_10b_s42_final.pt full backprop final โ€” all layers trained

Model config

LlamaConfig(vocab_size=50257, hidden_size=2048, intermediate_size=5632,
            num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16,
            max_position_embeddings=512, tie_word_embeddings=False)

โš ๏ธ Note on the init checkpoint (important)

The init was not saved during the original run; it is reproduced via torch.manual_seed(42) + LlamaForCausalLM(config).to(bfloat16), and verified byte-exact:

  • In frozen-partial training, the frozen-region layers (0โ€“5, 12โ€“23 + embed / final-norm / lm_head) are never updated, so they are the original init. We confirmed all 162 frozen-region tensors of the regenerated seed-42 model match frozen_pos6_10b_s42_final.pt exactly (torch 2.9, bf16).
  • Because the initialization is fully deterministic, the trained-region layers (6โ€“11) of the regenerated seed-42 model are therefore also the exact original init. Hence frozen_pos6_10b_s42_init.pt is the byte-exact pre-training weight set.
  • If you regenerate the init on a different torch version, use the frozen-region equality check above to confirm your RNG matches before relying on layers 6โ€“11.

Load

import torch
from transformers import LlamaForCausalLM, LlamaConfig
cfg = LlamaConfig(vocab_size=50257, hidden_size=2048, intermediate_size=5632,
                  num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16,
                  max_position_embeddings=512, tie_word_embeddings=False)
m = LlamaForCausalLM(cfg)
m.load_state_dict(torch.load("frozen_pos6_10b_s42_final.pt", map_location="cpu"))

Headline result (10B, FineWeb-Edu val PPL)

frozen-partial 35.2 < full BP 39.0 โ€” backpropagating a well-chosen ~25% of layers (early-middle window [6,12)) while freezing the rest at random init beats tuned full backprop at the 10B-token horizon.


Author name Raen2264 is a pseudonym โ€” please retain it. License: Apache-2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support