datasysdev commited on
Commit
5b14326
·
verified ·
1 Parent(s): 40551f7

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +99 -103
README.md CHANGED
@@ -9,6 +9,7 @@ tags:
9
  - dit
10
  - qwen
11
  - math-reasoning
 
12
  datasets:
13
  - AI-MO/NuminaMath-CoT
14
  base_model:
@@ -18,64 +19,44 @@ base_model:
18
  # Continuous Latent Speculative Decoding (CLSD)
19
 
20
  **Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
21
- **Target**: SOTA mathematical reasoning via continuous latent speculative decoding
22
- **Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer
23
 
24
  ---
25
 
26
  ## Thesis
27
 
28
- Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a
29
- hybrid causal Diffusion Transformer (DiT) — a strided 12-layer slice of Qwen3.5-9B —
30
- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier.
31
- Both models share the exact same 4096-dimensional manifold, the same tokenizer,
32
- and the same attention geometry. No projection bridges, no dimensional translation loss.
33
 
34
- Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8
35
- standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern.
36
- The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear
37
- recurrence is strictly causal by design and cannot be flipped to bidirectional.
38
 
39
- The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps.
40
- The verifier evaluates them in a single batched forward pass. The DiT is aligned via
41
- Cross-Entropy backpropagation through the frozen verifier.
42
 
43
- > **Why causal diffusion works**: The conditioning vector C is injected via adaLN into
44
- > every position simultaneously, providing global context regardless of attention mask.
45
- > Token 1 does not need to see token 128 -- C already carries the full prompt context.
46
- > The causal constraint actually forces the DiT to learn autoregressive-like internal
47
- > logic, which mirrors the frozen verifier expectations.
48
 
49
  ---
50
 
51
  ## Architecture
52
 
53
- ### Models
54
-
55
- | Role | Model | Params | Dim | Layers | Attn Heads | KV Heads |
56
- |------|-------|--------|-----|--------|-----------|----------|
57
- | **Generator (DiT)** | Qwen3.5-9B -> strided 12-layer slice | ~4.0B | 4096 | 12 | 16 | 4 |
58
- | **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 | 16 | 4 |
59
 
60
  ### The Strided Graft
61
 
62
  ```
63
  Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
64
  Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
65
- DiT indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
66
 
67
  D = DeltaNet (linear_attention), A = full_attention
68
- Result: 9 DeltaNet + 3 full_attention layers
69
  ```
70
 
71
- ### Modifications to Grafted Layers
72
-
73
- 1. **Strip the LM head** -- the DiT outputs continuous embeddings, not logits
74
- 2. **Keep causal masking** -- preserves 100% of pre-trained weight integrity
75
- 3. **Inject adaLN-Zero modulators** -- one per block, nn.Linear(4096, 24576)
76
- 4. **Zero-initialize** -- at step 0 the network acts as identity
77
- 5. **Timestep conditioning** -- sinusoidal embedding + conditioning vector C
78
- 6. **Learned local positional embedding** -- nn.Parameter(zeros(1, 128, 4096))
79
 
80
  ---
81
 
@@ -83,83 +64,105 @@ Result: 9 DeltaNet + 3 full_attention layers
83
 
84
  ### Pre-Flight: Embedding Extraction
85
 
86
- Target embeddings pre-computed from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought reasoning):
87
- - Tokenize reasoning paths with Qwen tokenizer
88
- - Lookup embeddings via Qwen3.5-9B frozen embedding matrix E (248320 x 4096)
89
- - Chunk into fixed 128-token windows
90
- - Save as [64, 128, 4096] safetensors shards
91
-
92
- **Result**: 2,294 shard files x 64 chunks = **146,790 total chunks** (~144 GB)
93
-
94
- ### Stage A: Rectified Flow (Velocity Regression)
95
 
96
- Teach the DiT the straight-line velocity field from noise to embeddings using Rectified Flow:
97
 
98
- x_t = (1 - t) * x_0 + t * x_1, t in [0, 1]
99
 
100
- L_RF = ||v_theta(x_t, t, C) - (x_1 - x_0)||^2
101
-
102
- | Property | DDPM + LCM (old) | Rectified Flow (this work) |
103
- |----------|-------------------|---------------------------|
104
- | Training objective | Noise prediction | Velocity prediction (v) |
105
- | Trajectory shape | Curved (needs 1000 steps) | **Straight line** |
106
- | Distillation required? | Yes | **No** |
107
- | Native inference steps | 2 (after distillation) | **1-2 Euler steps natively** |
108
-
109
- **This release**: Stage A trained on 1x NVIDIA B200 for 50,000 steps:
110
 
