license: apache-2.0
language:
- en
tags:
- diffusion
- speculative-decoding
- rectified-flow
- dit
- qwen
- math-reasoning
datasets:
- AI-MO/NuminaMath-CoT
base_model:
- Qwen/Qwen3.5-9B
Continuous Latent Speculative Decoding (CLSD)
Architecture: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier Target: SOTA mathematical reasoning via continuous latent speculative decoding Key Innovation: First hybrid DeltaNet/Attention causal diffusion transformer
Thesis
Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) — a strided 12-layer slice of Qwen3.5-9B — operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.
Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps causal masking -- DeltaNet linear recurrence is strictly causal by design and cannot be flipped to bidirectional.
The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass. The DiT is aligned via Cross-Entropy backpropagation through the frozen verifier.
Why causal diffusion works: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. Token 1 does not need to see token 128 -- C already carries the full prompt context. The causal constraint actually forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.
Architecture
Models
| Role | Model | Params | Dim | Layers | Attn Heads | KV Heads |
|---|---|---|---|---|---|---|
| Generator (DiT) | Qwen3.5-9B -> strided 12-layer slice | ~4.0B | 4096 | 12 | 16 | 4 |
| Verifier (frozen) | Qwen3.5-9B (text tower) | 9B | 4096 | 32 | 16 | 4 |
The Strided Graft
Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
DiT indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
D = DeltaNet (linear_attention), A = full_attention
Result: 9 DeltaNet + 3 full_attention layers
Modifications to Grafted Layers
- Strip the LM head -- the DiT outputs continuous embeddings, not logits
- Keep causal masking -- preserves 100% of pre-trained weight integrity
- Inject adaLN-Zero modulators -- one per block, nn.Linear(4096, 24576)
- Zero-initialize -- at step 0 the network acts as identity
- Timestep conditioning -- sinusoidal embedding + conditioning vector C
- Learned local positional embedding -- nn.Parameter(zeros(1, 128, 4096))
Training Pipeline
Pre-Flight: Embedding Extraction
Target embeddings pre-computed from AI-MO/NuminaMath-CoT (mathematical chain-of-thought reasoning):
- Tokenize reasoning paths with Qwen tokenizer
- Lookup embeddings via Qwen3.5-9B frozen embedding matrix E (248320 x 4096)
- Chunk into fixed 128-token windows
- Save as [64, 128, 4096] safetensors shards
Result: 2,294 shard files x 64 chunks = 146,790 total chunks (~144 GB)
Stage A: Rectified Flow (Velocity Regression)
Teach the DiT the straight-line velocity field from noise to embeddings using Rectified Flow:
x_t = (1 - t) * x_0 + t * x_1, t in [0, 1]
L_RF = ||v_theta(x_t, t, C) - (x_1 - x_0)||^2
| Property | DDPM + LCM (old) | Rectified Flow (this work) |
|---|---|---|
| Training objective | Noise prediction | Velocity prediction (v) |
| Trajectory shape | Curved (needs 1000 steps) | Straight line |
| Distillation required? | Yes | No |
| Native inference steps | 2 (after distillation) | 1-2 Euler steps natively |
This release: Stage A trained on 1x NVIDIA B200 for 50,000 steps:
| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=1e-4, warmup 100 steps, cosine decay) |
| Batch size | 32 |
| Steps | 50,000 |
| Wall-clock | 154.8 minutes |
| Final MSE loss | ~0.013 (converged by step 5K) |
| Checkpoints included | 5K, 10K, 20K, 30K, 40K, final |
Stage C: CE Alignment (Next)
Shift the DiT from outputs that look like embeddings to outputs that make the 9B verifier produce correct tokens:
z ~ N(0,I) -> DiT(z, C) -> [2 Euler steps] -> X (128x4096)
-> Qwen_frozen(X, past_kv) -> logits (128x248320)
L_total = alpha * CE(logits, targets) + beta * MSE(X, E(targets))
- alpha = 1.0 (CE drives alignment)
- beta = 0.1 -> 0 over training (MSE regularizer anneals)
Live Inference (Target)
- User submits a reasoning prompt
- 9B Verifier runs forward pass -> extracts C (4096-d) + KV cache
- DiT samples 32 noise vectors, generates 32 candidate 128-token branches in 2 Euler steps
- 9B Verifier evaluates all 32 branches in one batched forward pass
- Causal Guillotine: Scan Top-1 draft left-to-right, truncate at first position where log-prob drops below threshold
- Qwen samples the correct token, new C generated, loop repeats
Target latency: <500ms per 128-token block
Repository Contents
embeddings/ # Pre-computed NuminaMath-CoT embeddings (146K chunks)
batch_0000.safetensors # Each: [64, 128, 4096]
...
checkpoints/
dit_stage_a_step_5000.pt
dit_stage_a_step_10000.pt
dit_stage_a_step_20000.pt
dit_stage_a_step_30000.pt
dit_stage_a_step_40000.pt
dit_stage_a_final.pt # 50K steps, converged
Loading a Checkpoint
from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch
# Build the DiT architecture
qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
# Load trained weights
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)
Key Architectural Decisions
- Shared 4096-d space: Generator and verifier operate in the same embedding geometry natively. No projection layers, no information bottlenecks.
- Strided layer slice: DiT inherits geometric knowledge from early, middle, and late layers of the 9B.
- Rectified Flow over DDPM: Linear trajectories -> no distillation stage -> native 2-step generation.
- Instruct/Instruct architecture: Both drafter and verifier sliced from the same model. Zero distributional gap at initialization.
- Monte Carlo parallel search: 32 branches x 128 tokens = 4,096 candidate tokens per inference step.
Citation
@misc{clsd2026,
title={Continuous Latent Speculative Decoding: A Hybrid Causal DiT for Parallel Reasoning},
year={2026},
url={https://huggingface.co/datasysdev/clsd}
}
License
Apache 2.0