🚀 Refined BitTransformerLM: Organized codebase with best practices
Browse files- bit_transformer/model.py +55 -14
bit_transformer/model.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
import contextlib
|
| 3 |
import logging
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.distributed as dist
|
| 8 |
-
import sys
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
import torch.utils.checkpoint as checkpoint
|
| 12 |
|
| 13 |
-
from .torch_utils import cpu_autocast
|
| 14 |
-
|
| 15 |
-
from .optimization import configure_optimizer
|
| 16 |
from .compression import decompress_bits
|
|
|
|
| 17 |
from .parity import enforce_parity
|
|
|
|
| 18 |
|
| 19 |
_mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
|
| 20 |
_attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns
|
|
@@ -29,7 +30,15 @@ def clear_cache():
|
|
| 29 |
|
| 30 |
|
| 31 |
def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
| 32 |
-
"""Return or create a cached upper-triangular mask with memory management.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
key = (seq_len, device)
|
| 34 |
|
| 35 |
# Clear cache if it gets too large
|
|
@@ -56,7 +65,12 @@ except Exception: # pragma: no cover - handle missing torch or unsupported vers
|
|
| 56 |
|
| 57 |
|
| 58 |
class PositionalEncoding(nn.Module):
|
| 59 |
-
"""Sinusoidal positional encoding.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def __init__(self, d_model: int, max_len: int = 1024) -> None:
|
| 62 |
super().__init__()
|
|
@@ -70,7 +84,14 @@ class PositionalEncoding(nn.Module):
|
|
| 70 |
self.register_buffer("pe", pe.unsqueeze(1))
|
| 71 |
|
| 72 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
-
"""Add positional encoding to input tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return x + self.pe[: x.size(0)]
|
| 75 |
|
| 76 |
|
|
@@ -325,7 +346,14 @@ class ReversibleLoggingTransformerEncoderLayer(nn.Module):
|
|
| 325 |
|
| 326 |
|
| 327 |
class BitTransformerLM(nn.Module):
|
| 328 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
def __init__(
|
| 331 |
self,
|
|
@@ -349,10 +377,23 @@ class BitTransformerLM(nn.Module):
|
|
| 349 |
"""Create a BitTransformer language model.
|
| 350 |
|
| 351 |
Args:
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
"""
|
| 357 |
super().__init__()
|
| 358 |
self.d_model = d_model
|
|
|
|
| 1 |
+
"""BitTransformerLM model implementation with reversible layers and telemetry."""
|
| 2 |
+
|
| 3 |
import contextlib
|
| 4 |
import logging
|
| 5 |
+
import math
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.distributed as dist
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
import torch.utils.checkpoint as checkpoint
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
from .compression import decompress_bits
|
| 16 |
+
from .optimization import configure_optimizer
|
| 17 |
from .parity import enforce_parity
|
| 18 |
+
from .torch_utils import cpu_autocast
|
| 19 |
|
| 20 |
_mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
|
| 21 |
_attention_cache: Dict[str, torch.Tensor] = {} # For caching attention patterns
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor:
|
| 33 |
+
"""Return or create a cached upper-triangular mask with memory management.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
seq_len: Sequence length for the mask.
|
| 37 |
+
device: PyTorch device for tensor allocation.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Upper-triangular boolean mask tensor.
|
| 41 |
+
"""
|
| 42 |
key = (seq_len, device)
|
| 43 |
|
| 44 |
# Clear cache if it gets too large
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
class PositionalEncoding(nn.Module):
|
| 68 |
+
"""Sinusoidal positional encoding for transformer inputs.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
d_model: Model dimension for embedding.
|
| 72 |
+
max_len: Maximum sequence length to precompute.
|
| 73 |
+
"""
|
| 74 |
|
| 75 |
def __init__(self, d_model: int, max_len: int = 1024) -> None:
|
| 76 |
super().__init__()
|
|
|
|
| 84 |
self.register_buffer("pe", pe.unsqueeze(1))
|
| 85 |
|
| 86 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
"""Add positional encoding to input tensor.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
x: Input tensor of shape (seq_len, batch_size, d_model).
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Input tensor with positional encoding added.
|
| 94 |
+
"""
|
| 95 |
return x + self.pe[: x.size(0)]
|
| 96 |
|
| 97 |
|
|
|
|
| 346 |
|
| 347 |
|
| 348 |
class BitTransformerLM(nn.Module):
|
| 349 |
+
"""Bit-native transformer language model with reversible layers and telemetry.
|
| 350 |
+
|
| 351 |
+
A transformer architecture that processes binary sequences directly with:
|
| 352 |
+
- Reversible layers for memory efficiency
|
| 353 |
+
- Built-in safety telemetry (K/C/S metrics)
|
| 354 |
+
- Chunked attention for long sequences
|
| 355 |
+
- Causal and diffusion training modes
|
| 356 |
+
"""
|
| 357 |
|
| 358 |
def __init__(
|
| 359 |
self,
|
|
|
|
| 377 |
"""Create a BitTransformer language model.
|
| 378 |
|
| 379 |
Args:
|
| 380 |
+
d_model: Model dimension for embeddings and attention.
|
| 381 |
+
nhead: Number of attention heads.
|
| 382 |
+
num_layers: Number of transformer layers.
|
| 383 |
+
dim_feedforward: Dimension of feedforward networks.
|
| 384 |
+
max_seq_len: Maximum sequence length for positional encoding.
|
| 385 |
+
lambda_K: Weight for negentropy metric in telemetry.
|
| 386 |
+
lambda_C: Weight for complexity metric in telemetry.
|
| 387 |
+
lambda_S: Weight for symbiosis metric in telemetry.
|
| 388 |
+
reversible: Enable reversible layers for memory efficiency.
|
| 389 |
+
use_checkpoint: Use gradient checkpointing.
|
| 390 |
+
use_autocast: Use automatic mixed precision.
|
| 391 |
+
use_act: Enable Adaptive Computation Time.
|
| 392 |
+
act_threshold: ACT halting threshold.
|
| 393 |
+
chunk_size: Chunk size for chunked attention (None for full attention).
|
| 394 |
+
overlap: Overlap size for chunked attention.
|
| 395 |
+
full_attn_logging: When False and chunk_size is smaller than sequence
|
| 396 |
+
length, skip reconstructing full attention matrices for telemetry.
|
| 397 |
"""
|
| 398 |
super().__init__()
|
| 399 |
self.d_model = d_model
|