WCNegentropy commited on
Commit
7d0df52
·
verified ·
1 Parent(s): c1c18dc

Remove nested directory: BitTransformerLM/integration_schedule.py

Browse files
BitTransformerLM/integration_schedule.py DELETED
@@ -1,379 +0,0 @@
1
- import os
2
- import time
3
- import math
4
- from itertools import cycle
5
- from typing import Optional
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from bit_transformer import (
10
- BitTransformerLM,
11
- text_to_bits,
12
- quantize_dynamic,
13
- prepare_qat_fx,
14
- convert_qat_fx,
15
- hil_safe_inference,
16
- collapse_submodel,
17
- diffusion_inference,
18
- TelemetrySynthesizer,
19
- save_distilled_model,
20
- )
21
- from bit_transformer.training import train_loop as train
22
- from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
23
- from bit_transformer.utils import save_model, load_model, set_dropout
24
- from bit_transformer.torch_utils import cpu_autocast
25
-
26
-
27
- def lines_to_tensor(lines, max_len):
28
- seqs = []
29
- for text in lines:
30
- bits = text_to_bits(text)[:max_len]
31
- if len(bits) < max_len:
32
- bits.extend([0] * (max_len - len(bits)))
33
- seqs.append(bits)
34
- return torch.tensor(seqs, dtype=torch.long)
35
-
36
-
37
- def load_wikitext(dataset_size=128, max_len=64):
38
- try:
39
- from datasets import load_dataset
40
-
41
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
42
- train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
43
- valid_split = max(1, dataset_size // 4)
44
- valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
45
- train = lines_to_tensor(train_lines, max_len)
46
- valid = lines_to_tensor(valid_lines, max_len)
47
- return train, valid, train_lines
48
- except Exception as e:
49
- print("Dataset load failed, using random bits", e)
50
- train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
51
- valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
52
- return train, valid, ["" for _ in range(len(train))]
53
-
54
-
55
- def _warmup(
56
- model: BitTransformerLM,
57
- data: torch.Tensor,
58
- steps: int = 5,
59
- freeze_old: bool = False,
60
- old_layers: int = 0,
61
- *,
62
- diffusion: bool = False,
63
- curriculum: bool = False,
64
- optimizer: Optional[torch.optim.Optimizer] = None,
65
- scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
66
- ) -> None:
67
- """Run a short warm-up loop after expansion."""
68
- model.train()
69
- set_dropout(model, 0.1)
70
- if freeze_old:
71
- for idx, layer in enumerate(model.layers):
72
- if idx < old_layers:
73
- for p in layer.parameters():
74
- p.requires_grad_(False)
75
- if optimizer is None or scheduler is None:
76
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
77
- it = iter(data.split(8))
78
- for idx in range(steps):
79
- try:
80
- batch = next(it)
81
- except StopIteration:
82
- it = iter(data.split(8))
83
- batch = next(it)
84
- if diffusion:
85
- p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
86
- noise = (torch.rand_like(batch.float()) < p).long()
87
- noisy = batch ^ noise
88
- logits, _ = model(noisy, causal=False)
89
- pred = logits.reshape(-1, 2)
90
- target = batch.reshape(-1)
91
- else:
92
- logits, _ = model(batch)
93
- pred = logits[:, :-1, :].reshape(-1, 2)
94
- target = batch[:, 1:].reshape(-1)
95
- loss = F.cross_entropy(pred, target)
96
- loss.backward()
97
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
98
- optimizer.step()
99
- scheduler.step()
100
- optimizer.zero_grad()
101
- for p in model.parameters():
102
- p.requires_grad_(True)
103
- model.eval()
104
- set_dropout(model, 0.0)
105
-
106
-
107
- def integration_schedule(
108
- steps: int = 10,
109
- max_len: int = 64,
110
- dataset_size: int = 128,
111
- *,
112
- weights_path: str = "weights/model.pt.gz",
113
- plateau_steps: int = 0,
114
- collapsed_path: str | None = None,
115
- epochs_per_step: int = 2,
116
- extra_steps: int = 3,
117
- collapse: bool = True,
118
- diffusion: bool = False,
119
- noise_schedule: str = "linear",
120
- diffusion_steps: int = 8,
121
- diffusion_curriculum: bool = False,
122
- use_checkpoint: bool = True,
123
- reversible: bool = True,
124
- improve_thresh: float = 0.01,
125
- qat: bool = False,
126
- ):
127
- start = time.time()
128
- train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
129
- if os.path.exists(weights_path):
130
- try:
131
- model = load_model(weights_path)
132
- print(f"Loaded model from {weights_path}")
133
- except Exception as e:
134
- print("Failed to load weights, initializing new model", e)
135
- model = BitTransformerLM(
136
- d_model=32,
137
- nhead=4,
138
- num_layers=1,
139
- dim_feedforward=64,
140
- max_seq_len=max_len,
141
- use_act=True,
142
- act_threshold=0.7,
143
- reversible=reversible,
144
- chunk_size=max_len,
145
- use_autocast=True,
146
- use_checkpoint=use_checkpoint,
147
- )
148
- else:
149
- model = BitTransformerLM(
150
- d_model=32,
151
- nhead=4,
152
- num_layers=1,
153
- dim_feedforward=64,
154
- max_seq_len=max_len,
155
- use_act=True,
156
- act_threshold=0.7,
157
- reversible=reversible,
158
- chunk_size=max_len,
159
- use_autocast=True,
160
- use_checkpoint=use_checkpoint,
161
- )
162
- if qat:
163
- model = prepare_qat_fx(model)
164
- results = []
165
- scale_cycle = cycle(["layers", "width", "context"])
166
- base_lr = 1e-3
167
- prev_val_loss: Optional[float] = None
168
- for step in range(steps):
169
- model.train()
170
- set_dropout(model, 0.1)
171
- opt, sched = configure_optimizer(
172
- model, lr=base_lr, total_steps=epochs_per_step
173
- )
174
- train(
175
- model,
176
- train_bits,
177
- epochs=epochs_per_step,
178
- extra_steps=extra_steps,
179
- compress_prob=0.0 if diffusion else 1.0,
180
- log=True,
181
- diffusion=diffusion,
182
- diffusion_curriculum=diffusion_curriculum,
183
- optimizer=opt,
184
- scheduler=sched,
185
- )
186
-
187
- model.eval()
188
- set_dropout(model, 0.0)
189
- with torch.no_grad():
190
- logits, telemetry = model(valid_bits, causal=not diffusion)
191
- if diffusion:
192
- pred = logits.reshape(-1, 2)
193
- target = valid_bits.reshape(-1)
194
- else:
195
- pred = logits[:, :-1, :].reshape(-1, 2)
196
- target = valid_bits[:, 1:].reshape(-1)
197
- val_loss = F.cross_entropy(pred, target).item()
198
- k = telemetry["negentropy_logits"].mean().item()
199
- c = telemetry["lz_complexity_logits"].mean().item()
200
- s = telemetry["symbiosis_score"].mean().item()
201
- print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
202
- results.append((step, val_loss, k, c, s))
203
-
204
- if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
205
- strategy = next(scale_cycle)
206
- base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
207
- if strategy == "layers":
208
- old_layers = model.num_layers
209
- model = model.double_layers()
210
- warm_opt, warm_sched = configure_optimizer(
211
- model, lr=base_lr, total_steps=100
212
- )
213
- _warmup(
214
- model,
215
- train_bits,
216
- steps=100,
217
- freeze_old=True,
218
- old_layers=old_layers,
219
- diffusion=diffusion,
220
- curriculum=diffusion_curriculum,
221
- optimizer=warm_opt,
222
- scheduler=warm_sched,
223
- )
224
- elif strategy == "width":
225
- model = model.double_width()
226
- warm_opt, warm_sched = configure_optimizer(
227
- model, lr=base_lr, total_steps=100
228
- )
229
- _warmup(
230
- model,
231
- train_bits,
232
- steps=100,
233
- diffusion=diffusion,
234
- curriculum=diffusion_curriculum,
235
- optimizer=warm_opt,
236
- scheduler=warm_sched,
237
- )
238
- else:
239
- max_len *= 2
240
- train_bits, valid_bits, train_lines = load_wikitext(
241
- dataset_size, max_len
242
- )
243
- model = model.double_length()
244
- warm_opt, warm_sched = configure_optimizer(
245
- model, lr=base_lr, total_steps=100
246
- )
247
- _warmup(
248
- model,
249
- train_bits,
250
- steps=100,
251
- diffusion=diffusion,
252
- curriculum=diffusion_curriculum,
253
- optimizer=warm_opt,
254
- scheduler=warm_sched,
255
- )
256
-
257
- prev_val_loss = val_loss
258
- if time.time() - start > 8 * 60:
259
- print("Time limit reached")
260
- break
261
-
262
- # optional plateau phase at final size
263
- for p in range(plateau_steps):
264
- model.train()
265
- set_dropout(model, 0.1)
266
- train(
267
- model,
268
- train_bits,
269
- epochs=epochs_per_step,
270
- extra_steps=extra_steps,
271
- compress_prob=0.0 if diffusion else 1.0,
272
- log=True,
273
- diffusion=diffusion,
274
- diffusion_curriculum=diffusion_curriculum,
275
- )
276
- model.eval()
277
- set_dropout(model, 0.0)
278
- with torch.no_grad():
279
- logits, telemetry = model(valid_bits, causal=not diffusion)
280
- if diffusion:
281
- pred = logits.reshape(-1, 2)
282
- target = valid_bits.reshape(-1)
283
- else:
284
- pred = logits[:, :-1, :].reshape(-1, 2)
285
- target = valid_bits[:, 1:].reshape(-1)
286
- val_loss = F.cross_entropy(pred, target).item()
287
- k = telemetry["negentropy_logits"].mean().item()
288
- c = telemetry["lz_complexity_logits"].mean().item()
289
- s = telemetry["symbiosis_score"].mean().item()
290
- idx = steps + p
291
- print(
292
- f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
293
- )
294
- results.append((idx, val_loss, k, c, s))
295
- if time.time() - start > 8 * 60:
296
- print("Time limit reached")
297
- break
298
-
299
- # final validation after last step
300
- model.eval()
301
- set_dropout(model, 0.0)
302
- with torch.no_grad():
303
- logits, telemetry = model(valid_bits, causal=not diffusion)
304
- if diffusion:
305
- pred = logits.reshape(-1, 2)
306
- target = valid_bits.reshape(-1)
307
- else:
308
- pred = logits[:, :-1, :].reshape(-1, 2)
309
- target = valid_bits[:, 1:].reshape(-1)
310
- val_loss = F.cross_entropy(pred, target).item()
311
- k = telemetry["negentropy_logits"].mean().item()
312
- c = telemetry["lz_complexity_logits"].mean().item()
313
- s = telemetry["symbiosis_score"].mean().item()
314
-
315
- print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
316
- results.append((steps + plateau_steps, val_loss, k, c, s))
317
-
318
- # persist final model weights for future runs
319
- save_model(model, weights_path)
320
-
321
- input_bits = valid_bits[:1]
322
- if qat:
323
- qmodel = convert_qat_fx(model)
324
- else:
325
- with cpu_autocast():
326
- model(input_bits)
327
- qmodel = quantize_dynamic(model)
328
- qmodel.eval()
329
- try:
330
- hil_safe_inference(
331
- qmodel,
332
- input_bits,
333
- c_floor=0.3,
334
- s_floor=0.5,
335
- causal=not diffusion,
336
- strict=not diffusion,
337
- )
338
- except RuntimeError as e:
339
- print("Safety gate triggered", e)
340
- collapsed = None
341
- if collapse:
342
- synth = TelemetrySynthesizer(n_clusters=8)
343
- reps = synth.cluster_sequences(model, train_bits[:64])
344
- floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
345
- collapsed, metrics = collapse_submodel(
346
- reps,
347
- target_params=dict(
348
- d_model=16,
349
- nhead=4,
350
- num_layers=1,
351
- dim_feedforward=32,
352
- max_seq_len=max_len,
353
- ),
354
- floors=floors,
355
- )
356
- collapsed.eval()
357
- with torch.no_grad():
358
- logits, _ = collapsed(valid_bits)
359
- pred = logits[:, :-1, :].reshape(-1, 2)
360
- target = valid_bits[:, 1:].reshape(-1)
361
- c_loss = F.cross_entropy(pred, target).item()
362
- print("Collapsed model validation loss:", c_loss)
363
- if collapsed_path is not None:
364
- save_distilled_model(
365
- collapsed,
366
- collapsed_path,
367
- {**metrics, "val_loss": c_loss},
368
- floors=floors,
369
- )
370
- if diffusion:
371
- sample = diffusion_inference(
372
- model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
373
- )
374
- print("Diffusion sample:", sample[0].tolist())
375
- return results, collapsed
376
-
377
-
378
- if __name__ == "__main__":
379
- integration_schedule()