111
  | Parameter | Value |
112
  |-----------|-------|
113
- | Optimizer | AdamW (lr=1e-4, warmup 100 steps, cosine decay) |
114
- | Batch size | 32 |
115
  | Steps | 50,000 |
 
 
116
  | Wall-clock | 154.8 minutes |
117
- | Final MSE loss | ~0.013 (converged by step 5K) |
118
- | Checkpoints included | 5K, 10K, 20K, 30K, 40K, final |
119
 
120
- ### Stage C: CE Alignment (Next)
121
 
122
- Shift the DiT from outputs that look like embeddings to outputs that make
123
- the 9B verifier produce correct tokens:
124
 
125
  ```
126
- z ~ N(0,I) -> DiT(z, C) -> [2 Euler steps] -> X (128x4096)
127
- -> Qwen_frozen(X, past_kv) -> logits (128x248320)
128
  ```
129
 
130
- L_total = alpha * CE(logits, targets) + beta * MSE(X, E(targets))
 
 
 
 
 
 
131
 
132
- - alpha = 1.0 (CE drives alignment)
133
- - beta = 0.1 -> 0 over training (MSE regularizer anneals)
134
 
135
  ---
136
 
137
- ## Live Inference (Target)
138
 
139
- 1. User submits a reasoning prompt
140
- 2. 9B Verifier runs forward pass -> extracts C (4096-d) + KV cache
141
- 3. DiT samples 32 noise vectors, generates 32 candidate 128-token branches in **2 Euler steps**
142
- 4. 9B Verifier evaluates all 32 branches in one batched forward pass
143
- 5. **Causal Guillotine**: Scan Top-1 draft left-to-right, truncate at first position where log-prob drops below threshold
144
- 6. Qwen samples the correct token, new C generated, loop repeats
 
145
 
146
  **Target latency**: <500ms per 128-token block
147
 
148
  ---
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  ## Repository Contents
151
 
152
  ```
153
- embeddings/ # Pre-computed NuminaMath-CoT embeddings (146K chunks)
154
- batch_0000.safetensors # Each: [64, 128, 4096]
155
- ...
156
  checkpoints/
157
- dit_stage_a_step_5000.pt
158
- dit_stage_a_step_10000.pt
159
- dit_stage_a_step_20000.pt
160
- dit_stage_a_step_30000.pt
161
- dit_stage_a_step_40000.pt
162
- dit_stage_a_final.pt # 50K steps, converged
 
163
  ```
164
 
165
  ### Loading a Checkpoint
@@ -169,36 +172,29 @@ from clsd.grafted_dit import graft_dit_from_qwen, STRIDE_INDICES
169
  from transformers import AutoModelForCausalLM
170
  import torch
171
 
172
- # Build the DiT architecture
173
  qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
174
  dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
175
-
176
- # Load trained weights
177
  state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
178
  dit.load_state_dict(state_dict)
