WCNegentropy commited on
Commit
eab0cde
·
verified ·
1 Parent(s): 4fb71c6

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. bit_transformer/model.py +55 -14
bit_transformer/model.py CHANGED
@@ -1,20 +1,21 @@
1
- import math
 
2
  import contextlib
3
  import logging
4
- from typing import Dict, List, Tuple, Optional
 
 
5
 
6
  import torch
7
  import torch.distributed as dist
8
- import sys
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  import torch.utils.checkpoint as checkpoint
12
 
13
- from .torch_utils import cpu_autocast
14
-
15
- from .optimization import configure_optimizer
16
  from .compression import decompress_bits
 
17
  from .parity import enforce_parity
 
18
 
19
  _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
20
  _attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns
@@ -29,7 +30,15 @@ def clear_cache():
29
 
30
 
31
  def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
32
- """Return or create a cached upper-triangular mask with memory management."""
 
 
 
 
 
 
 
 
33
  key = (seq_len, device)
34
 
35
  # Clear cache if it gets too large
@@ -56,7 +65,12 @@ except Exception: # pragma: no cover - handle missing torch or unsupported vers
56
 
57
 
58
  class PositionalEncoding(nn.Module):
59
- """Sinusoidal positional encoding."""
 
 
 
 
 
60
 
61
  def __init__(self, d_model: int, max_len: int = 1024) -> None:
62
  super().__init__()
@@ -70,7 +84,14 @@ class PositionalEncoding(nn.Module):
70
  self.register_buffer("pe", pe.unsqueeze(1))
71
 
72
  def forward(self, x: torch.Tensor) -> torch.Tensor:
73
- """Add positional encoding to input tensor."""
 
 
 
 
 
 
 
74
  return x + self.pe[: x.size(0)]
75
 
76
 
@@ -325,7 +346,14 @@ class ReversibleLoggingTransformerEncoderLayer(nn.Module):
325
 
326
 
327
  class BitTransformerLM(nn.Module):
328
- """Transformer language model that operates on raw bits (0/1) with telemetry."""
 
 
 
 
 
 
 
329
 
330
  def __init__(
331
  self,
@@ -349,10 +377,23 @@ class BitTransformerLM(nn.Module):
349
  """Create a BitTransformer language model.
350
 
351
  Args:
352
- full_attn_logging: When ``False`` and ``chunk_size`` is
353
- smaller than the sequence length, the model skips
354
- reconstructing the full ``T×T`` attention matrices for
355
- telemetry to reduce memory use.
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  """
357
  super().__init__()
358
  self.d_model = d_model
 
1
+ """BitTransformerLM model implementation with reversible layers and telemetry."""
2
+
3
  import contextlib
4
  import logging
5
+ import math
6
+ import sys
7
+ from typing import Dict, List, Optional, Tuple
8
 
9
  import torch
10
  import torch.distributed as dist
 
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  import torch.utils.checkpoint as checkpoint
14
 
 
 
 
15
  from .compression import decompress_bits
16
+ from .optimization import configure_optimizer
17
  from .parity import enforce_parity
18
+ from .torch_utils import cpu_autocast
19
 
20
  _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
21
  _attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns
 
30
 
31
 
32
  def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
33
+ """Return or create a cached upper-triangular mask with memory management.
34
+
35
+ Args:
36
+ seq_len: Sequence length for the mask.
37
+ device: PyTorch device for tensor allocation.
38
+
39
+ Returns:
40
+ Upper-triangular boolean mask tensor.
41
+ """
42
  key = (seq_len, device)
43
 
44
  # Clear cache if it gets too large
 
65
 
66
 
67
  class PositionalEncoding(nn.Module):
68
+ """Sinusoidal positional encoding for transformer inputs.
69
+
70
+ Args:
71
+ d_model: Model dimension for embedding.
72
+ max_len: Maximum sequence length to precompute.
73
+ """
74
 
75
  def __init__(self, d_model: int, max_len: int = 1024) -> None:
76
  super().__init__()
 
84
  self.register_buffer("pe", pe.unsqueeze(1))
85
 
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ """Add positional encoding to input tensor.
88
+
89
+ Args:
90
+ x: Input tensor of shape (seq_len, batch_size, d_model).
91
+
92
+ Returns:
93
+ Input tensor with positional encoding added.
94
+ """
95
  return x + self.pe[: x.size(0)]
96
 
97
 
 
346
 
347
 
348
  class BitTransformerLM(nn.Module):
349
+ """Bit-native transformer language model with reversible layers and telemetry.
350
+
351
+ A transformer architecture that processes binary sequences directly with:
352
+ - Reversible layers for memory efficiency
353
+ - Built-in safety telemetry (K/C/S metrics)
354
+ - Chunked attention for long sequences
355
+ - Causal and diffusion training modes
356
+ """
357
 
358
  def __init__(
359
  self,
 
377
  """Create a BitTransformer language model.
378
 
379
  Args:
380
+ d_model: Model dimension for embeddings and attention.
381
+ nhead: Number of attention heads.
382
+ num_layers: Number of transformer layers.
383
+ dim_feedforward: Dimension of feedforward networks.
384
+ max_seq_len: Maximum sequence length for positional encoding.
385
+ lambda_K: Weight for negentropy metric in telemetry.
386
+ lambda_C: Weight for complexity metric in telemetry.
387
+ lambda_S: Weight for symbiosis metric in telemetry.
388
+ reversible: Enable reversible layers for memory efficiency.
389
+ use_checkpoint: Use gradient checkpointing.
390
+ use_autocast: Use automatic mixed precision.
391
+ use_act: Enable Adaptive Computation Time.
392
+ act_threshold: ACT halting threshold.
393
+ chunk_size: Chunk size for chunked attention (None for full attention).
394
+ overlap: Overlap size for chunked attention.
395
+ full_attn_logging: When False and chunk_size is smaller than sequence
396
+ length, skip reconstructing full attention matrices for telemetry.
397
  """
398
  super().__init__()
399
  self.d_model = d_model