WCNegentropy commited on
Commit
3a300bc
·
verified ·
1 Parent(s): 73b0816

Remove nested directory: BitTransformerLM/bit_transformer/model.py

Browse files
BitTransformerLM/bit_transformer/model.py DELETED
@@ -1,875 +0,0 @@
1
- import math
2
- import contextlib
3
- import logging
4
- from typing import Dict, List, Tuple
5
-
6
- import torch
7
- import torch.distributed as dist
8
- import sys
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import torch.utils.checkpoint as checkpoint
12
-
13
- from .torch_utils import cpu_autocast
14
-
15
- from .optimization import configure_optimizer
16
- from .compression import decompress_bits
17
- from .parity import enforce_parity
18
-
19
- _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
20
-
21
-
22
- def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
23
- """Return or create a cached upper-triangular mask."""
24
- key = (seq_len, device)
25
- if key not in _mask_cache:
26
- _mask_cache[key] = torch.triu(
27
- torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1
28
- )
29
- return _mask_cache[key]
30
-
31
- try: # torch.compile may not work on all Python versions
32
- if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11):
33
- compile_fn = torch.compile
34
- else:
35
- raise RuntimeError
36
- except Exception: # pragma: no cover - handle missing torch or unsupported version
37
-
38
- def compile_fn(fn=None, **kwargs):
39
- if fn is None:
40
- return lambda f: f
41
- return fn
42
-
43
-
44
- class PositionalEncoding(nn.Module):
45
- """Sinusoidal positional encoding."""
46
-
47
- def __init__(self, d_model: int, max_len: int = 1024) -> None:
48
- super().__init__()
49
- pe = torch.zeros(max_len, d_model)
50
- pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
51
- inv = torch.exp(
52
- torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
53
- )
54
- pe[:, 0::2] = torch.sin(pos * inv)
55
- pe[:, 1::2] = torch.cos(pos * inv)
56
- self.register_buffer("pe", pe.unsqueeze(1))
57
-
58
- def forward(self, x: torch.Tensor) -> torch.Tensor:
59
- """Add positional encoding to input tensor."""
60
- return x + self.pe[: x.size(0)]
61
-
62
-
63
- class LoggingTransformerEncoderLayer(nn.Module):
64
- """Transformer encoder layer that exposes attention weights.
65
-
66
- It optionally performs chunked attention with a fixed window size.
67
- """
68
-
69
- def __init__(
70
- self,
71
- d_model: int,
72
- nhead: int,
73
- dim_feedforward: int = 512,
74
- dropout: float = 0.1,
75
- chunk_size: int | None = None,
76
- overlap: int = 0,
77
- full_attn_logging: bool | None = None,
78
- ) -> None:
79
- super().__init__()
80
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
81
- self.chunk_size = chunk_size
82
- self.overlap = overlap
83
- if full_attn_logging is None:
84
- full_attn_logging = False if chunk_size is not None else True
85
- self.full_attn_logging = full_attn_logging
86
- self.linear1 = nn.Linear(d_model, dim_feedforward)
87
- self.dropout = nn.Dropout(dropout)
88
- self.linear2 = nn.Linear(dim_feedforward, d_model)
89
- self.norm1 = nn.LayerNorm(d_model)
90
- self.norm2 = nn.LayerNorm(d_model)
91
- self.dropout1 = nn.Dropout(dropout)
92
- self.dropout2 = nn.Dropout(dropout)
93
- self.activation = F.relu
94
-
95
- def _chunked_attn(
96
- self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
97
- ) -> Tuple[torch.Tensor, torch.Tensor]:
98
- """Perform chunked self attention with overlap."""
99
- T, B, D = src.shape
100
- src_b = src.transpose(0, 1) # [B, T, D]
101
- C = self.chunk_size or T
102
- O = self.overlap
103
- n_chunks = (T + C - 1) // C
104
- pad_len = n_chunks * C - T
105
- src_pad = F.pad(src_b, (0, 0, O, pad_len + O))
106
- chunk_len = C + 2 * O
107
- chunks = src_pad.unfold(1, chunk_len, C) # [B, n_chunks, chunk_len, D]
108
- mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None
109
- out, weights = self.self_attn(
110
- chunks.reshape(B * n_chunks, chunk_len, D),
111
- chunks.reshape(B * n_chunks, chunk_len, D),
112
- chunks.reshape(B * n_chunks, chunk_len, D),
113
- attn_mask=mask,
114
- need_weights=True,
115
- average_attn_weights=False,
116
- )
117
- out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
118
- weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
119
- :, :, :, O : O + C
120
- ]
121
- seq = out.reshape(B, n_chunks * C, D)[:, :T]
122
- if self.full_attn_logging and C < T:
123
- full_attn = torch.zeros(
124
- B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=src.device
125
- )
126
- for idx in range(n_chunks):
127
- s = idx * C
128
- start = max(s - O, 0)
129
- end = min(s + C, n_chunks * C)
130
- src_start = O - (s - start)
131
- src_end = src_start + (end - start)
132
- full_attn[:, :, s : s + C, start:end] = weights[:, idx, :, src_start:src_end]
133
- full_attn = full_attn[:, :, :T, :T]
134
- attn_out = full_attn.detach()
135
- else:
136
- attn_out = torch.empty(0, device=src.device)
137
- return seq.transpose(0, 1), attn_out
138
-
139
- def forward(
140
- self, src: torch.Tensor, attn_mask: torch.Tensor | None = None
141
- ) -> Tuple[torch.Tensor, torch.Tensor]:
142
- """Return output and attention map."""
143
- if self.chunk_size is not None:
144
- attn_output, attn_weights = self._chunked_attn(src, attn_mask)
145
- else:
146
- qkv = src.transpose(0, 1)
147
- attn_output, attn_weights = self.self_attn(
148
- qkv,
149
- qkv,
150
- qkv,
151
- attn_mask=attn_mask,
152
- need_weights=True,
153
- average_attn_weights=False,
154
- )
155
- attn_output = attn_output.transpose(0, 1)
156
- src = src + self.dropout1(attn_output)
157
- src = self.norm1(src)
158
- out = self.linear2(self.dropout(self.activation(self.linear1(src))))
159
- src = src + self.dropout2(out)
160
- src = self.norm2(src)
161
- return src, attn_weights.detach()
162
-
163
-
164
- class ReversibleLoggingTransformerEncoderLayer(nn.Module):
165
- """Reversible transformer encoder layer with checkpointing."""
166
-
167
- def __init__(
168
- self,
169
- d_model: int,
170
- nhead: int,
171
- dim_feedforward: int = 512,
172
- dropout: float = 0.1,
173
- chunk_size: int | None = None,
174
- overlap: int = 0,
175
- full_attn_logging: bool | None = None,
176
- ) -> None:
177
- super().__init__()
178
- self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
179
- self.chunk_size = chunk_size
180
- self.overlap = overlap
181
- if full_attn_logging is None:
182
- full_attn_logging = False if chunk_size is not None else True
183
- self.full_attn_logging = full_attn_logging
184
- self.linear1 = nn.Linear(d_model, dim_feedforward)
185
- self.dropout = nn.Dropout(dropout)
186
- self.linear2 = nn.Linear(dim_feedforward, d_model)
187
- self.norm1 = nn.LayerNorm(d_model)
188
- self.norm2 = nn.LayerNorm(d_model)
189
- self.dropout1 = nn.Dropout(dropout)
190
- self.dropout2 = nn.Dropout(dropout)
191
- self.activation = F.relu
192
-
193
- @compile_fn
194
- def _sa_block(
195
- self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
196
- ) -> Tuple[torch.Tensor, torch.Tensor]:
197
- if self.chunk_size is not None:
198
- T, B, D = x.shape
199
- x_b = x.transpose(0, 1)
200
- C = self.chunk_size or T
201
- O = self.overlap
202
- n_chunks = (T + C - 1) // C
203
- pad_len = n_chunks * C - T
204
- src_pad = F.pad(x_b, (0, 0, O, pad_len + O))
205
- chunk_len = C + 2 * O
206
- chunks = src_pad.unfold(1, chunk_len, C)
207
- mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None
208
- out, weights = self.self_attn(
209
- chunks.reshape(B * n_chunks, chunk_len, D),
210
- chunks.reshape(B * n_chunks, chunk_len, D),
211
- chunks.reshape(B * n_chunks, chunk_len, D),
212
- attn_mask=mask,
213
- need_weights=True,
214
- average_attn_weights=False,
215
- )
216
- out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C]
217
- weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[
218
- :, :, :, O : O + C
219
- ]
220
- seq = out.reshape(B, n_chunks * C, D)[:, :T]
221
- if self.full_attn_logging and C < T:
222
- full_attn = torch.zeros(
223
- B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device
224
- )
225
- for idx in range(n_chunks):
226
- s = idx * C
227
- start = max(s - O, 0)
228
- end = min(s + C, n_chunks * C)
229
- src_start = O - (s - start)
230
- src_end = src_start + (end - start)
231
- full_attn[:, :, s : s + C, start:end] = weights[
232
- :, idx, :, src_start:src_end
233
- ]
234
- full_attn = full_attn[:, :, :T, :T]
235
- weights = full_attn.detach()
236
- else:
237
- weights = torch.empty(0, device=x.device)
238
- attn_out = seq.transpose(0, 1)
239
- else:
240
- qkv = x.transpose(0, 1)
241
- attn_out, weights = self.self_attn(
242
- qkv,
243
- qkv,
244
- qkv,
245
- attn_mask=attn_mask,
246
- need_weights=True,
247
- average_attn_weights=False,
248
- )
249
- attn_out = attn_out.transpose(0, 1)
250
- x = self.norm1(x + self.dropout1(attn_out))
251
- return x, weights.detach()
252
-
253
- @compile_fn
254
- def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
255
- out = self.linear2(self.dropout(self.activation(self.linear1(x))))
256
- x = self.norm2(x + self.dropout2(out))
257
- return x
258
-
259
- def forward(
260
- self,
261
- x1: torch.Tensor,
262
- x2: torch.Tensor,
263
- attn_mask: torch.Tensor | None = None,
264
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
- y1, weights = self._sa_block(x2, attn_mask)
266
- y1 = x1 + y1
267
- y2 = x2 + self._ff_block(y1)
268
- return y1, y2, weights
269
-
270
-
271
- class BitTransformerLM(nn.Module):
272
- """Transformer language model that operates on raw bits (0/1) with telemetry."""
273
-
274
- def __init__(
275
- self,
276
- d_model: int = 128,
277
- nhead: int = 8,
278
- num_layers: int = 4,
279
- dim_feedforward: int = 512,
280
- max_seq_len: int = 1024,
281
- lambda_K: float = 1.0,
282
- lambda_C: float = 1.0,
283
- lambda_S: float = 1.0,
284
- reversible: bool = False,
285
- use_checkpoint: bool = True,
286
- use_autocast: bool = False,
287
- use_act: bool = False,
288
- act_threshold: float = 0.9,
289
- chunk_size: int | None = None,
290
- overlap: int = 0,
291
- full_attn_logging: bool | None = None,
292
- ) -> None:
293
- """Create a BitTransformer language model.
294
-
295
- Args:
296
- full_attn_logging: When ``False`` and ``chunk_size`` is
297
- smaller than the sequence length, the model skips
298
- reconstructing the full ``T×T`` attention matrices for
299
- telemetry to reduce memory use.
300
- """
301
- super().__init__()
302
- self.d_model = d_model
303
- self.num_layers = num_layers
304
- self.lambda_K = lambda_K
305
- self.lambda_C = lambda_C
306
- self.lambda_S = lambda_S
307
- self.reversible = reversible
308
- self.use_checkpoint = use_checkpoint
309
- self.use_autocast = use_autocast
310
- self.use_act = use_act
311
- self.act_threshold = act_threshold
312
- self.chunk_size = chunk_size
313
- self.overlap = overlap
314
- if full_attn_logging is None:
315
- full_attn_logging = False if chunk_size is not None else True
316
- self.full_attn_logging = full_attn_logging
317
-
318
- # Bit embedding: two possible input values
319
- self.embedding = nn.Embedding(2, d_model)
320
- self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len)
321
-
322
- layer_cls = (
323
- ReversibleLoggingTransformerEncoderLayer
324
- if reversible
325
- else LoggingTransformerEncoderLayer
326
- )
327
- self.layers = nn.ModuleList(
328
- [
329
- layer_cls(
330
- d_model=d_model,
331
- nhead=nhead,
332
- dim_feedforward=dim_feedforward,
333
- chunk_size=chunk_size,
334
- overlap=overlap,
335
- full_attn_logging=full_attn_logging,
336
- )
337
- for _ in range(num_layers)
338
- ]
339
- )
340
-
341
- if self.use_act:
342
- self.halt_projs = nn.ModuleList(
343
- [nn.Linear(d_model, 1) for _ in range(num_layers)]
344
- )
345
-
346
- self.out_head = nn.Linear(d_model, 2) # output logits for bit=0 or bit=1
347
-
348
- def expand_positional_encoding(self, new_len: int) -> None:
349
- """Expand positional encoding to at least ``new_len``."""
350
- cur_len = self.pos_enc.pe.size(0)
351
- if new_len <= cur_len:
352
- return
353
- device = self.pos_enc.pe.device
354
- d_model = self.d_model
355
- pe = torch.zeros(new_len, d_model, device=device)
356
- pe[:cur_len] = self.pos_enc.pe.squeeze(1)
357
- pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1)
358
- inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model))
359
- pe[cur_len:, 0::2] = torch.sin(pos * inv)
360
- pe[cur_len:, 1::2] = torch.cos(pos * inv)
361
- self.pos_enc.pe = pe.unsqueeze(1)
362
-
363
- def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None:
364
- """Update weighting coefficients for telemetry metrics."""
365
- self.lambda_K = lambda_K
366
- self.lambda_C = lambda_C
367
- self.lambda_S = lambda_S
368
-
369
- def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor:
370
- """Return raw bit sequences, decompressing if input appears run-length encoded."""
371
- if codes.dim() <= 1:
372
- return codes
373
- needs_decompress = codes.max().item() > 1
374
- if not needs_decompress and codes.size(1) % 2 == 0:
375
- vals = codes[:, 0::2]
376
- if torch.all(vals[:, 1:] != vals[:, :-1]):
377
- needs_decompress = True
378
- if not needs_decompress:
379
- return codes
380
- seqs = [decompress_bits(row.to(torch.uint8)) for row in codes]
381
- max_len = max(seq.numel() for seq in seqs)
382
- padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs]
383
- return torch.stack(padded)
384
-
385
- def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor:
386
- """Approximate negentropy of bit sequences.
387
-
388
- Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered
389
- sequence (all zeros or ones) and ``0`` reflects maximal entropy.
390
- """
391
- codes = self._maybe_decompress(codes)
392
- p = codes.float().mean(dim=1)
393
- entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
394
- max_e = math.log(2.0)
395
- return 1 - entropy / max_e
396
-
397
- def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor:
398
- """Differentiable proxy for Lempel–Ziv complexity.
399
-
400
- Values near ``0`` indicate highly compressible sequences while values
401
- approaching ``1`` correspond to rapid bit alternation.
402
- """
403
- codes = self._maybe_decompress(codes)
404
- diffs = torch.abs(codes[:, 1:] - codes[:, :-1])
405
- return diffs.float().mean(dim=1)
406
-
407
- def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
408
- """Negentropy computed from model logits.
409
-
410
- Parameters
411
- ----------
412
- logits: ``torch.Tensor``
413
- Logit tensor of shape ``(B, T, 2)``.
414
- detach: bool, default ``True``
415
- When ``True`` the computation is detached from the autograd graph.
416
- """
417
- assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
418
- prob = logits.softmax(-1)
419
- if detach:
420
- prob = prob.detach()
421
- p = prob[..., 1].mean(dim=1)
422
- entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9))
423
- max_e = math.log(2.0)
424
- return 1 - entropy / max_e
425
-
426
- def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor:
427
- """LZ complexity proxy computed from logits.
428
-
429
- Parameters
430
- ----------
431
- logits: ``torch.Tensor``
432
- Logit tensor of shape ``(B, T, 2)``.
433
- detach: bool, default ``True``
434
- When ``True`` the computation is detached from the autograd graph.
435
- """
436
- assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
437
- prob = logits.softmax(-1)
438
- if detach:
439
- prob = prob.detach()
440
- prob1 = prob[..., 1]
441
- diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1])
442
- return diffs.mean(dim=1)
443
-
444
- def symbiosis_kl_logits(
445
- self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True
446
- ) -> torch.Tensor:
447
- """Symbiosis score from KL divergence to a reference distribution.
448
-
449
- Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with
450
- the reference distribution and ``0`` indicating maximal divergence.
451
- """
452
- assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]"
453
- probs = logits.softmax(-1)
454
- if detach:
455
- probs = probs.detach()
456
- ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device)
457
- kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1)
458
- max_kl = math.log(2.0)
459
- return 1 - kl / max_kl
460
-
461
- def _act_step(
462
- self,
463
- hidden: torch.Tensor,
464
- idx: int,
465
- halt_prob: torch.Tensor,
466
- act_state: torch.Tensor,
467
- halt_history: List[torch.Tensor],
468
- ) -> Tuple[torch.Tensor, torch.Tensor, bool]:
469
- """Apply one step of ACT halting logic."""
470
- p = torch.sigmoid(self.halt_projs[idx](hidden))
471
- delta = (1 - halt_prob) * p
472
- halt_prob = halt_prob + delta
473
- act_state = act_state + hidden * delta
474
- halt_history.append(halt_prob.detach())
475
- min_prob = halt_prob.detach().min()
476
- if dist.is_initialized():
477
- dist.all_reduce(min_prob, op=dist.ReduceOp.MIN)
478
- return halt_prob, act_state, min_prob.item() >= self.act_threshold
479
-
480
- def forward(
481
- self, bit_seq: torch.Tensor, causal: bool = True
482
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
483
- """Forward pass returning logits and telemetry from the same graph.
484
-
485
- By default the model uses causal masking and (optional) chunked
486
- attention. When ``causal`` is ``False`` the model operates in
487
- "Diffusion LM" mode. In this mode chunked attention is temporarily
488
- disabled so that every token can attend to the full sequence
489
- bidirectionally. The original chunking configuration is restored after
490
- the forward pass.
491
- """
492
-
493
- # Disable chunking when running in bidirectional (non-causal) mode
494
- orig_chunks = None
495
- orig_model_chunk = None
496
- if not causal and self.chunk_size is not None:
497
- orig_model_chunk = self.chunk_size
498
- orig_chunks = [layer.chunk_size for layer in self.layers]
499
- self.chunk_size = None
500
- for layer in self.layers:
501
- layer.chunk_size = None
502
-
503
- try:
504
- ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext()
505
- with ctx:
506
- x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model)
507
- x = self.pos_enc(x)
508
-
509
- attn_mask = get_tri_mask(x.size(0), x.device) if causal else None
510
-
511
- activations: List[torch.Tensor] = []
512
- attn_maps: List[torch.Tensor] = []
513
- halt_history: List[torch.Tensor] = []
514
- if self.use_act:
515
- halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device)
516
- act_state = torch.zeros_like(x)
517
- if self.reversible:
518
- x1, x2 = x, x
519
- for idx, layer in enumerate(self.layers):
520
- if self.use_checkpoint:
521
- x1, x2, attn = checkpoint.checkpoint(
522
- layer, x1, x2, attn_mask
523
- )
524
- else:
525
- x1, x2, attn = layer(x1, x2, attn_mask)
526
- combined = (x1 + x2) / 2
527
- activations.append(combined)
528
- if attn.numel() > 0:
529
- attn_maps.append(attn)
530
- if self.use_act:
531
- halt_prob, act_state, should_break = self._act_step(
532
- combined, idx, halt_prob, act_state, halt_history
533
- )
534
- if should_break:
535
- break
536
- x = (x1 + x2) / 2
537
- else:
538
- for idx, layer in enumerate(self.layers):
539
- if self.use_checkpoint:
540
- x, attn = checkpoint.checkpoint(layer, x, attn_mask)
541
- else:
542
- x, attn = layer(x, attn_mask)
543
- activations.append(x)
544
- if attn.numel() > 0:
545
- attn_maps.append(attn)
546
- if self.use_act:
547
- halt_prob, act_state, should_break = self._act_step(
548
- x, idx, halt_prob, act_state, halt_history
549
- )
550
- if should_break:
551
- break
552
- if self.use_act:
553
- act_state = act_state + x * (1 - halt_prob)
554
- x = act_state
555
- logits = self.out_head(x)
556
-
557
- # Per-layer entropy of activations
558
- entropies = []
559
- for act in activations:
560
- prob = act.softmax(-1)
561
- ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean()
562
- entropies.append(ent)
563
-
564
- attn_entropies = []
565
- for attn in attn_maps:
566
- prob = attn # weights are already softmaxed
567
- ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1)
568
- ent = ent.mean(1)
569
- attn_entropies.append(ent)
570
- if attn_entropies:
571
- attn_entropy_map = torch.stack(attn_entropies).mean(0)
572
- else:
573
- attn_entropy_map = torch.zeros(
574
- bit_seq.size(0), bit_seq.size(1), device=bit_seq.device
575
- )
576
- max_ent = math.log(attn_entropy_map.size(-1))
577
- attn_entropy_map = attn_entropy_map / max_ent
578
- attn_entropy = attn_entropy_map.mean(1)
579
-
580
- logits_bt = logits.transpose(0, 1)
581
- negentropy_in = self.negentropy_kpi(bit_seq)
582
- lz_in = self.lz_complexity(bit_seq.float())
583
- negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False)
584
- lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False)
585
- kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False)
586
-
587
- raw_sym = (
588
- (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2
589
- + negentropy_logits_b * lz_logits_b
590
- - self.lambda_S * kl_div_b
591
- - 0.1 * attn_entropy
592
- )
593
- weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach()
594
- raw_sym = raw_sym - 0.01 * weight_norm
595
- sym_score = torch.sigmoid(raw_sym)
596
-
597
- B, T = bit_seq.shape
598
- assert logits_bt.shape[:2] == (B, T)
599
- assert attn_entropy_map.shape == (B, T)
600
-
601
- telemetry = {
602
- "activations": activations,
603
- "attention_maps": attn_maps,
604
- "attention_entropy": attn_entropy_map,
605
- "entropy": entropies,
606
- "attention_entropy_mean": attn_entropy,
607
- "negentropy_input": negentropy_in.detach(),
608
- "lz_complexity_input": lz_in.detach(),
609
- "negentropy_logits": negentropy_logits_b.detach(),
610
- "lz_complexity_logits": lz_logits_b.detach(),
611
- "symbiosis_kl": kl_div_b.detach(),
612
- "symbiosis_score": sym_score.detach(),
613
- }
614
- if self.use_act:
615
- telemetry["halt_probs"] = halt_history
616
-
617
- return logits_bt, telemetry
618
- finally:
619
- if orig_chunks is not None:
620
- self.chunk_size = orig_model_chunk
621
- for layer, chunk in zip(self.layers, orig_chunks):
622
- layer.chunk_size = chunk
623
-
624
- def forward_compressed(
625
- self, compressed_bits, causal: bool = True
626
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
627
- """Decompress bit sequences then run the normal forward pass."""
628
- if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1:
629
- sequences = [decompress_bits(compressed_bits).to(torch.long)]
630
- else:
631
- sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits]
632
- lengths = [seq.numel() for seq in sequences]
633
- if len(set(lengths)) != 1:
634
- raise ValueError("Sequences decompress to different lengths")
635
- bits = torch.stack(sequences)
636
- return self.forward(bits, causal=causal)
637
-
638
- def _current_params(self) -> Dict:
639
- """Return a dictionary with the current model hyperparameters."""
640
- return {
641
- "d_model": self.d_model,
642
- "nhead": self.layers[0].self_attn.num_heads,
643
- "num_layers": self.num_layers,
644
- "dim_feedforward": self.layers[0].linear1.out_features,
645
- "max_seq_len": self.pos_enc.pe.size(0),
646
- "lambda_K": self.lambda_K,
647
- "lambda_C": self.lambda_C,
648
- "lambda_S": self.lambda_S,
649
- "reversible": self.reversible,
650
- "use_checkpoint": self.use_checkpoint,
651
- "use_autocast": self.use_autocast,
652
- "use_act": self.use_act,
653
- "act_threshold": self.act_threshold,
654
- "chunk_size": self.chunk_size,
655
- "overlap": self.overlap,
656
- }
657
-
658
- def double_width(self) -> "BitTransformerLM":
659
- """Return a copy of the model with doubled hidden size."""
660
- from .scale import expand_model
661
-
662
- params = self._current_params()
663
- params["d_model"] *= 2
664
- params["dim_feedforward"] *= 2
665
- return expand_model(self, params)
666
-
667
- def double_layers(self) -> "BitTransformerLM":
668
- """Return a copy of the model with twice as many layers."""
669
- from .scale import expand_model
670
-
671
- params = self._current_params()
672
- params["num_layers"] *= 2
673
- return expand_model(self, params)
674
-
675
- def double_length(self) -> "BitTransformerLM":
676
- """Return a copy of the model with doubled maximum sequence length."""
677
- from .scale import expand_model
678
-
679
- params = self._current_params()
680
- params["max_seq_len"] *= 2
681
- params["chunk_size"] = params["max_seq_len"]
682
- return expand_model(self, params)
683
-
684
- def train_full_sequence(
685
- self,
686
- bits: torch.Tensor,
687
- *,
688
- ctx_bits: int = 4096,
689
- detach_every_n: int = 1_048_576,
690
- ) -> float:
691
- """Train on a long bit tensor using sliding windows.
692
-
693
- Parameters
694
- ----------
695
- bits: ``torch.Tensor``
696
- 1D tensor containing the full bit sequence.
697
- ctx_bits: int
698
- Size of the training context window.
699
- detach_every_n: int
700
- Interval in bits for optimizer updates and graph detachment.
701
- Returns
702
- -------
703
- float
704
- Mean loss over all windows.
705
- """
706
- self.train()
707
- optimizer, scheduler = configure_optimizer(
708
- self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits)
709
- )
710
- accum = 0
711
- total_loss = 0.0
712
- count = 0
713
- for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits):
714
- segment = bits[start : start + ctx_bits + 1].unsqueeze(0)
715
- logits, _ = self(segment)
716
- pred = logits[:, :-1, :].reshape(-1, 2)
717
- target = segment[:, 1:].reshape(-1)
718
- loss = F.cross_entropy(pred, target)
719
- loss.backward()
720
- accum += ctx_bits
721
- total_loss += loss.item()
722
- count += 1
723
- if accum >= detach_every_n:
724
- torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
725
- optimizer.step()
726
- scheduler.step()
727
- optimizer.zero_grad()
728
- accum = 0
729
- if accum > 0:
730
- torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
731
- optimizer.step()
732
- scheduler.step()
733
- optimizer.zero_grad()
734
- return total_loss / max(1, count)
735
-
736
-
737
- def infer_long_sequence(
738
- model: BitTransformerLM,
739
- bits: torch.Tensor,
740
- *,
741
- ctx_bits: int = 4096,
742
- overlap: int = 256,
743
- ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]:
744
- """Infer a long bit sequence using sliding windows with overlap."""
745
- model.eval()
746
- device = next(model.parameters()).device
747
- bits = bits.to(device)
748
- step = ctx_bits - overlap
749
- outputs: List[torch.Tensor] = []
750
- logs: List[Dict[str, torch.Tensor]] = []
751
- for start in range(0, bits.numel(), step):
752
- window = bits[start : start + ctx_bits].unsqueeze(0)
753
- logits, tele = model(window, causal=True)
754
- pred = logits.argmax(-1).squeeze(0)
755
- outputs.append(pred)
756
- logs.append(tele)
757
- out = torch.cat(outputs)[: bits.numel()]
758
- return out, logs
759
-
760
-
761
- def diffusion_inference(
762
- model: BitTransformerLM,
763
- *,
764
- length: int,
765
- steps: int = 8,
766
- batch_size: int = 1,
767
- init_bits: torch.Tensor | None = None,
768
- schedule: str = "linear",
769
- ) -> torch.Tensor:
770
- """Generate bit sequences using iterative denoising diffusion.
771
-
772
- Parameters
773
- ----------
774
- model: ``BitTransformerLM``
775
- The model used for denoising. It is run in non-causal mode with
776
- chunked attention disabled, enabling full-context bidirectional
777
- attention.
778
- length: int
779
- Length of the bit sequences to generate.
780
- steps: int, default ``8``
781
- Number of denoising iterations. More steps generally yield sharper
782
- samples at the cost of compute.
783
- batch_size: int, default ``1``
784
- Number of sequences to generate in parallel.
785
- init_bits: ``torch.Tensor`` | ``None``
786
- Optional initial noisy bits of shape ``(batch_size, length)``. When
787
- ``None`` random noise is used.
788
- schedule: str, default ``"linear"``
789
- Noise schedule for the denoising mask probability. Options are
790
- ``"linear"``, ``"cosine"``, and ``"exp"``.
791
-
792
- Returns
793
- -------
794
- ``torch.Tensor``
795
- A tensor of shape ``(batch_size, length)`` containing generated bits.
796
- """
797
-
798
- model.eval()
799
- device = next(model.parameters()).device
800
- if init_bits is None:
801
- bits = torch.randint(0, 2, (batch_size, length), device=device)
802
- else:
803
- bits = init_bits.to(device)
804
- if bits.shape != (batch_size, length):
805
- raise ValueError("init_bits must have shape (batch_size, length)")
806
-
807
- for step in range(steps):
808
- logits, _ = model(bits, causal=False)
809
- prob = logits.softmax(-1)[..., 1]
810
- t = (step + 1) / steps
811
- if schedule == "linear":
812
- mask_prob = 1.0 - t
813
- elif schedule == "cosine":
814
- mask_prob = math.cos(math.pi * t / 2)
815
- elif schedule == "exp":
816
- mask_prob = math.exp(-5 * t)
817
- else:
818
- raise ValueError(f"unknown schedule: {schedule}")
819
- mask = (torch.rand_like(bits.float()) < mask_prob).long()
820
- sampled = torch.bernoulli(prob).long()
821
- bits = torch.where(mask.bool(), sampled, bits)
822
- if bits.shape[-1] % 9 == 0:
823
- bits, corrections = enforce_parity(bits)
824
- if corrections:
825
- logging.info("Parity corrections applied: %d", corrections)
826
- try:
827
- from .safety import hil_safe_inference
828
-
829
- hil_safe_inference(model, bits, causal=False, strict=False)
830
- except RuntimeError as exc:
831
- logging.warning("Safety gate warning: %s", exc)
832
- return bits
833
-
834
-
835
- def example_usage() -> float:
836
- """Run the example from the README and return the loss."""
837
- B, L = 4, 16
838
- model = BitTransformerLM(
839
- d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L
840
- )
841
- bits = torch.randint(0, 2, (B, L), dtype=torch.long)
842
- logits, _ = model(bits)
843
- pred = logits[:, :-1, :].reshape(-1, 2)
844
- target = bits[:, 1:].reshape(-1)
845
- loss = F.cross_entropy(pred, target)
846
- return loss.item()
847
-
848
-
849
- def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]:
850
- """Demonstrate a training step where metrics do not affect gradients."""
851
- B, L = 4, 16
852
- model = BitTransformerLM(
853
- d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L
854
- )
855
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1)
856
-
857
- bits = torch.randint(0, 2, (B, L), dtype=torch.long)
858
- logits, telemetry = model(bits)
859
-
860
- pred = logits[:, :-1, :].reshape(-1, 2)
861
- target = bits[:, 1:].reshape(-1)
862
- loss = F.cross_entropy(pred, target)
863
-
864
- loss.backward()
865
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
866
- optimizer.step()
867
- scheduler.step()
868
- optimizer.zero_grad()
869
- return loss.item(), telemetry
870
-
871
-
872
- if __name__ == "__main__":
873
- loss, telemetry = example_training_step()
874
- print("Composite loss:", loss)
875
- print("Telemetry keys:", list(telemetry.keys()))