WCNegentropy commited on
Commit
0e50ac6
·
verified ·
1 Parent(s): a3f3b9e

Remove bit_transformer_lm_codex_playbook.md - cleanup for OS launch

Browse files
Files changed (1) hide show
  1. bit_transformer_lm_codex_playbook.md +0 -278
bit_transformer_lm_codex_playbook.md DELETED
@@ -1,278 +0,0 @@
1
- ---
2
-
3
- # 🧭 BitTransformerLM Codex Playbook (Merged)
4
-
5
- A single, actionable playbook that **implements optimizations first**, then **trains/ships the models**. Drop these prompts into your Codex/agent and run top-to-bottom.
6
-
7
- ---
8
-
9
- ## Phase 1 — Training Loop & Runtime Optimizations (apply these first)
10
-
11
- ### Task 1 — Make batch size configurable & fix OneCycle accounting — COMPLETED ✅
12
-
13
- **Prompt:**
14
-
15
- ```bash
16
- codex run bittransformerlm/patch \
17
- --file bit_transformer/training.py \
18
- --edit "Replace data.split(8) with DataLoader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, persistent_workers=True); compute steps_per_epoch=len(loader); set total_updates=epochs*(steps_per_epoch+extra_steps); pass total_updates into configure_optimizer"
19
- ```
20
-
21
- ✅ OneCycle’s horizon matches reality across runs.
22
-
23
- ---
24
-
25
- ### Task 2 — Remove hardcoded `total_steps=100` in dashboard/MCP — COMPLETED ✅
26
-
27
- **Prompt:**
28
-
29
- ```bash
30
- codex run bittransformerlm/patch \
31
- --file dashboard/manager.py \
32
- --edit "When (re)creating OneCycleLR after init/scale_up/download, use computed total_steps from the upcoming training plan instead of hardcoded 100"
33
- ```
34
-
35
- ✅ Aligns scheduler behavior between direct loop and MCP/dashboard.
36
-
37
- ---
38
-
39
- ### Task 3 — Add mixed-precision autocast (AMP, BF16) — COMPLETED ✅
40
-
41
- **Prompt (pseudo-patch):**
42
-
43
- ```python
44
- with torch.amp.autocast(device_type=("cuda" if torch.cuda.is_available() else "cpu"), dtype=torch.bfloat16):
45
- logits = model(batch)
46
- loss = criterion(logits, labels)
47
- loss.backward()
48
- ```
49
-
50
- ✅ 1.2–1.8× throughput on attention-heavy training. Keep grad-clip.
51
-
52
- ---
53
-
54
- ### Task 4 — Add gradient accumulation — COMPLETED ✅
55
-
56
- **Prompt:**
57
-
58
- ```bash
59
- codex run bittransformerlm/patch \
60
- --file bit_transformer/training.py \
61
- --edit "Introduce --accum_steps; scale loss by 1/accum_steps; optimizer.step() every accum_steps; scheduler.step() every accum_steps"
62
- ```
63
-
64
- ✅ Simulates larger effective batch sizes without extra memory.
65
-
66
- ---
67
-
68
- ### Task 5 — Optimize dataset pipeline (mmap + streaming) — COMPLETED ✅
69
-
70
- **Prompt:**
71
-
72
- ```bash
73
- codex run bittransformerlm/patch \
74
- --file data/wikitext_schedule.py \
75
- --edit "Precompute text->bit tensors aligned to max_seq_len; store in memory-mapped file; implement Dataset with __len__/__getitem__; use DataLoader(num_workers>0, persistent_workers=True)"
76
- ```
77
-
78
- ✅ Removes conversion bottlenecks on large corpora.
79
-
80
- ---
81
-
82
- ### Task 6 — Schedule compression probability (safer ramp) — COMPLETED ✅
83
-
84
- **Prompt (pseudo-code):**
85
-
86
- ```python
87
- compress_prob = cosine_ramp(global_step, start=0.0, end=0.5, total_steps=warmup_steps)
88
- ```
89
-
90
- ✅ Prevents early instability from aggressive compression.
91
-
92
- ---
93
-
94
- ### Task 7 — Stabilize safety gate (EMA + burn‑in) — COMPLETED ✅
95
-
96
- **Prompt (pseudo-patch):**
97
-
98
- ```python
99
- ema_val = ema(val_loss, decay=0.9)
100
- if step < burn_in_steps:
101
- allow_training = True
102
- elif ema_val > threshold:
103
- trigger_gate()
104
- ```
105
-
106
- ✅ Reduces false positives from noisy early validations.
107
-
108
- ---
109
-
110
- ### Task 8 — Enable `torch.compile` selectively — COMPLETED ✅
111
-
112
- **Prompt:**
113
-
114
- ```bash
115
- codex run bittransformerlm/patch \
116
- --file bit_transformer/training.py \
117
- --edit "Enable torch.compile only if torch.__version__>=\"2.1\" and python<3.12; else skip with a clear warning"
118
- ```
119
-
120
- ✅ Opportunistic speedup where supported.
121
-
122
- ---
123
-
124
- ### Task 9 — Integrate FlashAttention / SDPA
125
-
126
- **Prompt (pseudo-patch):**
127
-
128
- ```python
129
- from torch.nn import functional as F
130
-
131
- def forward_attention(q, k, v, is_causal=True):
132
- return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
133
- ```
134
-
135
- ✅ Unlocks fused kernels; prefer `is_causal=True` over boolean masks.
136
-
137
- ---
138
-
139
- ### Task 10 — Cache causal masks — COMPLETED ✅
140
-
141
- **Prompt (pseudo-code):**
142
-
143
- ```python
144
- mask_cache = {}
145
-
146
- def get_tri_mask(seq_len, device):
147
- key = (seq_len, device)
148
- if key not in mask_cache:
149
- mask_cache[key] = torch.triu(
150
- torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
151
- )
152
- return mask_cache[key]
153
- ```
154
-
155
- ✅ Avoids repeated `triu` allocations when masks are still needed.
156
-
157
- ---
158
-
159
- ### Task 11 — Fix stitched attention negative indexing — COMPLETED ✅
160
-
161
- **Prompt (pseudo-code):**
162
-
163
- ```python
164
- start = max(s - overlap, 0)
165
- end = min(s + chunk_size, T)
166
- canvas[..., start:end] = attn_chunk[..., : end - start]
167
- ```
168
-
169
- ✅ Prevents wrap-around misplacement during T×T map reconstruction.
170
-
171
- ---
172
-
173
- ### Task 12 — Default off: full T×T attention logging in chunked runs — COMPLETED ✅
174
-
175
- **Prompt:**
176
-
177
- ```bash
178
- codex run bittransformerlm/patch \
179
- --file bit_transformer/model.py \
180
- --edit "Set full_attn_logging=False by default when chunk_size is set"
181
- ```
182
-
183
- ✅ Big memory/time savings without losing training signal.
184
-
185
- ---
186
-
187
- ## Phase 2 — Model Creation & Training Tasks (run after Phase 1)
188
-
189
- ### Task A — Train the best current baseline (8×256 with ACT)
190
-
191
- **Prompt:**
192
-
193
- ```bash
194
- codex run bittransformerlm/train \
195
- --layers 8 \
196
- --d_model 256 \
197
- --nhead 8 \
198
- --causal true \
199
- --chunk_size 128 \
200
- --act true \
201
- --reversible true \
202
- --checkpointing true \
203
- --batch_size 64 \
204
- --accum_steps 2 \
205
- --amp bf16 \
206
- --lr_schedule progressive_plateau \
207
- --full_attn_logging false
208
- ```
209
-
210
- ✅ Reproduces the validated **sweet spot** with newly enabled efficiency features.
211
-
212
- ---
213
-
214
- ### Task B — CPU‑friendly deployment (8×128, INT8 + optional QAT)
215
-
216
- **Prompt:**
217
-
218
- ```bash
219
- codex run bittransformerlm/train \
220
- --layers 8 \
221
- --d_model 128 \
222
- --nhead 8 \
223
- --causal true \
224
- --chunk_size 128 \
225
- --quantization int8 \
226
- --qat true \
227
- --reversible true \
228
- --checkpointing true \
229
- --batch_size 128 \
230
- --accum_steps 1 \
231
- --amp bf16
232
- ```
233
-
234
- ✅ Efficient CPU target; QAT optional based on deployment constraints.
235
-
236
- ---
237
-
238
- ### Task C — Cautious scale‑up candidate (16×256)
239
-
240
- **Prompt:**
241
-
242
- ```bash
243
- codex run bittransformerlm/train \
244
- --layers 16 \
245
- --d_model 256 \
246
- --nhead 8 \
247
- --causal true \
248
- --chunk_size 128 \
249
- --act true \
250
- --reversible true \
251
- --checkpointing true \
252
- --batch_size 48 \
253
- --accum_steps 3 \
254
- --amp bf16 \
255
- --lr_schedule progressive_plateau
256
- ```
257
-
258
- ⚠️ Use only after data expansion and schedule retune.
259
-
260
- ---
261
-
262
- ## Recommended Execution Order
263
-
264
- 1. **Phase 1 Tasks 1–12** (apply all optimizations).
265
- 2. **Task A** baseline → validate.
266
- 3. **Task B** CPU build → validate + (optional) QAT.
267
- 4. **Task C** scale‑up **only** when data/schedule allow.
268
-
269
- ---
270
-
271
- ### Notes
272
-
273
- - Pair Phase 1 changes with CI that runs a short sanity fit (few hundred steps) to confirm loss decreases and no scheduler drift.
274
- - Keep `full_attn_logging=false` in chunked runs; enable selectively when inspecting attention.
275
- - When using SDPA, prefer `is_causal=True` and avoid passing dense masks unless required.
276
-
277
- ---
278
-