179
  ```
180
 
181
  ---
182
 
183
- ## Key Architectural Decisions
184
 
185
- 1. **Shared 4096-d space**: Generator and verifier operate in the same embedding geometry natively. No projection layers, no information bottlenecks.
186
- 2. **Strided layer slice**: DiT inherits geometric knowledge from early, middle, and late layers of the 9B.
187
- 3. **Rectified Flow over DDPM**: Linear trajectories -> no distillation stage -> native 2-step generation.
188
- 4. **Instruct/Instruct architecture**: Both drafter and verifier sliced from the same model. Zero distributional gap at initialization.
189
- 5. **Monte Carlo parallel search**: 32 branches x 128 tokens = 4,096 candidate tokens per inference step.
 
 
 
190
 
191
- ---
192
-
193
- ## Citation
194
 
195
- ```bibtex
196
- @misc{clsd2026,
197
- title={Continuous Latent Speculative Decoding: A Hybrid Causal DiT for Parallel Reasoning},
198
- year={2026},
199
- url={https://huggingface.co/datasysdev/clsd}
200
- }
201
- ```
202
 
203
  ## License
204
 
 
9
  - dit
10
  - qwen
11
  - math-reasoning
12
+ - deltanet
13
  datasets:
14
  - AI-MO/NuminaMath-CoT
15
  base_model:
 
19
  # Continuous Latent Speculative Decoding (CLSD)
20
 
21
  **Architecture**: ~4.0B Hybrid Causal DiT (Rectified Flow) + 9B Frozen Verifier
22
+ **Key Innovation**: First hybrid DeltaNet/Attention causal diffusion transformer for parallel token generation
23
+ **Status**: Stage A converged, Stage C alignment in progress
24
 
25
  ---
26
 
27
  ## Thesis
28
 
29
+ Autoregressive language models are bottlenecked by sequential generation. CLSD deploys a hybrid causal Diffusion Transformer (DiT) -- a strided 12-layer slice of Qwen3.5-9B -- operating in the continuous embedding space of the same frozen Qwen3.5-9B verifier. Both models share the exact same 4096-dimensional manifold, the same tokenizer, and the same attention geometry. No projection bridges, no dimensional translation loss.
 
 
 
 
30
 
31
+ Qwen3.5-9B uses a hybrid architecture: 24 Gated DeltaNet (linear attention) layers + 8 standard quadratic attention layers in a repeating [3xDeltaNet, 1xAttention] pattern. The DiT preserves this hybrid structure and keeps **causal masking** -- DeltaNet linear recurrence is strictly causal by design.
 
 
 
32
 
33
+ The DiT drafts 32 candidate 128-token embedding sequences simultaneously in 2 Euler steps. The verifier evaluates them in a single batched forward pass.
 
 
34
 
35
+ > **Why causal diffusion works**: The conditioning vector C is injected via adaLN into every position simultaneously, providing global context regardless of attention mask. The causal constraint forces the DiT to learn autoregressive-like internal logic, which mirrors the frozen verifier expectations.
 
 
 
 
36
 
37
  ---
38
 
39
  ## Architecture
40
 
41
+ | Role | Model | Params | Dim | Layers |
42
+ |------|-------|--------|-----|--------|
43
+ | **Generator (DiT)** | Qwen3.5-9B strided slice | ~4.0B | 4096 | 12 (9 DeltaNet + 3 FullAttn) |
44
+ | **Verifier (frozen)** | Qwen3.5-9B (text tower) | 9B | 4096 | 32 |
 
 
45
 
46
  ### The Strided Graft
47
 
48
  ```
49
  Source layers: [0, 3, 6, 9, 12, 15, 18, 21, 24, 26, 28, 31]
50
  Layer types: [D, A, D, D, D, A, D, D, D, D, D, A ]
 
51
 
52
  D = DeltaNet (linear_attention), A = full_attention
 
53
  ```
54
 
55
+ ### DiT Modifications
56
+ 1. **adaLN-Zero modulators** per block: nn.Linear(4096, 24576), zero-initialized
57
+ 2. **Timestep conditioning**: sinusoidal embedding + conditioning vector C
58
+ 3. **Learned local positional embedding**: nn.Parameter(zeros(1, 128, 4096))
59
+ 4. Causal masking preserved from original Qwen weights
 
 
 
60
 
61
  ---
62
 
 
64
 
65
  ### Pre-Flight: Embedding Extraction
66
 
67
+ Target embeddings from **AI-MO/NuminaMath-CoT** (mathematical chain-of-thought):
68
+ - Tokenized with Qwen tokenizer, embeddings looked up via frozen embedding matrix
69
+ - Chunked into 128-token windows: [64, 128, 4096] safetensors shards
70
+ - **146,790 total chunks** across 2,294 files
 
 
 
 
 
71
 
72
+ ### Stage A: Rectified Flow (Velocity Regression) -- COMPLETE
73
 
74
+ The DiT learns the straight-line velocity field v = x1 - x0:
75
 
76
+ ```
77
+ x_t = (1-t)*noise + t*target, t in [0,1]
78
+ L = ||v_pred - (target - noise)||^2
79
+ ```
 
 
 
 
 
 
80
 
81
  | Parameter | Value |
82
  |-----------|-------|
83
+ | Hardware | 1x NVIDIA B200 (183 GB) |
 
84
  | Steps | 50,000 |
85
+ | Batch size | 32 |
86
+ | Optimizer | AdamW (lr=1e-4, cosine decay) |
87
  | Wall-clock | 154.8 minutes |
88
+ | Final MSE | ~0.013 (converged by step 5K) |
 
89
 
90
+ ### Stage C: CE Alignment -- IN PROGRESS
91
 
92
+ Backpropagate through the frozen 9B verifier to teach the DiT semantic correctness:
 
93
 
94
  ```
95
+ noise -> DiT (2 Euler steps) -> draft_embeds
96
+ -> frozen Qwen 32 layers -> logits -> CE loss vs ground truth tokens
97
  ```
98
 
99
+ L_total = CE(logits, targets) + beta * MSE(drafts, true_embeddings)
100
+
101
+ Beta anneals from 0.1 to 0, gradually shifting from geometric to semantic alignment.
102
+
103
+ **Smoke test results** (50 steps, batch=1):
104
+ - CE dropped 12.8 -> 6.1: verifier starting to read DiT output
105
+ - Gradients flow correctly through frozen verifier
106
 
107
+ **Current run**: 2000 steps, batch=8, grad_accum=4 on B200 -- streaming to wandb
 
108
 
109
  ---
110
 
111
+ ## Step 4: Live Inference (The Parallel Rollout)
112
 
113
+ 1. User submits reasoning prompt
114
+ 2. 9B Verifier forward pass -> conditioning vector C + KV cache
115
+ 3. DiT generates **32 candidate 128-token branches** in 2 Euler steps
116
+ 4. 9B Verifier evaluates all 32 branches in one batched pass (shared prompt KV via PagedAttention)
117
+ 5. Score by mean log-probability across 128 positions
118
+ 6. **Causal Guillotine**: scan Top-1 left-to-right, truncate at first low-confidence position
119
+ 7. Qwen samples correct token, new C generated, loop repeats
120
 
121
  **Target latency**: <500ms per 128-token block
122
 
123
  ---
124
 
125
+ ## Step 5: The Shadow Loop (Async RL -- Continuous Improvement)
126
+
127
+ The Primary Node never stops drafting. A Shadow Node continuously improves the DiT:
128
+
129
+ ```
130
+ Primary Node --[Redis: 32 trajectories/cycle]--> Shadow Node
131
+ Shadow Node --[Weight sync every 1000 steps]--> Primary Node
132
+ ```
133
+
134
+ ### Objective Verification (Reward Signal)
135
+
136
+ Feed Top-1 decoded tokens through:
137
+ - **Lean 4**: formal mathematical proof verification
138
+ - **Python sandbox**: code execution for correctness
139
+
140
+ If verified -> reward the continuous vectors (positive signal)
141
+ If failed -> penalize (negative signal)
142
+
143
+ This breaks the log-prob echo chamber. The DiT learns "alien intuition" -- solutions the 9B verifier would score as correct but would never stumble upon autoregressively.
144
+
145
+ ### RL Objective
146
+
147
+ Policy gradient from objective verification creates a reward signal independent of the verifier log-probs. The DiT explores the embedding space for novel solutions that:
148
+ 1. The verifier accepts (high log-prob)
149
+ 2. Actually solve the problem (Lean4/sandbox verification)
150
+
151
+ This is an **infinite background process** -- the system improves continuously as long as compute is available.
152
+
153
+ ---
154
+
155
  ## Repository Contents
156
 
157
  ```
 
 
 
158
  checkpoints/
159
+ dit_stage_a_step_5000.pt # Early training
160
+ dit_stage_a_step_10000.pt # Mid training
161
+ dit_stage_a_step_30000.pt # Late training
162
+ dit_stage_a_final.pt # 50K steps, converged (MSE=0.013)
163
+ dit_stage_c_*.pt # CE alignment checkpoints (when available)
164
+ embeddings_sample/ # 50 representative embedding shards
165
+ batch_*.safetensors # Each: [64, 128, 4096]
166
  ```
167
 
168
  ### Loading a Checkpoint
 
172
  from transformers import AutoModelForCausalLM
173
  import torch
174
 
 
175
  qwen = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-9B", dtype=torch.bfloat16)
176
  dit, embed_tokens = graft_dit_from_qwen(qwen, slice_indices=STRIDE_INDICES)
 
 
177
  state_dict = torch.load("checkpoints/dit_stage_a_final.pt", weights_only=True)
178
  dit.load_state_dict(state_dict)
179
  ```
180
 
181
  ---
182
 
183
+ ## Roadmap
184
 
185
+ - [x] Pre-flight: embedding extraction (146K chunks from NuminaMath-CoT)
186
+ - [x] Step 1: Frankenstein graft (4.0B hybrid DiT from 9B)
187
+ - [x] Step 2: Stage A rectified flow (50K steps, converged)
188
+ - [x] Stage C smoke test (50 steps, pipeline validated)
189
+ - [ ] Step 3: Stage C full alignment (2000+ steps on B200)
190
+ - [ ] Step 4: Live inference with Causal Guillotine
191
+ - [ ] Step 5: Shadow Loop async RL with Lean4/sandbox verification
192
+ - [ ] Scale to 8x H200 cluster for production training
193
 
194
+ ## Wandb
 
 
195
 
196
+ - Stage A: [clsd-speedrun](https://wandb.ai/dalletest123/clsd-speedrun)
197
+ - Stage C smoke: [clsd-speedrun-smoke](https://wandb.ai/dalletest123/clsd-speedrun-smoke)
 
 
 
 
 
198
 
199
  ## License
200