LLaDA Distilled 24L Draft Model (v2)

A 24-layer structurally pruned + distilled version of LLaDA-8B-Instruct (32 layers), designed as a draft model for speculative decoding in masked diffusion language models.

Model Details

Teacher Student (this model)
Layers 32 24
Attention heads 32 32
FFN intermediate dim 12,288 9,216 (25% pruned)
Hidden dim (d_model) 4,096 4,096
Vocab size 126,464 126,464
Total params 8.0B 5.36B
Parameter reduction โ€” 33.1%

How It Was Made

Step 1: Structured Pruning (Data-Driven Layer & Neuron Selection)

We ran 12 ablation experiments on the teacher model to identify redundancies:

  • Layer removal: Layers 12โ€“19 (contiguous middle block) were identified as the most expendable via leave-one-out argmax agreement analysis. Removing them preserves ~60% argmax agreement with the full model.
  • FFN pruning: Neuron importance was ranked by mean gated activation magnitude (|silu(ff_proj) ร— up_proj|). The bottom 25% (3,072 / 12,288 neurons) were removed per layer โ€” different neurons for each layer.
  • Attention heads: All 32 heads kept (d_model=4096 requires divisibility by n_heads).

Layers kept (original indices): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] Layers removed: [12, 13, 14, 15, 16, 17, 18, 19]

Step 2: Weight Initialization

Student weights are initialized by directly copying from the teacher:

  • Attention projections (q, k, v, o): copied in full
  • FFN projections (ff_proj, up_proj, ff_out): sliced to keep only the top-9,216 most important neurons
  • Embeddings, LM head, layer norms: copied directly

Step 3: Knowledge Distillation

  • Loss: 0.4 ร— KL(student โˆฅ teacher, T=1.0) + 0.2 ร— CE(student, ground_truth) + 0.4 ร— CE(student, teacher_argmax)
  • The top-1 CE term directly optimizes for speculative decoding acceptance rate
  • Data: C4 (60%) + StarCoderData Python (40%)
  • Masking: Beta(2,2) distributed mask ratios โˆˆ [0.15, 0.85]
  • Optimizer: AdamW (lr=5e-5, cosine schedule, 5% warmup)
  • Precision: bf16 mixed
  • Sequence length: 256
  • Checkpoint: step 6500, argmax agreement = 0.8808

Intended Use

This model is a draft model for speculative decoding with LLaDA-8B-Instruct as the target. It is NOT intended for standalone generation.

In speculative decoding:

  1. This draft model proposes token unmaskings cheaply (24 layers, 9216 FFN vs 32 layers, 12288 FFN)
  2. The full LLaDA-8B teacher verifies proposals in one forward pass
  3. Accepted tokens are kept; rejected ones are resampled from the teacher
  4. Result: same output quality as the teacher, faster wall-clock time

How to Use

from transformers import AutoModel, AutoTokenizer
import torch

# Load as draft model
draft_model = AutoModel.from_pretrained(
    "REPO_ID_HERE",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

# Load teacher (target)
target_model = AutoModel.from_pretrained(
    "GSAI-ML/LLaDA-8B-Instruct",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

tokenizer = AutoTokenizer.from_pretrained("GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True)

# Use in speculative decoding (CtV or VtC)
from generate import generate_speculative
out = generate_speculative(
    target_model, draft_model, input_ids, attention_mask,
    steps=256, gen_length=256, block_length=32,
    verification_frequency=2, threshold=0.95
)

Limitations

  • Not standalone: designed for speculative decoding draft, not direct text generation
  • Same tokenizer as LLaDA-8B-Instruct (128K vocab, mask_id=126336)
  • Distillation is ongoing โ€” later checkpoints may improve
Downloads last month
201
Safetensors
Model size
5B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for jaygala223/llada-distilled-24L-v2

Finetuned
(28)
this model