LLaDA Distilled 24L Draft Model (v2)
A 24-layer structurally pruned + distilled version of LLaDA-8B-Instruct (32 layers), designed as a draft model for speculative decoding in masked diffusion language models.
Model Details
| Teacher | Student (this model) | |
|---|---|---|
| Layers | 32 | 24 |
| Attention heads | 32 | 32 |
| FFN intermediate dim | 12,288 | 9,216 (25% pruned) |
| Hidden dim (d_model) | 4,096 | 4,096 |
| Vocab size | 126,464 | 126,464 |
| Total params | 8.0B | 5.36B |
| Parameter reduction | โ | 33.1% |
How It Was Made
Step 1: Structured Pruning (Data-Driven Layer & Neuron Selection)
We ran 12 ablation experiments on the teacher model to identify redundancies:
- Layer removal: Layers 12โ19 (contiguous middle block) were identified as the most expendable via leave-one-out argmax agreement analysis. Removing them preserves ~60% argmax agreement with the full model.
- FFN pruning: Neuron importance was ranked by mean gated activation magnitude (
|silu(ff_proj) ร up_proj|). The bottom 25% (3,072 / 12,288 neurons) were removed per layer โ different neurons for each layer. - Attention heads: All 32 heads kept (d_model=4096 requires divisibility by n_heads).
Layers kept (original indices): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] Layers removed: [12, 13, 14, 15, 16, 17, 18, 19]
Step 2: Weight Initialization
Student weights are initialized by directly copying from the teacher:
- Attention projections (q, k, v, o): copied in full
- FFN projections (ff_proj, up_proj, ff_out): sliced to keep only the top-9,216 most important neurons
- Embeddings, LM head, layer norms: copied directly
Step 3: Knowledge Distillation
- Loss: 0.4 ร KL(student โฅ teacher, T=1.0) + 0.2 ร CE(student, ground_truth) + 0.4 ร CE(student, teacher_argmax)
- The top-1 CE term directly optimizes for speculative decoding acceptance rate
- Data: C4 (60%) + StarCoderData Python (40%)
- Masking: Beta(2,2) distributed mask ratios โ [0.15, 0.85]
- Optimizer: AdamW (lr=5e-5, cosine schedule, 5% warmup)
- Precision: bf16 mixed
- Sequence length: 256
- Checkpoint: step 6500, argmax agreement = 0.8808
Intended Use
This model is a draft model for speculative decoding with LLaDA-8B-Instruct as the target. It is NOT intended for standalone generation.
In speculative decoding:
- This draft model proposes token unmaskings cheaply (24 layers, 9216 FFN vs 32 layers, 12288 FFN)
- The full LLaDA-8B teacher verifies proposals in one forward pass
- Accepted tokens are kept; rejected ones are resampled from the teacher
- Result: same output quality as the teacher, faster wall-clock time
How to Use
from transformers import AutoModel, AutoTokenizer
import torch
# Load as draft model
draft_model = AutoModel.from_pretrained(
"REPO_ID_HERE",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
# Load teacher (target)
target_model = AutoModel.from_pretrained(
"GSAI-ML/LLaDA-8B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
tokenizer = AutoTokenizer.from_pretrained("GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True)
# Use in speculative decoding (CtV or VtC)
from generate import generate_speculative
out = generate_speculative(
target_model, draft_model, input_ids, attention_mask,
steps=256, gen_length=256, block_length=32,
verification_frequency=2, threshold=0.95
)
Limitations
- Not standalone: designed for speculative decoding draft, not direct text generation
- Same tokenizer as LLaDA-8B-Instruct (128K vocab, mask_id=126336)
- Distillation is ongoing โ later checkpoints may improve
- Downloads last month
- 201
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support
Model tree for jaygala223/llada-distilled-24L-v2
Base model
GSAI-ML/LLaDA-8B-Instruct