clsd / README.md
datasysdev's picture
Upload README.md with huggingface_hub
3fd6ee7 verified
|
raw
history blame
7.22 kB
metadata
license: apache-2.0
language:
  - en
tags:
  - diffusion
  - speculative-decoding
  - rectified-flow
  - dit
  - qwen
  - math-reasoning
datasets:
  - AI-MO/NuminaMath-CoT
base_model:
  - Qwen/Qwen3.5-9B

Continuous Latent Speculative Decoding (CLSD)

Architecture: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier Target: SOTA mathematical reasoning via continuous latent speculative decoding Key Innovation: First hybrid DeltaNet/Attention causal diffusion transformer


Thesis

Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) — a strided 12-layer slice of Qwen3.5-9B — operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.

Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps causal masking -- DeltaNet linear recurrence is strictly causal by design and cannot be flipped to bidirectional.

The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass. The DiT is aligned via Cross-Entropy backpropagation through the frozen verifier.

Why causal diffusion works: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. Token 1 does not need to see token 128 -- C already carries the full prompt context. The causal constraint actually forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.


Architecture

Models

Role Model Params Dim Layers Attn Heads KV Heads
Generator (DiT) Qwen3.5-9B -> strided 12-layer slice ~4.0B 4096 12 16 4
Verifier (frozen) Qwen3.5-9B (text tower) 9B 4096 32 16 4

The Strided Graft

Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
Layer types:   [D, A, D, D, D,  A,  D,  D,  D,  D,  D,  A ]
DiT indices:   [0, 1, 2, 3, 4,  5,  6,  7,  8,  9,  10, 11]

D = DeltaNet (linear_attention), A = full_attention
Result: 9 DeltaNet + 3 full_attention layers

Modifications to Grafted Layers

  1. Strip the LM head -- the DiT outputs continuous embeddings, not logits
  2. Keep causal masking -- preserves 100% of pre-trained weight integrity
  3. Inject adaLN-Zero modulators -- one per block, nn.Linear(4096, 24576)
  4. Zero-initialize -- at step 0 the network acts as identity
  5. Timestep conditioning -- sinusoidal embedding + conditioning vector C
  6. Learned local positional embedding -- nn.Parameter(zeros(1, 128, 4096))

Training Pipeline

Pre-Flight: Embedding Extraction

Target embeddings pre-computed from AI-MO/NuminaMath-CoT (mathematical chain-of-thought reasoning):

  • Tokenize reasoning paths with Qwen tokenizer
  • Lookup embeddings via Qwen3.5-9B frozen embedding matrix E (248320 x 4096)
  • Chunk into fixed 128-token windows
  • Save as [64, 128, 4096] safetensors shards

Result: 2,294 shard files x 64 chunks = 146,790 total chunks (~144 GB)

Stage A: Rectified Flow (Velocity Regression)

Teach the DiT the straight-line velocity field from noise to embeddings using Rectified Flow:

x_t = (1 - t) * x_0 + t * x_1, t in [0, 1]

L_RF = ||v_theta(x_t, t, C) - (x_1 - x_0)||^2

Property DDPM + LCM (old) Rectified Flow (this work)
Training objective Noise prediction Velocity prediction (v)
Trajectory shape Curved (needs 1000 steps) Straight line
Distillation required? Yes No
Native inference steps 2 (after distillation) 1-2 Euler steps natively

This release: Stage A trained on 1x NVIDIA B200 for 50,000 steps:

Parameter Value
Optimizer AdamW (lr=1e-4, warmup 100 steps, cosine decay)
Batch size 32
Steps 50,000
Wall-clock 154.8 minutes
Final MSE loss ~0.013 (converged by step 5K)
Checkpoints included 5K, 10K, 20K, 30K, 40K, final

Stage C: CE Alignment (Next)

Shift the DiT from outputs that look like embeddings to outputs that make the 9B verifier produce correct tokens:

z ~ N(0,I) -> DiT(z, C) -> [2 Euler steps] -> X (128x4096)
  -> Qwen_frozen(X, past_kv) -> logits (128x248320)

L_total = alpha * CE(logits, targets) + beta * MSE(X, E(targets))

  • alpha = 1.0 (CE drives alignment)
  • beta = 0.1 -> 0 over training (MSE regularizer anneals)

Live Inference (Target)

  1. User submits a reasoning prompt
  2. 9B Verifier runs forward pass -> extracts C (4096-d) + KV cache
  3. DiT samples 32 noise vectors, generates 32 candidate 128-token branches in 2 Euler steps
  4. 9B Verifier evaluates all 32 branches in one batched forward pass
  5. Causal Guillotine: Scan Top-1 draft left-to-right, truncate at first position where log-prob drops below threshold
  6. Qwen samples the correct token, new C generated, loop repeats

Target latency: <500ms per 128-token block


Repository Contents

embeddings/                    # Pre-computed NuminaMath-CoT embeddings (146K chunks)
  batch_0000.safetensors       # Each: [64, 128, 4096]
  ...
checkpoints/
  dit_stage_a_step_5000.pt
  dit_stage_a_step_10000.pt
  dit_stage_a_step_20000.pt
  dit_stage_a_step_30000.pt
  dit_stage_a_step_40000.pt
  dit_stage_a_final.pt         # 50K steps, converged

Loading a Checkpoint

from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
from transformers import AutoModelForCausalLM
import torch

# Build the DiT architecture
qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)

# Load trained weights
state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
dit.load_state_dict(state_dict)

Key Architectural Decisions

  1. Shared 4096-d space: Generator and verifier operate in the same embedding geometry natively. No projection layers, no information bottlenecks.
  2. Strided layer slice: DiT inherits geometric knowledge from early, middle, and late layers of the 9B.
  3. Rectified Flow over DDPM: Linear trajectories -> no distillation stage -> native 2-step generation.
  4. Instruct/Instruct architecture: Both drafter and verifier sliced from the same model. Zero distributional gap at initialization.
  5. Monte Carlo parallel search: 32 branches x 128 tokens = 4,096 candidate tokens per inference step.

Citation

@misc{clsd2026,
  title={Continuous Latent Speculative Decoding: A Hybrid Causal DiT for Parallel Reasoning},
  year={2026},
  url={https://huggingface.co/datasysdev/clsd}
}

License

Apache 2.0